This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9f96718939 [SYSTEMDS-3714] N-gram statistics of operation sequences
9f96718939 is described below
commit 9f96718939bd43fedcf942ac71f6fd70bdcf48d6
Author: Jaybit0 <[email protected]>
AuthorDate: Wed Jul 24 18:31:43 2024 +0200
[SYSTEMDS-3714] N-gram statistics of operation sequences
Closes #2045.
---
src/main/java/org/apache/sysds/api/DMLOptions.java | 27 +++
src/main/java/org/apache/sysds/api/DMLScript.java | 9 +
.../sysds/runtime/controlprogram/ProgramBlock.java | 6 +-
.../java/org/apache/sysds/utils/Statistics.java | 231 ++++++++++++++++++-
.../org/apache/sysds/utils/stats/NGramBuilder.java | 248 +++++++++++++++++++++
.../test/applications/ApplyTransformTest.java | 9 +
.../apache/sysds/test/applications/L2SVMTest.java | 3 +
7 files changed, 529 insertions(+), 4 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java
b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 70af5ba9e8..acacc39572 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -53,7 +53,10 @@ public class DMLOptions {
public String configFile = null; // Path
to config file if default config and default config is to be overridden
public boolean clean = false; //
Whether to clean up all SystemDS working directories (FS, DFS)
public boolean stats = false; //
Whether to record and print the statistics
+ public boolean statsNGrams = false; // Whether
to record and print the statistics n-grams
public int statsCount = 10; //
Default statistics count
+ public int[] statsNGramSizes = { 3 }; //
Default n-gram tuple sizes
+ public int statsTopKNGrams = 10; // How
many of the most heavy hitting n-grams are displayed
public boolean fedStats = false; //
Whether to record and print the federated statistics
public int fedStatsCount = 10; //
Default federated statistics count
public boolean memStats = false; // max
memory statistics
@@ -212,6 +215,26 @@ public class DMLOptions {
}
}
}
+
+ dmlOptions.statsNGrams = line.hasOption("ngrams");
+ if (dmlOptions.statsNGrams){
+ String[] nGramArgs = line.getOptionValues("ngrams");
+ if (nGramArgs.length == 2) {
+ try {
+ String[] nGramSizeSplit =
nGramArgs[0].split(",");
+ dmlOptions.statsNGramSizes = new
int[nGramSizeSplit.length];
+
+ for (int i = 0; i <
nGramSizeSplit.length; i++) {
+ dmlOptions.statsNGramSizes[i] =
Integer.parseInt(nGramSizeSplit[i]);
+ }
+
+ dmlOptions.statsTopKNGrams =
Integer.parseInt(nGramArgs[1]);
+ } catch (NumberFormatException e) {
+ throw new
org.apache.commons.cli.ParseException("Invalid argument specified for -ngrams
option, must be a valid integer");
+ }
+ }
+ }
+
dmlOptions.fedStats = line.hasOption("fedStats");
if (dmlOptions.fedStats) {
String fedStatsCount = line.getOptionValue("fedStats");
@@ -335,6 +358,9 @@ public class DMLOptions {
Option statsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary
execution statistics; heavy hitter <count> is 10 unless overridden; default
off")
.hasOptionalArg().create("stats");
+ Option ngramsOpt = OptionBuilder//.withArgName("ngrams")
+ .withDescription("monitors and reports the most
occurring n-grams; -ngrams <comma separated n's> <topK>")
+ .hasOptionalArgs(2).create("ngrams");
Option fedStatsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary
execution statistics of federated workers; heavy hitter <count> is 10 unless
overridden; default off")
.hasOptionalArg().create("fedStats");
@@ -396,6 +422,7 @@ public class DMLOptions {
options.addOption(configOpt);
options.addOption(cleanOpt);
options.addOption(statsOpt);
+ options.addOption(ngramsOpt);
options.addOption(fedStatsOpt);
options.addOption(memOpt);
options.addOption(explainOpt);
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 3443f68740..cd86426a42 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -94,10 +94,16 @@ public class DMLScript
private static ExecMode EXEC_MODE =
DMLOptions.defaultOptions.execMode;
// Enable/disable to print statistics
public static boolean STATISTICS =
DMLOptions.defaultOptions.stats;
+ // Enable/disable to print statistics n-grams
+ public static boolean STATISTICS_NGRAMS =
DMLOptions.defaultOptions.statsNGrams;
// Enable/disable to gather memory use stats in JMLC
public static boolean JMLC_MEM_STATISTICS = false;
// Set maximum heavy hitter count
public static int STATISTICS_COUNT =
DMLOptions.defaultOptions.statsCount;
+ // The sizes of recorded n-gram tuples
+ public static int[] STATISTICS_NGRAM_SIZES =
DMLOptions.defaultOptions.statsNGramSizes;
+ // Set top k displayed n-grams limit
+ public static int STATISTICS_TOP_K_NGRAMS =
DMLOptions.defaultOptions.statsTopKNGrams;
// Set statistics maximum wrap length
public static int STATISTICS_MAX_WRAP_LEN = 30;
// Enable/disable to print federated statistics
@@ -250,6 +256,9 @@ public class DMLScript
{
STATISTICS = dmlOptions.stats;
STATISTICS_COUNT = dmlOptions.statsCount;
+ STATISTICS_NGRAMS = dmlOptions.statsNGrams;
+ STATISTICS_NGRAM_SIZES = dmlOptions.statsNGramSizes;
+ STATISTICS_TOP_K_NGRAMS = dmlOptions.statsTopKNGrams;
FED_STATISTICS = dmlOptions.fedStats;
FED_STATISTICS_COUNT = dmlOptions.fedStatsCount;
JMLC_MEM_STATISTICS = dmlOptions.memStats;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 34b954148b..4e75d5456f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -241,7 +241,7 @@ public abstract class ProgramBlock implements ParseInfo {
private void executeSingleInstruction(Instruction currInst,
ExecutionContext ec) {
try {
// start time measurement for statistics
- long t0 = (DMLScript.STATISTICS ||
LOG.isTraceEnabled()) ? System.nanoTime() : 0;
+ long t0 = (DMLScript.STATISTICS ||
DMLScript.STATISTICS_NGRAMS || LOG.isTraceEnabled()) ? System.nanoTime() : 0;
// pre-process instruction (inst patching, listeners,
lineage)
Instruction tmp = currInst.preprocessInstruction(ec);
@@ -263,6 +263,10 @@ public abstract class ProgramBlock implements ParseInfo {
if(DMLScript.STATISTICS) {
Statistics.maintainCPHeavyHitters(tmp.getExtendedOpcode(), System.nanoTime() -
t0);
}
+
+ if (DMLScript.STATISTICS_NGRAMS) {
+
Statistics.maintainNGrams(tmp.getExtendedOpcode(), System.nanoTime() - t0);
+ }
}
// optional trace information (instruction and runtime)
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 5cab7dbd30..a7c764cf78 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -34,10 +34,11 @@ import
org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.utils.stats.CodegenStatistics;
-import org.apache.sysds.utils.stats.RecompileStatistics;
+import org.apache.sysds.utils.stats.NGramBuilder;
import org.apache.sysds.utils.stats.NativeStatistics;
-import org.apache.sysds.utils.stats.ParamServStatistics;
import org.apache.sysds.utils.stats.ParForStatistics;
+import org.apache.sysds.utils.stats.ParamServStatistics;
+import org.apache.sysds.utils.stats.RecompileStatistics;
import org.apache.sysds.utils.stats.SparkStatistics;
import org.apache.sysds.utils.stats.TransformStatistics;
@@ -45,10 +46,13 @@ import java.lang.management.CompilationMXBean;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.text.DecimalFormat;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
+import java.util.Locale;
+import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
@@ -64,6 +68,46 @@ public class Statistics
private final LongAdder time = new LongAdder();
private final LongAdder count = new LongAdder();
}
+
+ public static class NGramStats {
+
+ public final long n;
+ public final long cumTimeNanos;
+ public final double m2;
+
+ public static <T> Comparator<NGramBuilder.NGramEntry<T,
NGramStats>> getComparator() {
+ return Comparator.comparingLong(entry ->
entry.getCumStats().cumTimeNanos);
+ }
+
+ public static NGramStats merge(NGramStats stats1, NGramStats
stats2) {
+ // Using the algorithm from:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
+ long newN = stats1.n + stats2.n;
+ long cumTimeNanos = stats1.cumTimeNanos +
stats2.cumTimeNanos;
+
+ // Ensure the calculation uses floating-point arithmetic
+ double mean1 = (double) stats1.cumTimeNanos /
1000000000d / stats1.n;
+ double mean2 = (double) stats2.cumTimeNanos /
1000000000d / stats2.n;
+ double delta = mean2 - mean1;
+
+ double newM2 = stats1.m2 + stats2.m2 + delta * delta *
stats1.n * stats2.n / (double)newN;
+
+ return new NGramStats(newN, cumTimeNanos, newM2);
+ }
+
+ public NGramStats(final long n, final long cumTimeNanos, final
double m2) {
+ this.n = n;
+ this.cumTimeNanos = cumTimeNanos;
+ this.m2 = m2;
+ }
+
+ public double getTimeVariance() {
+ return m2 / Math.max(n-1, 1);
+ }
+
+ public String toString() {
+ return String.format(Locale.US, "%.5f", (cumTimeNanos /
1000000000d));
+ }
+ }
private static long compileStartTime = 0;
private static long compileEndTime = 0;
@@ -71,7 +115,8 @@ public class Statistics
private static long execEndTime = 0;
//heavy hitter counts and times
- private static final ConcurrentHashMap<String,InstStats>_instStats =
new ConcurrentHashMap<>();
+ private static final ConcurrentHashMap<String,InstStats> _instStats =
new ConcurrentHashMap<>();
+ private static final ConcurrentHashMap<String, NGramBuilder<String,
NGramStats>[]> _instStatsNGram = new ConcurrentHashMap<>();
// number of compiled/executed SP instructions
private static final LongAdder numExecutedSPInst = new LongAdder();
@@ -252,6 +297,8 @@ public class Statistics
DMLCompressionStatistics.reset();
FederatedStatistics.reset();
+
+ _instStatsNGram.clear();
}
public static void resetJITCompileTime(){
@@ -353,6 +400,177 @@ public class Statistics
tmp.time.add(timeNanos);
tmp.count.increment();
}
+
+ public static void maintainNGrams(String instName, long timeNanos) {
+ NGramBuilder<String, NGramStats>[] tmp =
_instStatsNGram.computeIfAbsent(Thread.currentThread().getName(), k -> {
+ NGramBuilder<String, NGramStats>[] threadEntry = new
NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length];
+ for (int i = 0; i < threadEntry.length; i++) {
+ threadEntry[i] = new NGramBuilder<String,
NGramStats>(String.class, NGramStats.class,
DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge);
+ }
+ return threadEntry;
+ });
+
+ for (int i = 0; i < tmp.length; i++)
+ tmp[i].append(instName, new NGramStats(1, timeNanos,
0));
+ }
+
+ public static NGramBuilder<String, NGramStats>[] mergeNGrams() {
+ NGramBuilder<String, NGramStats>[] builders = new
NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length];
+
+ for (int i = 0; i < builders.length; i++) {
+ builders[i] = new NGramBuilder<String,
NGramStats>(String.class, NGramStats.class,
DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge);
+ }
+
+ for (int i = 0; i < DMLScript.STATISTICS_NGRAM_SIZES.length;
i++) {
+ for (Map.Entry<String, NGramBuilder<String,
NGramStats>[]> entry : _instStatsNGram.entrySet()) {
+ NGramBuilder<String, NGramStats> mbuilder =
entry.getValue()[i];
+ builders[i].merge(mbuilder);
+ }
+ }
+
+ return builders;
+ }
+
+ public static String getNGramStdDevs(NGramStats[] stats, int offset,
int prec, boolean displayZero) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("(");
+ boolean containsData = false;
+ int actualIndex;
+ for (int i = 0; i < stats.length; i++) {
+ if (i != 0)
+ sb.append(", ");
+ actualIndex = (offset + i) % stats.length;
+ double var = 1000000000d * stats[actualIndex].n *
Math.sqrt(stats[actualIndex].getTimeVariance()) /
stats[actualIndex].cumTimeNanos;
+ if (displayZero || var >= Math.pow(10, -prec)) {
+ sb.append(String.format(Locale.US, "%." + prec
+ "f", var));
+ containsData = true;
+ }
+ }
+ sb.append(")");
+ return containsData ? sb.toString() : "-";
+ }
+
+ public static String getNGramAvgTimes(NGramStats[] stats, int offset,
int prec) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("(");
+ int actualIndex;
+ for (int i = 0; i < stats.length; i++) {
+ if (i != 0)
+ sb.append(", ");
+ actualIndex = (offset + i) % stats.length;
+ double var = (stats[actualIndex].cumTimeNanos /
1000000000d) / stats[actualIndex].n;
+ sb.append(String.format(Locale.US, "%." + prec + "f",
var));
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+
+ public static String nGramToCSV(final NGramBuilder<String, NGramStats>
mbuilder) {
+ ArrayList<String> colList = new ArrayList<>();
+ colList.add("N-Gram");
+ colList.add("Time[s]");
+
+ for (int j = 0; j < mbuilder.getSize(); j++)
+ colList.add("Col" + (j + 1));
+ for (int j = 0; j < mbuilder.getSize(); j++)
+ colList.add("Col" + (j + 1) + "::Mean(Time[s])");
+ for (int j = 0; j < mbuilder.getSize(); j++)
+ colList.add("Col" + (j + 1) + "::StdDev(Time[s])/Col" +
(j + 1) + "::Mean(Time[s])");
+
+ colList.add("Count");
+
+ return NGramBuilder.toCSV(colList.toArray(new
String[colList.size()]), mbuilder.getTopK(100000,
Statistics.NGramStats.getComparator(), true), e -> {
+ StringBuilder builder = new StringBuilder();
+ builder.append(e.getIdentifier().replace("(",
"").replace(")", "").replace(", ", ","));
+ builder.append(",");
+
builder.append(Statistics.getNGramAvgTimes(e.getStats(), e.getOffset(),
9).replace("-", "").replace("(", "").replace(")", ""));
+ builder.append(",");
+ String stdDevs =
Statistics.getNGramStdDevs(e.getStats(), e.getOffset(), 9, true).replace("-",
"").replace("(", "").replace(")", "");
+ if (stdDevs.isEmpty()) {
+ for (int j = 0; j < mbuilder.getSize()-1; j++)
+ builder.append(",");
+ } else {
+ builder.append(stdDevs);
+ }
+ return builder.toString();
+ });
+ }
+
+ public static String getCommonNGrams(NGramBuilder<String, NGramStats>
builder, int num) {
+ if (num <= 0 || _instStatsNGram.size() <= 0)
+ return "-";
+
+ //NGramBuilder<String, Long> builder = mergeNGrams();
+ @SuppressWarnings("unchecked")
+ NGramBuilder.NGramEntry<String, NGramStats>[] topNGrams =
builder.getTopK(num, NGramStats.getComparator(),
true).toArray(NGramBuilder.NGramEntry[]::new);
+
+ final String numCol = "#";
+ final String instCol = "N-Gram";
+ final String timeSCol = "Time(s)";
+ final String timeSVar = "StdDev(t)/Mean(t)";
+ final String countCol = "Count";
+ StringBuilder sb = new StringBuilder();
+ int len = topNGrams.length;
+ int numHittersToDisplay = Math.min(num, len);
+ int maxNumLen = String.valueOf(numHittersToDisplay).length();
+ int maxInstLen = instCol.length();
+ int maxTimeSLen = timeSCol.length();
+ int maxTimeSVarLen = timeSVar.length();
+ int maxCountLen = countCol.length();
+ DecimalFormat sFormat = new DecimalFormat("#,##0.000");
+
+ for (int i = 0; i < numHittersToDisplay; i++) {
+ long timeNs = topNGrams[i].getCumStats().cumTimeNanos;
+ String instruction = topNGrams[i].getIdentifier();
+ double timeS = timeNs / 1000000000d;
+
+
+ maxInstLen = Math.max(maxInstLen, instruction.length()
+ 1);
+
+ String timeSString = sFormat.format(timeS);
+ String timeSVarString =
getNGramStdDevs(topNGrams[i].getStats(), topNGrams[i].getOffset(), 3, false);
+ maxTimeSLen = Math.max(maxTimeSLen,
timeSString.length());
+ maxTimeSVarLen = Math.max(maxTimeSVarLen,
timeSVarString.length());
+
+ maxCountLen = Math.max(maxCountLen,
String.valueOf(topNGrams[i].getOccurrences()).length());
+ }
+
+ maxInstLen = Math.min(maxInstLen,
DMLScript.STATISTICS_MAX_WRAP_LEN);
+ sb.append(String.format( " %" + maxNumLen + "s %-" +
maxInstLen + "s %"
+ + maxTimeSLen + "s %" + maxTimeSVarLen + "s
%" + maxCountLen + "s", numCol, instCol, timeSCol, timeSVar, countCol));
+ sb.append("\n");
+ for (int i = 0; i < numHittersToDisplay; i++) {
+ String instruction = topNGrams[i].getIdentifier();
+ String [] wrappedInstruction = wrap(instruction,
maxInstLen);
+
+ //long timeNs = tmp[len - 1 -
i].getValue().time.longValue();
+ double timeS = topNGrams[i].getCumStats().cumTimeNanos
/ 1000000000d;
+ double timeVar =
topNGrams[i].getCumStats().getTimeVariance();
+ String timeSString = sFormat.format(timeS);
+ String timeVarString =
getNGramStdDevs(topNGrams[i].getStats(), topNGrams[i].getOffset(), 3,
false);//sFormat.format(timeVar);
+
+ long count = topNGrams[i].getOccurrences();
+ int numLines = wrappedInstruction.length;
+
+ for(int wrapIter = 0; wrapIter < numLines; wrapIter++) {
+ String instStr = (wrapIter <
wrappedInstruction.length) ? wrappedInstruction[wrapIter] : "";
+ if(wrapIter == 0) {
+ // Display instruction count
+ sb.append(String.format(
+ " %" + maxNumLen + "d
%-" + maxInstLen + "s %" + maxTimeSLen + "s %" + maxTimeSVarLen + "s %" +
maxCountLen + "d",
+ (i + 1), instStr,
timeSString, timeVarString, count));
+ }
+ else {
+ sb.append(String.format(
+ " %" + maxNumLen + "s
%-" + maxInstLen + "s %" + maxTimeSLen + "s %" + maxTimeSVarLen + "s %" +
maxCountLen + "s",
+ "", instStr, "", "",
""));
+ }
+ sb.append("\n");
+ }
+ }
+
+ return sb.toString();
+ }
public static void maintainCPFuncCallStats(String instName) {
InstStats tmp = _instStats.get(instName);
@@ -679,6 +897,13 @@ public class Statistics
sb.append("Heavy hitter instructions:\n" +
getHeavyHitters(maxHeavyHitters));
}
+ if (DMLScript.STATISTICS_NGRAMS) {
+ NGramBuilder<String, NGramStats>[] mergedNGrams =
mergeNGrams();
+ for (int i = 0; i <
DMLScript.STATISTICS_NGRAM_SIZES.length; i++) {
+ sb.append("Most common " +
DMLScript.STATISTICS_NGRAM_SIZES[i] + "-grams (sorted by absolute time):\n" +
getCommonNGrams(mergedNGrams[i], DMLScript.STATISTICS_TOP_K_NGRAMS));
+ }
+ }
+
if(DMLScript.FED_STATISTICS) {
sb.append("\n");
sb.append(FederatedStatistics.displayStatistics(DMLScript.FED_STATISTICS_COUNT));
diff --git a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
new file mode 100644
index 0000000000..7554fdcd67
--- /dev/null
+++ b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.utils.stats;
+
+import java.lang.reflect.Array;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+public class NGramBuilder<T, U> {
+
+ public static <T, U> String toCSV(String[] columnNames,
List<NGramEntry<T, U>> entries, Function<NGramEntry<T, U>, String> statsMapper)
{
+ StringBuilder builder = new StringBuilder(String.join(",",
columnNames));
+ builder.append("\n");
+
+ for (NGramEntry<T, U> entry : entries) {
+ builder.append(entry.getIdentifier().replace(",", ";"));
+ builder.append(",");
+ builder.append(entry.getCumStats());
+ builder.append(",");
+
+ if (statsMapper != null) {
+ builder.append(statsMapper.apply(entry));
+ builder.append(",");
+ }
+
+ builder.append(entry.getOccurrences());
+ builder.append("\n");
+ }
+
+ return builder.toString();
+ }
+
+ public static class NGramEntry<T, U> {
+ private final String identifier;
+ private final T[] entry;
+ private U[] stats;
+ private U cumStats;
+ private long occurrences;
+ private int offset;
+
+ public NGramEntry(String identifier, T[] entry, U[] stats, U
cumStats, int offset) {
+ this.identifier = identifier;
+ this.entry = entry;
+ this.stats = stats;
+ this.occurrences = 1;
+ this.offset = offset;
+ this.cumStats = cumStats;
+ }
+
+ public String getIdentifier() {
+ return identifier;
+ }
+
+ public long getOccurrences() {
+ return occurrences;
+ }
+
+ public U getStat(int index) {
+ if (index < 0 || index >= entry.length)
+ throw new ArrayIndexOutOfBoundsException("Index
" + index + " is out of bounds");
+
+ index = (index + offset) % entry.length;
+ return stats[index];
+ }
+
+ public U getCumStats() {
+ return cumStats;
+ }
+
+ public U[] getStats() {
+ return stats;
+ }
+
+ public int getOffset() {
+ return offset;
+ }
+
+ void setCumStats(U cumStats) {
+ this.cumStats = cumStats;
+ }
+
+ public T get(int index) {
+ if (index < 0 || index >= entry.length)
+ throw new ArrayIndexOutOfBoundsException("Index
" + index + " is out of bounds");
+
+ index = (index + offset) % entry.length;
+ return entry[index];
+ }
+
+ private NGramEntry<T, U> increment() {
+ occurrences++;
+ return this;
+ }
+
+ private NGramEntry<T, U> add(NGramEntry<T, U> entry) {
+ return add(entry.occurrences);
+ }
+
+ private NGramEntry<T, U> add(long n) {
+ occurrences += n;
+ return this;
+ }
+ }
+
+ private final T[] currentNGram;
+ private final U[] currentStats;
+ private int currentIndex = 0;
+ private int currentSize = 0;
+ private final Function<T, String> idGenerator;
+ private final BiFunction<U, U, U> statsMerger;
+ private final ConcurrentHashMap<String, NGramEntry<T, U>> nGrams;
+
+ @SuppressWarnings("unchecked")
+ public NGramBuilder(Class<T> clazz, Class<U> clazz2, int size,
Function<T, String> idGenerator, BiFunction<U, U, U> statsMerger) {
+ currentNGram = (T[]) Array.newInstance(clazz, size);
+ currentStats = (U[]) Array.newInstance(clazz2, size);
+ this.idGenerator = idGenerator;
+ this.nGrams = new ConcurrentHashMap<>();
+ this.statsMerger = statsMerger;
+ }
+
+ public int getSize() {
+ return currentNGram.length;
+ }
+
+ public synchronized void merge(NGramBuilder<T, U> builder) {
+ builder.nGrams.forEach((k, v) -> nGrams.merge(k, v, (v1, v2) ->
+ {
+ v1.add(v2.occurrences);
+ v1.setCumStats(statsMerger.apply(v1.getCumStats(),
v2.getCumStats()));
+ int index1 = v1.offset;
+ int index2 = v2.offset;
+ U[] stats1 = v1.getStats();
+ U[] stats2 = v2.getStats();
+
+ for (int i = 0; i < stats1.length; i++) {
+ stats1[index1] =
statsMerger.apply(stats1[index1], stats2[index2]);
+ index1 = (index1 + 1) % stats1.length;
+ index2 = (index2 + 1) % stats2.length;
+ }
+
+ return v1;
+ }));
+ }
+
+ public synchronized void append(T element, U stat) {
+ currentNGram[currentIndex] = element;
+ currentStats[currentIndex] = stat;
+ currentIndex = (currentIndex + 1) % currentNGram.length;
+
+ if (currentSize < currentNGram.length)
+ currentSize++;
+
+ if (currentSize == currentNGram.length) {
+ StringBuilder builder = new
StringBuilder(currentNGram.length);
+ builder.append("(");
+
+ for (int i = 0; i < currentNGram.length; i++) {
+ int actualIndex = (i + currentIndex) %
currentSize;
+
builder.append(idGenerator.apply(currentNGram[actualIndex]));
+
+ if (i != currentNGram.length - 1)
+ builder.append(", ");
+ }
+
+ builder.append(")");
+
+ registerElement(builder.toString(), stat);
+ }
+ }
+
+ public synchronized List<NGramEntry<T, U>> getTopK(int k) {
+ return nGrams.entrySet().stream()
+
.sorted(Comparator.comparingLong((Map.Entry<String, NGramEntry<T, U>> v) ->
v.getValue().occurrences).reversed())
+ .map(Map.Entry::getValue)
+ .limit(k)
+ .collect(Collectors.toList());
+ }
+
+ public synchronized List<NGramEntry<T, U>> getTopK(int k,
Comparator<NGramEntry<T, U>> comparator, boolean reversed) {
+ return nGrams.entrySet().stream()
+ .sorted((e1, e2) -> reversed ?
comparator.compare(e2.getValue(), e1.getValue()) :
comparator.compare(e1.getValue(), e2.getValue()))
+ .map(Map.Entry::getValue)
+ .limit(k)
+ .collect(Collectors.toList());
+ }
+
+ private synchronized void registerElement(String id, U stat) {
+ nGrams.compute(id, (key, entry) -> {
+ if (entry == null) {
+ U cumStat = currentStats[0];
+
+ for (int i = 1; i < currentStats.length; i++) {
+ cumStat =
statsMerger.apply(currentStats[i], cumStat);
+ }
+
+ entry = new NGramEntry<T, U>(id,
Arrays.copyOf(currentNGram, currentNGram.length), Arrays.copyOf(currentStats,
currentStats.length), cumStat, currentIndex);
+ } else {
+ entry.increment();
+ U[] stats = entry.getStats();
+ U cumStat = null;
+
+ int mCurrentIndex = currentIndex;
+ int mIndexEntry = entry.offset;
+
+ for (int i = 0; i < stats.length; i++) {
+ stats[mIndexEntry] =
statsMerger.apply(stats[mIndexEntry], currentStats[mCurrentIndex]);
+ if (i == 0) {
+ cumStat = stats[mIndexEntry];
+ } else {
+ cumStat =
statsMerger.apply(stats[mIndexEntry], cumStat);
+ }
+
+ mCurrentIndex = (mCurrentIndex + 1) %
stats.length;
+ mIndexEntry = (mIndexEntry + 1) %
stats.length;
+ }
+
+ entry.setCumStats(cumStat);
+ }
+
+ return entry;
+ });
+ }
+
+}
diff --git
a/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java
b/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java
index 4f8af7f149..bdd0a9b415 100644
--- a/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java
@@ -19,6 +19,8 @@
package org.apache.sysds.test.applications;
+import java.io.FileWriter;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -27,6 +29,9 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.utils.Statistics;
+import org.apache.sysds.utils.stats.NGramBuilder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -81,6 +86,10 @@ public class ApplyTransformTest extends AutomatedTestBase{
getAndLoadTestConfiguration(TEST_NAME);
List<String> proArgs = new ArrayList<>();
+ proArgs.add("-stats");
+ proArgs.add("-ngrams");
+ proArgs.add("1,2,3,4,5,6,7,8,9,10");
+ proArgs.add("10");
proArgs.add("-nvargs");
proArgs.add("X=" + sourceDirectory + X);
proArgs.add("missing_value_maps=" +
(missing_value_maps.equals(" ") ? " " : sourceDirectory + missing_value_maps));
diff --git a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
index b77bc49a0d..dbb98160e2 100644
--- a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
@@ -83,6 +83,9 @@ public class L2SVMTest extends AutomatedTestBase
List<String> proArgs = new ArrayList<>();
proArgs.add("-stats");
+ proArgs.add("-ngrams");
+ proArgs.add("3,2");
+ proArgs.add("10");
proArgs.add("-nvargs");
proArgs.add("X=" + input("X"));
proArgs.add("Y=" + input("Y"));