Repository: systemml Updated Branches: refs/heads/master 7dbbaaa76 -> fe5ed5947
[SYSTEMML-2069] Lock-free maintenance of heavy hitter statistics This patch addresses issues of thread contention on the maintenance of heavy hitter statistics. In multi-threaded applications with many small operations, the synchronized maintenance of these heavy hitters can turn into a bottleneck. Example scenarios are JMLC scoring with multiple JMLC streams per node and local parfor scripts. We simply replace the synchronized maintenance with the use of a concurrent hash map as well as a consolidated statistics objects for times and counts, which rely on the scalable LongAdder primitive. On an end-to-end JMLC scoring scenario, this patch improved the total runtime from 261s to 117s (with statistics enabled), which is good compared to the baseline of 101s (without statistics) because the time measurement itself is also expensive. Similarly, end-to-end algorithms such as decision tree and random forest will see similar benefits for the case of enabled statistics. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3737542d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3737542d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3737542d Branch: refs/heads/master Commit: 3737542deccf83adafc28edb56ceb3f0a136b370 Parents: 7dbbaaa Author: Matthias Boehm <[email protected]> Authored: Thu Jan 11 18:12:38 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Jan 11 18:12:38 2018 -0800 ---------------------------------------------------------------------- .../java/org/apache/sysml/utils/Statistics.java | 69 +++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/3737542d/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index 762c167..414f0e7 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -25,10 +25,10 @@ import java.lang.management.ManagementFactory; import java.text.DecimalFormat; import java.util.Arrays; import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map.Entry; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; import org.apache.sysml.api.DMLScript; @@ -47,13 +47,20 @@ import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; * This class captures all statistics. */ public class Statistics -{ +{ + private static class InstStats { + private final LongAdder time = new LongAdder(); + private final LongAdder count = new LongAdder(); + } + private static long compileStartTime = 0; private static long compileEndTime = 0; - private static long execStartTime = 0; private static long execEndTime = 0; - + + //heavy hitter counts and times + private static final ConcurrentHashMap<String,InstStats>_instStats = new ConcurrentHashMap<>(); + // number of compiled/executed MR jobs private static final LongAdder numExecutedMRJobs = new LongAdder(); private static final LongAdder numCompiledMRJobs = new LongAdder(); @@ -103,10 +110,6 @@ public class Statistics private static long parforOptCount = 0; //count private static long parforInitTime = 0; //in milli sec private static long parforMergeTime = 0; //in milli sec - - //heavy hitter counts and times - private static HashMap<String,Long> _cpInstTime = new HashMap<>(); - private static HashMap<String,Long> _cpInstCounts = new HashMap<>(); private static final LongAdder lTotalUIPVar = new LongAdder(); private static final LongAdder lTotalLix = new LongAdder(); @@ -463,8 +466,7 @@ public class Statistics } public static void resetCPHeavyHitters(){ - _cpInstTime.clear(); - _cpInstCounts.clear(); + _instStats.clear(); } public static void setSparkCtxCreateTime(long ns) { @@ -527,26 +529,29 @@ public class Statistics /** * "Maintains" or adds time to per instruction/op timers, also increments associated count - * @param instructionName name of the instruction/op + * @param instName name of the instruction/op * @param timeNanos time in nano seconds */ - public synchronized static void maintainCPHeavyHitters( String instructionName, long timeNanos ) - { - Long oldVal = _cpInstTime.getOrDefault(instructionName, 0L); - _cpInstTime.put(instructionName, oldVal + timeNanos); - - Long oldCnt = _cpInstCounts.getOrDefault(instructionName, 0L); - _cpInstCounts.put(instructionName, oldCnt + 1); + public static void maintainCPHeavyHitters( String instName, long timeNanos ) { + //maintain instruction entry + InstStats tmp = _instStats.get(instName); + if( tmp == null ) { + tmp = new InstStats(); + _instStats.put(instName, tmp); + } + + //thread-local maintenance of instruction stats + tmp.time.add(timeNanos); + tmp.count.increment(); } - public static Set<String> getCPHeavyHitterOpCodes() { - return _cpInstTime.keySet(); + return _instStats.keySet(); } public static long getCPHeavyHitterCount(String opcode) { - Long tmp = _cpInstCounts.get(opcode); - return (tmp != null) ? tmp : 0; + InstStats tmp = _instStats.get(opcode); + return (tmp != null) ? tmp.count.longValue() : 0; } /** @@ -560,15 +565,15 @@ public class Statistics * format */ public static String getHeavyHitters(int num) { - int len = _cpInstTime.size(); + int len = _instStats.size(); if (num <= 0 || len <= 0) return "-"; // get top k via sort - Entry<String, Long>[] tmp = _cpInstTime.entrySet().toArray(new Entry[len]); - Arrays.sort(tmp, new Comparator<Entry<String, Long>>() { - public int compare(Entry<String, Long> e1, Entry<String, Long> e2) { - return e1.getValue().compareTo(e2.getValue()); + Entry<String, InstStats>[] tmp = _instStats.entrySet().toArray(new Entry[len]); + Arrays.sort(tmp, new Comparator<Entry<String, InstStats>>() { + public int compare(Entry<String, InstStats> e1, Entry<String, InstStats> e2) { + return Long.compare(e1.getValue().time.longValue(), e2.getValue().time.longValue()); } }); @@ -585,9 +590,9 @@ public class Statistics int maxCountLen = countCol.length(); DecimalFormat sFormat = new DecimalFormat("#,##0.000"); for (int i = 0; i < numHittersToDisplay; i++) { - Entry<String, Long> hh = tmp[len - 1 - i]; + Entry<String, InstStats> hh = tmp[len - 1 - i]; String instruction = hh.getKey(); - Long timeNs = hh.getValue(); + long timeNs = hh.getValue().time.longValue(); double timeS = (double) timeNs / 1000000000.0; maxInstLen = Math.max(maxInstLen, instruction.length()); @@ -595,7 +600,7 @@ public class Statistics String timeSString = sFormat.format(timeS); maxTimeSLen = Math.max(maxTimeSLen, timeSString.length()); - maxCountLen = Math.max(maxCountLen, String.valueOf(_cpInstCounts.get(instruction)).length()); + maxCountLen = Math.max(maxCountLen, String.valueOf(hh.getValue().count.longValue()).length()); } maxInstLen = Math.min(maxInstLen, DMLScript.STATISTICS_MAX_WRAP_LEN); sb.append(String.format( @@ -610,11 +615,11 @@ public class Statistics String instruction = tmp[len - 1 - i].getKey(); String [] wrappedInstruction = wrap(instruction, maxInstLen); - Long timeNs = tmp[len - 1 - i].getValue(); + long timeNs = tmp[len - 1 - i].getValue().time.longValue(); double timeS = (double) timeNs / 1000000000.0; String timeSString = sFormat.format(timeS); - Long count = _cpInstCounts.get(instruction); + long count = tmp[len - 1 - i].getValue().count.longValue(); int numLines = wrappedInstruction.length; String [] miscTimers = null;
