This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 582b9c3  [SYSTEMDS-3039] Tracking and consolidation of federated 
statistics
582b9c3 is described below

commit 582b9c3f622d87cd9d11b9dd01abcb0a6f179309
Author: ywcb00 <[email protected]>
AuthorDate: Sun Jul 4 19:52:56 2021 +0200

    [SYSTEMDS-3039] Tracking and consolidation of federated statistics
    
    Closes #1321.
---
 src/main/java/org/apache/sysds/api/DMLOptions.java |  19 ++
 src/main/java/org/apache/sysds/api/DMLScript.java  |  72 ++---
 .../federated/FederatedStatistics.java             | 311 +++++++++++++++++++++
 .../instructions/fed/InitFEDInstruction.java       |   5 +
 .../java/org/apache/sysds/utils/Statistics.java    |  22 +-
 .../primitives/FederatedStatisticsTest.java        |   5 +-
 6 files changed, 395 insertions(+), 39 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java 
b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 0949c06..fbdaa90 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -52,6 +52,8 @@ public class DMLOptions {
        public boolean              clean         = false;            // 
Whether to clean up all SystemDS working directories (FS, DFS)
        public boolean              stats         = false;            // 
Whether to record and print the statistics
        public int                  statsCount    = 10;               // 
Default statistics count
+       public boolean              fedStats      = false;            // 
Whether to record and print the federated statistics
+       public int                  fedStatsCount = 10;               // 
Default federated statistics count
        public boolean              memStats      = false;            // max 
memory statistics
        public Explain.ExplainType  explainType   = Explain.ExplainType.NONE;  
// Whether to print the "Explain" and if so, what type
        public ExecMode             execMode      = 
OptimizerUtils.getDefaultExecutionMode();  // Execution mode standalone, MR, 
Spark or a hybrid
@@ -85,6 +87,8 @@ public class DMLOptions {
                        ", clean=" + clean +
                        ", stats=" + stats +
                        ", statsCount=" + statsCount +
+                       ", fedStats=" + fedStats +
+                       ", fedStatsCount=" + fedStatsCount +
                        ", memStats=" + memStats +
                        ", explainType=" + explainType +
                        ", execMode=" + execMode +
@@ -193,6 +197,17 @@ public class DMLOptions {
                                }
                        }
                }
+               dmlOptions.fedStats = line.hasOption("fedStats");
+               if (dmlOptions.fedStats) {
+                       String fedStatsCount = line.getOptionValue("fedStats");
+                       if(fedStatsCount != null) {
+                               try {
+                                       dmlOptions.fedStatsCount = 
Integer.parseInt(fedStatsCount);
+                               } catch (NumberFormatException e) {
+                                       throw new 
org.apache.commons.cli.ParseException("Invalid argument specified for -fedStats 
option, must be a valid integer");
+                               }
+                       }
+               }
                dmlOptions.memStats = line.hasOption("mem");
 
                dmlOptions.clean = line.hasOption("clean");
@@ -265,6 +280,9 @@ public class DMLOptions {
                Option statsOpt = OptionBuilder.withArgName("count")
                        .withDescription("monitors and reports summary 
execution statistics; heavy hitter <count> is 10 unless overridden; default 
off")
                        .hasOptionalArg().create("stats");
+               Option fedStatsOpt = OptionBuilder.withArgName("count")
+                       .withDescription("monitors and reports summary 
execution statistics of federated workers; heavy hitter <count> is 10 unless 
overridden; default off")
+                       .hasOptionalArg().create("fedStats");
                Option memOpt = OptionBuilder.withDescription("monitors and 
reports max memory consumption in CP; default off")
                        .create("mem");
                Option explainOpt = OptionBuilder.withArgName("level")
@@ -299,6 +317,7 @@ public class DMLOptions {
                options.addOption(configOpt);
                options.addOption(cleanOpt);
                options.addOption(statsOpt);
+               options.addOption(fedStatsOpt);
                options.addOption(memOpt);
                options.addOption(explainOpt);
                options.addOption(execOpt);
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 7d2bf16..e2e67a5 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -82,26 +82,28 @@ import org.apache.sysds.utils.Explain.ExplainType;
 
 public class DMLScript 
 {
-       private static ExecMode   EXEC_MODE          = 
DMLOptions.defaultOptions.execMode;     // the execution mode
-       public static boolean     STATISTICS          = 
DMLOptions.defaultOptions.stats;       // whether to print statistics
-       public static boolean     JMLC_MEM_STATISTICS = false;                  
               // whether to gather memory use stats in JMLC
-       public static int         STATISTICS_COUNT    = 
DMLOptions.defaultOptions.statsCount;  // statistics maximum heavy hitter count
-       public static int         STATISTICS_MAX_WRAP_LEN = 30;                 
               // statistics maximum wrap length
-       public static ExplainType EXPLAIN             = 
DMLOptions.defaultOptions.explainType; // explain type
-       public static String      DML_FILE_PATH_ANTLR_PARSER = 
DMLOptions.defaultOptions.filePath; // filename of dml/pydml script
-       public static String      FLOATING_POINT_PRECISION = "double";          
               // data type to use internally
-       public static boolean     PRINT_GPU_MEMORY_INFO = false;                
               // whether to print GPU memory-related information
-       public static long        EVICTION_SHADOW_BUFFER_MAX_BYTES = 0;         
               // maximum number of bytes to use for shadow buffer
-       public static long        EVICTION_SHADOW_BUFFER_CURR_BYTES = 0;        
               // number of bytes to use for shadow buffer
-       public static double      GPU_MEMORY_UTILIZATION_FACTOR = 0.9;          
               // fraction of available GPU memory to use
-       public static String      GPU_MEMORY_ALLOCATOR = "cuda";                
               // GPU memory allocator to use
-       public static boolean     LINEAGE = DMLOptions.defaultOptions.lineage;  
               // whether compute lineage trace
-       public static boolean     LINEAGE_DEDUP = 
DMLOptions.defaultOptions.lineage_dedup;     // whether deduplicate lineage 
items
-       public static ReuseCacheType LINEAGE_REUSE = 
DMLOptions.defaultOptions.linReuseType;   // whether lineage-based reuse
-       public static LineageCachePolicy LINEAGE_POLICY = 
DMLOptions.defaultOptions.linCachePolicy; // lineage cache eviction policy
-       public static boolean     LINEAGE_ESTIMATE = 
DMLOptions.defaultOptions.lineage_estimate; // whether estimate reuse benefits
-       public static boolean     LINEAGE_DEBUGGER = 
DMLOptions.defaultOptions.lineage_debugger; // whether enable lineage debugger
-       public static boolean     CHECK_PRIVACY = 
DMLOptions.defaultOptions.checkPrivacy;      // Check which privacy constraints 
are loaded and checked during federated execution
+       private static ExecMode   EXEC_MODE          = 
DMLOptions.defaultOptions.execMode;           // the execution mode
+       public static boolean     STATISTICS          = 
DMLOptions.defaultOptions.stats;             // whether to print statistics
+       public static boolean     JMLC_MEM_STATISTICS = false;                  
                     // whether to gather memory use stats in JMLC
+       public static int         STATISTICS_COUNT    = 
DMLOptions.defaultOptions.statsCount;        // statistics maximum heavy hitter 
count
+       public static int         STATISTICS_MAX_WRAP_LEN = 30;                 
                     // statistics maximum wrap length
+       public static boolean     FED_STATISTICS        = 
DMLOptions.defaultOptions.fedStats;        // whether to print federated 
statistics
+       public static int         FED_STATISTICS_COUNT  = 
DMLOptions.defaultOptions.fedStatsCount;   // federated statistics maximum 
heavy hitter count
+       public static ExplainType EXPLAIN             = 
DMLOptions.defaultOptions.explainType;       // explain type
+       public static String      DML_FILE_PATH_ANTLR_PARSER = 
DMLOptions.defaultOptions.filePath;   // filename of dml/pydml script
+       public static String      FLOATING_POINT_PRECISION = "double";          
                     // data type to use internally
+       public static boolean     PRINT_GPU_MEMORY_INFO = false;                
                     // whether to print GPU memory-related information
+       public static long        EVICTION_SHADOW_BUFFER_MAX_BYTES = 0;         
                     // maximum number of bytes to use for shadow buffer
+       public static long        EVICTION_SHADOW_BUFFER_CURR_BYTES = 0;        
                     // number of bytes to use for shadow buffer
+       public static double      GPU_MEMORY_UTILIZATION_FACTOR = 0.9;          
                     // fraction of available GPU memory to use
+       public static String      GPU_MEMORY_ALLOCATOR = "cuda";                
                     // GPU memory allocator to use
+       public static boolean     LINEAGE = DMLOptions.defaultOptions.lineage;  
                     // whether compute lineage trace
+       public static boolean     LINEAGE_DEDUP = 
DMLOptions.defaultOptions.lineage_dedup;           // whether deduplicate 
lineage items
+       public static ReuseCacheType LINEAGE_REUSE = 
DMLOptions.defaultOptions.linReuseType;         // whether lineage-based reuse
+       public static LineageCachePolicy LINEAGE_POLICY = 
DMLOptions.defaultOptions.linCachePolicy;  // lineage cache eviction policy
+       public static boolean     LINEAGE_ESTIMATE = 
DMLOptions.defaultOptions.lineage_estimate;     // whether estimate reuse 
benefits
+       public static boolean     LINEAGE_DEBUGGER = 
DMLOptions.defaultOptions.lineage_debugger;     // whether enable lineage 
debugger
+       public static boolean     CHECK_PRIVACY = 
DMLOptions.defaultOptions.checkPrivacy;            // Check which privacy 
constraints are loaded and checked during federated execution
 
        public static boolean           USE_ACCELERATOR     = 
DMLOptions.defaultOptions.gpu;
        public static boolean           FORCE_ACCELERATOR   = 
DMLOptions.defaultOptions.forceGPU;
@@ -212,20 +214,22 @@ public class DMLScript
 
                try
                {
-                       STATISTICS          = dmlOptions.stats;
-                       STATISTICS_COUNT    = dmlOptions.statsCount;
-                       JMLC_MEM_STATISTICS = dmlOptions.memStats;
-                       USE_ACCELERATOR     = dmlOptions.gpu;
-                       FORCE_ACCELERATOR   = dmlOptions.forceGPU;
-                       EXPLAIN             = dmlOptions.explainType;
-                       EXEC_MODE           = dmlOptions.execMode;
-                       LINEAGE             = dmlOptions.lineage;
-                       LINEAGE_DEDUP       = dmlOptions.lineage_dedup;
-                       LINEAGE_REUSE       = dmlOptions.linReuseType;
-                       LINEAGE_POLICY      = dmlOptions.linCachePolicy;
-                       LINEAGE_ESTIMATE    = dmlOptions.lineage_estimate;
-                       CHECK_PRIVACY       = dmlOptions.checkPrivacy;
-                       LINEAGE_DEBUGGER        = dmlOptions.lineage_debugger;
+                       STATISTICS            = dmlOptions.stats;
+                       STATISTICS_COUNT      = dmlOptions.statsCount;
+                       FED_STATISTICS        = dmlOptions.fedStats;
+                       FED_STATISTICS_COUNT  = dmlOptions.fedStatsCount;
+                       JMLC_MEM_STATISTICS   = dmlOptions.memStats;
+                       USE_ACCELERATOR       = dmlOptions.gpu;
+                       FORCE_ACCELERATOR     = dmlOptions.forceGPU;
+                       EXPLAIN               = dmlOptions.explainType;
+                       EXEC_MODE             = dmlOptions.execMode;
+                       LINEAGE               = dmlOptions.lineage;
+                       LINEAGE_DEDUP         = dmlOptions.lineage_dedup;
+                       LINEAGE_REUSE         = dmlOptions.linReuseType;
+                       LINEAGE_POLICY        = dmlOptions.linCachePolicy;
+                       LINEAGE_ESTIMATE      = dmlOptions.lineage_estimate;
+                       CHECK_PRIVACY         = dmlOptions.checkPrivacy;
+                       LINEAGE_DEBUGGER      = dmlOptions.lineage_debugger;
 
                        String fnameOptConfig = dmlOptions.configFile;
                        boolean isFile = dmlOptions.filePath != null;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
new file mode 100644
index 0000000..14f29d9
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.io.Serializable;
+import java.net.InetSocketAddress;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.concurrent.Future;
+import javax.net.ssl.SSLException;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.CacheStatsCollection;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.GCStatsCollection;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.utils.Statistics;
+
+public class FederatedStatistics {
+       private static Set<Pair<String, Integer>> _fedWorkerAddresses = new 
HashSet<>();
+
+       public static void registerFedWorker(String host, int port) {
+               _fedWorkerAddresses.add(new ImmutablePair<>(host, new 
Integer(port)));
+       }
+
+       public static String displayFedWorkers() {
+               StringBuilder sb = new StringBuilder();
+               sb.append("Federated Worker Addresses:\n");
+               for(Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
+                       sb.append(String.format("  %s:%d", fedAddr.getLeft(), 
fedAddr.getRight().intValue()));
+                       sb.append("\n");
+               }
+               return sb.toString();
+       }
+
+       public static String displayFedStatistics(int numHeavyHitters) {
+               StringBuilder sb = new StringBuilder();
+               FedStatsCollection fedStats = collectFedStats();
+               sb.append("SystemDS Federated Statistics:\n");
+               sb.append(displayCacheStats(fedStats.cacheStats));
+               sb.append(String.format("Total JIT compile time:\t\t%.3f 
sec.\n", fedStats.jitCompileTime));
+               sb.append(displayGCStats(fedStats.gcStats));
+               sb.append(displayHeavyHitters(fedStats.heavyHitters, 
numHeavyHitters));
+               return sb.toString();
+       }
+
+       public static String displayCacheStats(CacheStatsCollection csc) {
+               StringBuilder sb = new StringBuilder();
+               sb.append(String.format("Cache hits 
(Mem/Li/WB/FS/HDFS):\t%d/%d/%d/%d/%d.\n",
+                       csc.memHits, csc.linHits, csc.fsBuffHits, csc.fsHits, 
csc.hdfsHits));
+               sb.append(String.format("Cache writes 
(Li/WB/FS/HDFS):\t%d/%d/%d/%d.\n",
+                       csc.linWrites, csc.fsBuffWrites, csc.fsWrites, 
csc.hdfsWrites));
+               sb.append(String.format("Cache times (ACQr/m, RLS, 
EXP):\t%.3f/%.3f/%.3f/%.3f sec.\n",
+                       csc.acqRTime, csc.acqMTime, csc.rlsTime, csc.expTime));
+               return sb.toString();
+       }
+
+       public static String displayGCStats(GCStatsCollection gcsc) {
+               StringBuilder sb = new StringBuilder();
+               sb.append(String.format("Total JVM GC count:\t\t%d.\n", 
gcsc.gcCount));
+               sb.append(String.format("Total JVM GC time:\t\t%.3f sec.\n", 
gcsc.gcTime));
+               return sb.toString();
+       }
+
+       public static String displayHeavyHitters(HashMap<String, Pair<Long, 
Double>> heavyHitters) {
+               return displayHeavyHitters(heavyHitters, 10);
+       }
+
+       public static String displayHeavyHitters(HashMap<String, Pair<Long, 
Double>> heavyHitters, int num) {
+               StringBuilder sb = new StringBuilder();
+               @SuppressWarnings("unchecked")
+               Entry<String, Pair<Long, Double>>[] hhArr = 
heavyHitters.entrySet().toArray(new Entry[0]);
+               Arrays.sort(hhArr, new Comparator<Entry<String, Pair<Long, 
Double>>>() {
+                       public int compare(Entry<String, Pair<Long, Double>> 
e1, Entry<String, Pair<Long, Double>> e2) {
+                               return 
e1.getValue().getRight().compareTo(e2.getValue().getRight());
+                       }
+               });
+               
+               sb.append("Heavy hitter instructions:\n");
+               final String numCol = "#";
+               final String instCol = "Instruction";
+               final String timeSCol = "Time(s)";
+               final String countCol = "Count";
+               int numHittersToDisplay = Math.min(num, hhArr.length);
+               int maxNumLen = String.valueOf(numHittersToDisplay).length();
+               int maxInstLen = instCol.length();
+               int maxTimeSLen = timeSCol.length();
+               int maxCountLen = countCol.length();
+               DecimalFormat sFormat = new DecimalFormat("#,##0.000");
+               for (int counter = 0; counter < numHittersToDisplay; counter++) 
{
+                       Entry<String, Pair<Long, Double>> hh = 
hhArr[hhArr.length - 1 - counter];
+                       String instruction = hh.getKey();
+                       maxInstLen = Math.max(maxInstLen, instruction.length());
+                       String timeString = 
sFormat.format(hh.getValue().getRight());
+                       maxTimeSLen = Math.max(maxTimeSLen, 
timeString.length());
+                       maxCountLen = Math.max(maxCountLen, 
String.valueOf(hh.getValue().getLeft()).length());
+               }
+               maxInstLen = Math.min(maxInstLen, 
DMLScript.STATISTICS_MAX_WRAP_LEN);
+               sb.append(String.format( " %" + maxNumLen + "s  %-" + 
maxInstLen + "s  %"
+                       + maxTimeSLen + "s  %" + maxCountLen + "s", numCol, 
instCol, timeSCol, countCol));
+               sb.append("\n");
+
+               for (int counter = 0; counter < numHittersToDisplay; counter++) 
{
+                       String instruction = hhArr[hhArr.length - 1 - 
counter].getKey();
+                       String [] wrappedInstruction = 
Statistics.wrap(instruction, maxInstLen);
+
+                       String timeSString = sFormat.format(hhArr[hhArr.length 
- 1 - counter].getValue().getRight());
+
+                       long count = hhArr[hhArr.length - 1 - 
counter].getValue().getLeft();
+                       int numLines = wrappedInstruction.length;
+                       
+                       for(int wrapIter = 0; wrapIter < numLines; wrapIter++) {
+                               String instStr = (wrapIter < 
wrappedInstruction.length) ? wrappedInstruction[wrapIter] : "";
+                               if(wrapIter == 0) {
+                                       sb.append(String.format(
+                                               " %" + maxNumLen + "d  %-" + 
maxInstLen + "s  %" + maxTimeSLen + "s  %" 
+                                               + maxCountLen + "d", (counter + 
1), instStr, timeSString, count));
+                               }
+                               else {
+                                       sb.append(String.format(
+                                               " %" + maxNumLen + "s  %-" + 
maxInstLen + "s  %" + maxTimeSLen + "s  %" 
+                                               + maxCountLen + "s", "", 
instStr, "", ""));
+                               }
+                               sb.append("\n");
+                       }
+               }
+
+               return sb.toString();
+       }
+
+       private static FedStatsCollection collectFedStats() {
+               Future<FederatedResponse>[] responses = getFederatedResponses();
+               FedStatsCollection aggFedStats = new FedStatsCollection();
+               for(Future<FederatedResponse> res : responses) {
+                       try {
+                               Object[] tmp = res.get().getData();
+                               if(tmp[0] instanceof FedStatsCollection)
+                                       
aggFedStats.aggregate((FedStatsCollection)tmp[0]);
+                       } catch(Exception e) {
+                               throw new DMLRuntimeException("Exception of 
type " + e.getClass().toString() 
+                                       + " thrown while " + "getting the 
federated stats of the federated response: ", e);
+                       }
+               }
+               return aggFedStats;
+       }
+
+       private static Future<FederatedResponse>[] getFederatedResponses() {
+               List<Future<FederatedResponse>> ret = new ArrayList<>();
+               for(Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
+                       InetSocketAddress isa = new 
InetSocketAddress(fedAddr.getLeft(), fedAddr.getRight());
+                       FederatedRequest frUDF = new 
FederatedRequest(RequestType.EXEC_UDF, -1, 
+                               new FedStatsCollectFunction());
+                       try {
+                               
ret.add(FederatedData.executeFederatedOperation(isa, frUDF));
+                       } catch(SSLException ssle) {
+                               throw new DMLRuntimeException("SSLException 
while getting the federated stats from "
+                                       + isa.toString() + ": ", ssle);
+                       } catch (Exception e) {
+                               throw new DMLRuntimeException("Exeption of type 
" + e.getClass().getName() 
+                                       + " thrown while getting stats from 
federated worker: ", e);
+                       }
+               }
+               @SuppressWarnings("unchecked")
+               Future<FederatedResponse>[] retArr = ret.toArray(new Future[0]);
+               return retArr;
+       }
+
+       private static class FedStatsCollectFunction extends FederatedUDF {
+               private static final long serialVersionUID = 1L;
+
+               public FedStatsCollectFunction() {
+                       super(new long[] { });
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       FedStatsCollection fedStats = new FedStatsCollection();
+                       fedStats.collectStats();
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, fedStats);
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
+       protected static class FedStatsCollection implements Serializable {
+               private static final long serialVersionUID = 1L;
+
+               private void collectStats() {
+                       cacheStats.collectStats();
+                       jitCompileTime = 
((double)Statistics.getJITCompileTime()) / 1000; // in sec
+                       gcStats.collectStats();
+                       heavyHitters = Statistics.getHeavyHittersHashMap();
+               }
+               
+               private void aggregate(FedStatsCollection that) {
+                       cacheStats.aggregate(that.cacheStats);
+                       jitCompileTime += that.jitCompileTime;
+                       gcStats.aggregate(that.gcStats);
+                       that.heavyHitters.forEach(
+                               (key, value) -> heavyHitters.merge(key, value, 
(v1, v2) ->
+                                       new ImmutablePair<>(v1.getLeft() + 
v2.getLeft(), v1.getRight() + v2.getRight()))
+                       );
+               }
+
+               protected static class CacheStatsCollection implements 
Serializable {
+                       private static final long serialVersionUID = 1L;
+
+                       private void collectStats() {
+                               memHits = CacheStatistics.getMemHits();
+                               linHits = CacheStatistics.getLinHits();
+                               fsBuffHits = CacheStatistics.getFSBuffHits();
+                               fsHits = CacheStatistics.getFSHits();
+                               hdfsHits = CacheStatistics.getHDFSHits();
+                               linWrites = CacheStatistics.getLinWrites();
+                               fsBuffWrites = 
CacheStatistics.getFSBuffWrites();
+                               fsWrites = CacheStatistics.getFSWrites();
+                               hdfsWrites = CacheStatistics.getHDFSWrites();
+                               acqRTime = 
((double)CacheStatistics.getAcquireRTime()) / 1000000000; // in sec
+                               acqMTime = 
((double)CacheStatistics.getAcquireMTime()) / 1000000000; // in sec
+                               rlsTime = 
((double)CacheStatistics.getReleaseTime()) / 1000000000; // in sec
+                               expTime = 
((double)CacheStatistics.getExportTime()) / 1000000000; // in sec
+                       }
+
+                       private void aggregate(CacheStatsCollection that) {
+                               memHits += that.memHits;
+                               linHits += that.linHits;
+                               fsBuffHits += that.fsBuffHits;
+                               fsHits += that.fsHits;
+                               hdfsHits += that.hdfsHits;
+                               linWrites += that.linWrites;
+                               fsBuffWrites += that.fsBuffWrites;
+                               fsWrites += that.fsWrites;
+                               hdfsWrites += that.hdfsWrites;
+                               acqRTime += that.acqRTime;
+                               acqMTime += that.acqMTime;
+                               rlsTime += that.rlsTime;
+                               expTime += that.expTime;
+                       }
+
+                       private long memHits = 0;
+                       private long linHits = 0;
+                       private long fsBuffHits = 0;
+                       private long fsHits = 0;
+                       private long hdfsHits = 0;
+                       private long linWrites = 0;
+                       private long fsBuffWrites = 0;
+                       private long fsWrites = 0;
+                       private long hdfsWrites = 0;
+                       private double acqRTime = 0;
+                       private double acqMTime = 0;
+                       private double rlsTime = 0;
+                       private double expTime = 0;
+               }
+
+               protected static class GCStatsCollection implements 
Serializable {
+                       private static final long serialVersionUID = 1L;
+
+                       private void collectStats() {
+                               gcCount = Statistics.getJVMgcCount();
+                               gcTime = ((double)Statistics.getJVMgcTime()) / 
1000; // in sec
+                       }
+
+                       private void aggregate(GCStatsCollection that) {
+                               gcCount += that.gcCount;
+                               gcTime += that.gcTime;
+                       }
+
+                       private long gcCount = 0;
+                       private double gcTime = 0;
+               }
+
+               private CacheStatsCollection cacheStats = new 
CacheStatsCollection();
+               private double jitCompileTime = 0;
+               private GCStatsCollection gcStats = new GCStatsCollection();
+               private HashMap<String, Pair<Long, Double>> heavyHitters = new 
HashMap<>();
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 9b6d3f0..b4d2e04 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -47,6 +47,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -120,6 +121,10 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                                String host = parsedValues[0];
                                int port = Integer.parseInt(parsedValues[1]);
                                String filePath = parsedValues[2];
+
+                               // register the federated worker for federated 
statistics creation
+                               FederatedStatistics.registerFedWorker(host, 
port);
+
                                // get beginning and end of data ranges
                                List<Data> rangesData = ranges.getData();
                                Data beginData = rangesData.get(i * 2);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index d4247a7..dd8ddce 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -25,6 +25,7 @@ import java.lang.management.ManagementFactory;
 import java.text.DecimalFormat;
 import java.util.Arrays;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map.Entry;
 import java.util.Set;
@@ -32,12 +33,15 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.DoubleAdder;
 import java.util.concurrent.atomic.LongAdder;
 
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -732,6 +736,17 @@ public class Statistics
                return (tmp != null) ? tmp.count.longValue() : 0;
        }
 
+       public static HashMap<String, Pair<Long, Double>> 
getHeavyHittersHashMap() {
+               HashMap<String, Pair<Long, Double>> heavyHitters = new 
HashMap<>();
+               for(String opcode : _instStats.keySet()) {
+                       InstStats val = _instStats.get(opcode);
+                       long count = val.count.longValue();
+                       double time = val.time.longValue() / 1000000000d; // in 
sec
+                       heavyHitters.put(opcode, new ImmutablePair<Long, 
Double>(new Long(count), new Double(time)));
+               }
+               return heavyHitters;
+       }
+
        /**
         * Obtain a string tabular representation of the heavy hitter 
instructions
         * that displays the time, instruction count, and optionally GPU stats 
about
@@ -956,7 +971,7 @@ public class Statistics
        }
        
        
-       private static String [] wrap(String str, int wrapLength) {
+       public static String [] wrap(String str, int wrapLength) {
                int numLines = (int) Math.ceil( ((double)str.length()) / 
wrapLength);
                int len = str.length();
                String [] ret = new String[numLines];
@@ -1105,6 +1120,11 @@ public class Statistics
                if (DMLScript.CHECK_PRIVACY)
                        sb.append(CheckedConstraintsLog.display());
 
+               if(DMLScript.FED_STATISTICS) {
+                       sb.append("\n");
+                       
sb.append(FederatedStatistics.displayFedStatistics(DMLScript.FED_STATISTICS_COUNT));
+               }
+
                return sb.toString();
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
index 09ca19e..54d89e6 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
@@ -30,14 +30,12 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
-@Ignore
 public class FederatedStatisticsTest extends AutomatedTestBase {
 
        private final static String TEST_DIR = "functions/federated/";
@@ -105,7 +103,6 @@ public class FederatedStatisticsTest extends 
AutomatedTestBase {
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
                
-
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
                programArgs = new String[] {"-args", input("X1"), input("X2"), 
input("Y"), expected("Z")};
@@ -113,7 +110,7 @@ public class FederatedStatisticsTest extends 
AutomatedTestBase {
 
                // Run actual dml script with federated matrix
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-stats", "30", "-nvargs",
+               programArgs = new String[] {"-stats", "30", "-fedStats", 
"-nvargs",
                        "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
                        "in_Y=" + input("Y"), "out=" + output("Z")};

Reply via email to