This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 335239e929 [SYSTEMDS-3714] Extended N-gram statistics based on lineage
335239e929 is described below

commit 335239e929c27c65aa59af6f22375826d61a662f
Author: Jaybit0 <[email protected]>
AuthorDate: Tue Sep 3 12:35:13 2024 +0200

    [SYSTEMDS-3714] Extended N-gram statistics based on lineage
    
    Closes #2062.
---
 src/main/java/org/apache/sysds/api/DMLOptions.java |  11 +-
 src/main/java/org/apache/sysds/api/DMLScript.java  |   2 +
 .../sysds/runtime/controlprogram/ProgramBlock.java |   8 +-
 .../sysds/runtime/instructions/Instruction.java    |   3 +
 .../sysds/runtime/lineage/LineageItemUtils.java    |  39 ++++
 .../apache/sysds/runtime/lineage/LineageMap.java   |   4 +
 .../sysds/runtime/matrix/data/MatrixBlock.java     |   2 -
 .../java/org/apache/sysds/utils/Statistics.java    | 225 ++++++++++++++++++++-
 .../org/apache/sysds/utils/stats/NGramBuilder.java |  32 +++
 .../sysds/performance/matrix/MatrixAggregate.java  |   5 +-
 .../apache/sysds/test/applications/L2SVMTest.java  |  19 +-
 11 files changed, 330 insertions(+), 20 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java 
b/src/main/java/org/apache/sysds/api/DMLOptions.java
index acacc39572..5bd5e019d0 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -57,6 +57,7 @@ public class DMLOptions {
        public int                  statsCount    = 10;               // 
Default statistics count
        public int[]                statsNGramSizes = { 3 };          // 
Default n-gram tuple sizes
        public int                  statsTopKNGrams = 10;             // How 
many of the most heavy hitting n-grams are displayed
+       public boolean              statsNGramsUseLineage = true;     // If 
N-Grams use lineage for data-dependent tracking
        public boolean              fedStats      = false;            // 
Whether to record and print the federated statistics
        public int                  fedStatsCount = 10;               // 
Default federated statistics count
        public boolean              memStats      = false;            // max 
memory statistics
@@ -219,7 +220,7 @@ public class DMLOptions {
                dmlOptions.statsNGrams = line.hasOption("ngrams");
                if (dmlOptions.statsNGrams){
                        String[] nGramArgs = line.getOptionValues("ngrams");
-                       if (nGramArgs.length == 2) {
+                       if (nGramArgs.length >= 2) {
                                try {
                                        String[] nGramSizeSplit = 
nGramArgs[0].split(",");
                                        dmlOptions.statsNGramSizes = new 
int[nGramSizeSplit.length];
@@ -229,10 +230,18 @@ public class DMLOptions {
                                        }
 
                                        dmlOptions.statsTopKNGrams = 
Integer.parseInt(nGramArgs[1]);
+
+                                       if (nGramArgs.length == 3) {
+                                               
dmlOptions.statsNGramsUseLineage = Boolean.parseBoolean(nGramArgs[2]);
+                                       }
                                } catch (NumberFormatException e) {
                                        throw new 
org.apache.commons.cli.ParseException("Invalid argument specified for -ngrams 
option, must be a valid integer");
                                }
                        }
+
+                       if (dmlOptions.statsNGramsUseLineage) {
+                               dmlOptions.lineage = true;
+                       }
                }
 
                dmlOptions.fedStats = line.hasOption("fedStats");
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 2137915f22..81ce1f04b0 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -104,6 +104,8 @@ public class DMLScript
        public static int[]         STATISTICS_NGRAM_SIZES   = 
DMLOptions.defaultOptions.statsNGramSizes;
        // Set top k displayed n-grams limit
        public static int         STATISTICS_TOP_K_NGRAMS    = 
DMLOptions.defaultOptions.statsTopKNGrams;
+       // Set if N-Grams use lineage for data-dependent tracking
+       public static boolean     STATISTICS_NGRAMS_USE_LINEAGE = 
DMLOptions.defaultOptions.statsNGramsUseLineage;
        // Set statistics maximum wrap length
        public static int         STATISTICS_MAX_WRAP_LEN    = 30;
        // Enable/disable to print federated statistics
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 4e75d5456f..0739334680 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -241,7 +241,8 @@ public abstract class ProgramBlock implements ParseInfo {
        private void executeSingleInstruction(Instruction currInst, 
ExecutionContext ec) {
                try {
                        // start time measurement for statistics
-                       long t0 = (DMLScript.STATISTICS || 
DMLScript.STATISTICS_NGRAMS || LOG.isTraceEnabled()) ? System.nanoTime() : 0;
+                       long t0 = (DMLScript.STATISTICS || 
DMLScript.STATISTICS_NGRAMS || LOG.isTraceEnabled())
+                               ? System.nanoTime() : 0;
 
                        // pre-process instruction (inst patching, listeners, 
lineage)
                        Instruction tmp = currInst.preprocessInstruction(ec);
@@ -264,9 +265,8 @@ public abstract class ProgramBlock implements ParseInfo {
                                        
Statistics.maintainCPHeavyHitters(tmp.getExtendedOpcode(), System.nanoTime() - 
t0);
                                }
 
-                               if (DMLScript.STATISTICS_NGRAMS) {
-                                       
Statistics.maintainNGrams(tmp.getExtendedOpcode(), System.nanoTime() - t0);
-                               }
+                               if (DMLScript.STATISTICS_NGRAMS)
+                                       
Statistics.maintainNGramsFromLineage(tmp, ec, t0);
                        }
 
                        // optional trace information (instruction and runtime)
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index 969dfaf5c2..50238aadd8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -26,6 +26,7 @@ import org.apache.sysds.lops.Lop;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.utils.Statistics;
 
 public abstract class Instruction 
 {
@@ -214,6 +215,8 @@ public abstract class Instruction
         * @return instruction
         */
        public Instruction preprocessInstruction(ExecutionContext ec) {
+               if (DMLScript.STATISTICS_NGRAMS && 
DMLScript.STATISTICS_NGRAMS_USE_LINEAGE)
+                       Statistics.prepareNGramInst(null); // Reset the current 
LineageItem for this thread
                // Lineage tracing
                if (DMLScript.LINEAGE)
                        ec.traceLineage(this);
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 58dab47534..5766437fe1 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -63,6 +63,7 @@ import 
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import 
org.apache.sysds.runtime.instructions.fed.ReorgFEDInstruction.DiagMatrix;
 import org.apache.sysds.runtime.instructions.fed.ReorgFEDInstruction.Rdiag;
 import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.utils.Statistics;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -124,6 +125,44 @@ public class LineageItemUtils {
        public static boolean isFunctionDebugging () {
                return FUNCTION_DEBUGGING;
        }
+
+       public static String explainLineageType(LineageItem li, 
Statistics.LineageNGramExtension ext) {
+               if (li.getType() == LineageItemType.Literal) {
+                       String[] splt = li.getData().split("·");
+                       if (splt.length >= 3)
+                               return splt[1] + "·" + splt[2];
+                       return "·";
+               }
+               return ext != null ? ext.getDataType() + "·" + 
ext.getValueType() : "··";
+       }
+
+       public static String explainLineageWithTypes(LineageItem li, 
Statistics.LineageNGramExtension ext) {
+               if (li.getType() == LineageItemType.Literal) {
+                       String[] splt = li.getData().split("·");
+                       if (splt.length >= 3)
+                               return "L·" + splt[1] + "·" + splt[2];
+                       return "L··";
+               }
+               return li.getOpcode() + "·" + (ext != null ? ext.getDataType() 
+ "·" + ext.getValueType() : "·");
+       }
+
+       public static String explainLineageAsInstruction(LineageItem li, 
Statistics.LineageNGramExtension ext) {
+               StringBuilder sb = new 
StringBuilder(explainLineageWithTypes(li, ext));
+               sb.append("(");
+               if (li.getInputs() != null) {
+                       int ctr = 0;
+                       for (LineageItem liIn : li.getInputs()) {
+                               if (ctr++ != 0)
+                                       sb.append(" ° ");
+                               if (liIn.getType() == LineageItemType.Literal)
+                                       sb.append("L_" + 
explainLineageType(liIn, Statistics.getExtendedLineage(li)));
+                               else
+                                       sb.append(explainLineageType(liIn, 
Statistics.getExtendedLineage(li)));
+                       }
+               }
+               sb.append(")");
+               return sb.toString();
+       }
        
        public static String explainSingleLineageItem(LineageItem li) {
                StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
index 41875bdfdf..2b3c981d9e 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
@@ -32,6 +32,7 @@ import 
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
 import org.apache.sysds.runtime.lineage.LineageItem.LineageItemType;
 import org.apache.sysds.utils.Explain;
+import org.apache.sysds.utils.Statistics;
 
 import java.util.HashMap;
 import java.util.Map;
@@ -146,6 +147,9 @@ public class LineageMap {
        }
        
        private void trace(Instruction inst, ExecutionContext ec, Pair<String, 
LineageItem> li) {
+               if (li != null && li.getValue() != null && 
DMLScript.STATISTICS_NGRAMS && DMLScript.STATISTICS_NGRAMS_USE_LINEAGE)
+                       Statistics.prepareNGramInst(li);
+
                if (inst instanceof VariableCPInstruction) {
                        VariableCPInstruction vcp_inst = 
((VariableCPInstruction) inst);
                        
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 79eb73ba0a..f76502ef7c 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -73,7 +73,6 @@ import 
org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.functionobjects.CM;
 import org.apache.sysds.runtime.functionobjects.CTable;
 import org.apache.sysds.runtime.functionobjects.DiagIndex;
-import org.apache.sysds.runtime.functionobjects.Divide;
 import org.apache.sysds.runtime.functionobjects.FunctionObject;
 import org.apache.sysds.runtime.functionobjects.IfElse;
 import org.apache.sysds.runtime.functionobjects.KahanFunction;
@@ -96,7 +95,6 @@ import org.apache.sysds.runtime.instructions.cp.KahanObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
-import org.apache.sysds.runtime.matrix.data.LibMatrixBincell.BinaryAccessType;
 import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 3ad613c842..c0f087d0b0 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -26,13 +26,19 @@ import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.fedplanner.FederatedCompilationTimer;
 import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.spark.SPInstruction;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.utils.stats.CodegenStatistics;
 import org.apache.sysds.utils.stats.NGramBuilder;
 import org.apache.sysds.utils.stats.NativeStatistics;
@@ -46,6 +52,7 @@ import java.lang.management.CompilationMXBean;
 import java.lang.management.GarbageCollectorMXBean;
 import java.lang.management.ManagementFactory;
 import java.text.DecimalFormat;
+import java.util.AbstractMap;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
@@ -54,10 +61,12 @@ import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.DoubleAdder;
 import java.util.concurrent.atomic.LongAdder;
+import java.util.function.Consumer;
 
 /**
  * This class captures all statistics.
@@ -74,6 +83,7 @@ public class Statistics
                public final long n;
                public final long cumTimeNanos;
                public final double m2;
+               public final HashMap<String, Double> meta;
 
                public static <T> Comparator<NGramBuilder.NGramEntry<T, 
NGramStats>> getComparator() {
                        return Comparator.comparingLong(entry -> 
entry.getCumStats().cumTimeNanos);
@@ -91,13 +101,25 @@ public class Statistics
 
                        double newM2 = stats1.m2 + stats2.m2 + delta * delta * 
stats1.n * stats2.n / (double)newN;
 
-                       return new NGramStats(newN, cumTimeNanos, newM2);
+                       HashMap<String, Double> cpy = null;
+
+                       if (stats1.meta != null) {
+                               cpy = new HashMap<>(stats1.meta);
+                               final HashMap<String, Double> mCpy = cpy;
+                               if (stats2.meta != null)
+                                       stats2.meta.forEach((key, value) -> 
mCpy.merge(key, value, Double::sum));
+                       } else if (stats2.meta != null) {
+                               cpy = new HashMap<>(stats2.meta);
+                       }
+
+                       return new NGramStats(newN, cumTimeNanos, newM2, cpy);
                }
 
-               public NGramStats(final long n, final long cumTimeNanos, final 
double m2) {
+               public NGramStats(final long n, final long cumTimeNanos, final 
double m2, HashMap<String, Double> meta) {
                        this.n = n;
                        this.cumTimeNanos = cumTimeNanos;
                        this.m2 = m2;
+                       this.meta = meta;
                }
 
                public double getTimeVariance() {
@@ -107,6 +129,54 @@ public class Statistics
                public String toString() {
                        return String.format(Locale.US, "%.5f", (cumTimeNanos / 
1000000000d));
                }
+
+               public HashMap<String, Double> getMeta() {
+                       return meta;
+               }
+       }
+
+       public static class LineageNGramExtension {
+               private String _datatype;
+               private String _valuetype;
+               private long _execNanos;
+
+               private HashMap<String, Double> _meta;
+
+               public void setDataType(String dataType) {
+                       _datatype = dataType;
+               }
+
+               public String getDataType() {
+                       return _datatype == null ? "" : _datatype;
+               }
+
+               public void setValueType(String valueType) {
+                       _valuetype = valueType;
+               }
+
+               public String getValueType() {
+                       return _valuetype == null ? "" : _valuetype;
+               }
+
+               public void setExecNanos(long nanos) {
+                       _execNanos = nanos;
+               }
+
+               public long getExecNanos() {
+                       return _execNanos;
+               }
+
+               public void setMeta(String key, Double value) {
+                       if (_meta == null)
+                               _meta = new HashMap<>();
+                       _meta.put(key, value);
+               }
+
+               public Object getMeta(String key) {
+                       if (_meta == null)
+                               return null;
+                       return _meta.get(key);
+               }
        }
        
        private static long compileStartTime = 0;
@@ -117,6 +187,8 @@ public class Statistics
        //heavy hitter counts and times 
        private static final ConcurrentHashMap<String,InstStats> _instStats = 
new ConcurrentHashMap<>();
        private static final ConcurrentHashMap<String, NGramBuilder<String, 
NGramStats>[]> _instStatsNGram = new ConcurrentHashMap<>();
+       private static final ConcurrentHashMap<Long, Entry<String, 
LineageItem>> _instStatsLineageTracker = new ConcurrentHashMap<>();
+       private static final ConcurrentHashMap<LineageItem, 
LineageNGramExtension> _lineageExtensions = new ConcurrentHashMap<>();
 
        // number of compiled/executed SP instructions
        private static final LongAdder numExecutedSPInst = new LongAdder();
@@ -299,6 +371,8 @@ public class Statistics
                FederatedStatistics.reset();
 
                _instStatsNGram.clear();
+               _instStatsLineageTracker.clear();
+               _instStats.clear();
        }
 
        public static void resetJITCompileTime(){
@@ -401,6 +475,120 @@ public class Statistics
                tmp.count.increment();
        }
 
+       public static void prepareNGramInst(Entry<String, LineageItem> li) {
+               if (li == null)
+                       
_instStatsLineageTracker.remove(Thread.currentThread().getId());
+               else
+                       
_instStatsLineageTracker.put(Thread.currentThread().getId(), li);
+       }
+
+       public static Optional<Entry<String, LineageItem>> 
getCurrentLineageItem() {
+               Entry<String, LineageItem> item = 
_instStatsLineageTracker.get(Thread.currentThread().getId());
+               return item == null ? Optional.empty() : Optional.of(item);
+       }
+
+       public static synchronized void clearNGramRecording() {
+               NGramBuilder<String, NGramStats>[] bl = 
_instStatsNGram.get(Thread.currentThread().getName());
+               for (NGramBuilder<String, NGramStats> b : bl)
+                       b.clearCurrentRecording();
+       }
+
+       public static synchronized void extendLineageItem(LineageItem li, 
LineageNGramExtension ext) {
+               _lineageExtensions.put(li, ext);
+       }
+
+       public static synchronized LineageNGramExtension 
getExtendedLineage(LineageItem li) {
+               return _lineageExtensions.get(li);
+       }
+       
+       public static synchronized void maintainNGramsFromLineage(Instruction 
tmp, ExecutionContext ec, long t0) {
+               final long nanoTime = System.nanoTime() - t0;
+               if (DMLScript.STATISTICS_NGRAMS_USE_LINEAGE) {
+                       Statistics.getCurrentLineageItem().ifPresent(li -> {
+                               Data data = ec.getVariable(li.getKey());
+                               Statistics.LineageNGramExtension ext = new 
Statistics.LineageNGramExtension();
+                               if (data != null) {
+                                       
ext.setDataType(data.getDataType().toString());
+                                       
ext.setValueType(data.getValueType().toString());
+                                       if (data instanceof CacheableData) {
+                                               DataCharacteristics dc = 
((CacheableData<?>)data).getDataCharacteristics();
+                                               ext.setMeta("NDims", 
(double)dc.getNumDims());
+                                               ext.setMeta("NumRows", 
(double)dc.getRows());
+                                               ext.setMeta("NumCols", 
(double)dc.getCols());
+                                               ext.setMeta("NonZeros", 
(double)dc.getNonZeros());
+                                       }
+                               }
+                               ext.setExecNanos(nanoTime);
+                               Statistics.extendLineageItem(li.getValue(), 
ext);
+                               
Statistics.maintainNGramsFromLineage(li.getValue());
+                       });
+               } else
+                       Statistics.maintainNGrams(tmp.getExtendedOpcode(), 
nanoTime);
+       }
+
+       @SuppressWarnings("unchecked")
+       public static synchronized void maintainNGramsFromLineage(LineageItem 
li) {
+               NGramBuilder<String, NGramStats>[] tmp = 
_instStatsNGram.computeIfAbsent(Thread.currentThread().getName(), k -> {
+                       NGramBuilder<String, NGramStats>[] threadEntry = new 
NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length];
+                       for (int i = 0; i < threadEntry.length; i++) {
+                               threadEntry[i] = new NGramBuilder<String, 
NGramStats>(String.class, NGramStats.class, 
DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge);
+                       }
+                       return threadEntry;
+               });
+               addLineagePaths(li, new ArrayList<>(), new ArrayList<>(), tmp);
+       }
+
+       /**
+        * Adds the corresponding sequences of instructions to the n-grams.
+        * <p></p>
+        * Example: 2-grams from (a*b + a/c) will add [(*,+), (/,+)]
+        * @param li
+        * @param currentPath
+        * @param indexes
+        * @param builders
+        */
+       private static void addLineagePaths(LineageItem li, 
ArrayList<Entry<LineageItem, LineageNGramExtension>> currentPath, 
ArrayList<Integer> indexes, NGramBuilder<String, NGramStats>[] builders) {
+               if (li.getType() == LineageItem.LineageItemType.Literal)
+                       return; // Skip literals as they are no real instruction
+
+               currentPath.add(new AbstractMap.SimpleEntry<>(li, 
getExtendedLineage(li)));
+
+               int maxSize = 0;
+               NGramBuilder<String, NGramStats> matchingBuilder = null;
+
+               for (NGramBuilder<String, NGramStats> builder : builders) {
+                       if (builder.getSize() == currentPath.size())
+                               matchingBuilder = builder;
+                       if (builder.getSize() > maxSize)
+                               maxSize = builder.getSize();
+               }
+
+               if (matchingBuilder != null) {
+                       // If we have an n-gram builder with n = 
currentPath.size(), then we want to insert the entry
+                       // As we cannot incrementally add the instructions (we 
have a DAG rather than a sequence of instructions)
+                       // we need to clear the current n-grams
+                       clearNGramRecording();
+                       // We then record a new n-gram with all the 
LineageItems of the current lineage path
+                       Entry<LineageItem, LineageNGramExtension> currentEntry 
= currentPath.get(currentPath.size()-1);
+                       
matchingBuilder.append(LineageItemUtils.explainLineageAsInstruction(currentEntry.getKey(),
 currentEntry.getValue()) + (indexes.size() > 0 ? ("[" + 
indexes.get(currentPath.size()-2) + "]") : ""), new NGramStats(1, 
currentEntry.getValue() != null ? currentEntry.getValue().getExecNanos() : 0, 
0, currentEntry.getValue() != null ? currentEntry.getValue()._meta : null));
+                       for (int i = currentPath.size()-2; i >= 0; i--) {
+                               currentEntry = currentPath.get(i);
+                               
matchingBuilder.append(LineageItemUtils.explainLineageAsInstruction(currentEntry.getKey(),
 currentEntry.getValue()) + (i > 0 ? ("[" + indexes.get(i-1) + "]") : ""), new 
NGramStats(1, currentEntry.getValue() != null ? 
currentEntry.getValue().getExecNanos() : 0, 0, currentEntry.getValue() != null 
? currentEntry.getValue()._meta : null));
+                       }
+               }
+
+               if (currentPath.size() < maxSize && li.getInputs() != null) {
+                       int idx = 0;
+                       for (LineageItem input : li.getInputs()) {
+                               indexes.add(idx++);
+                               addLineagePaths(input, currentPath, indexes, 
builders);
+                               indexes.remove(indexes.size()-1);
+                       }
+               }
+
+               currentPath.remove(currentPath.size()-1);
+       }
+
        @SuppressWarnings("unchecked")
        public static void maintainNGrams(String instName, long timeNanos) {
                NGramBuilder<String, NGramStats>[] tmp = 
_instStatsNGram.computeIfAbsent(Thread.currentThread().getName(), k -> {
@@ -412,7 +600,7 @@ public class Statistics
                });
 
                for (int i = 0; i < tmp.length; i++)
-                       tmp[i].append(instName, new NGramStats(1, timeNanos, 
0));
+                       tmp[i].append(instName, new NGramStats(1, timeNanos, 0, 
null));
        }
 
        @SuppressWarnings("unchecked")
@@ -467,7 +655,7 @@ public class Statistics
                return sb.toString();
        }
 
-       public static String nGramToCSV(final NGramBuilder<String, NGramStats> 
mbuilder) {
+       public static void toCSVStream(final NGramBuilder<String, NGramStats> 
mbuilder, final Consumer<String> lineConsumer) {
                ArrayList<String> colList = new ArrayList<>();
                colList.add("N-Gram");
                colList.add("Time[s]");
@@ -478,10 +666,12 @@ public class Statistics
                        colList.add("Col" + (j + 1) + "::Mean(Time[s])");
                for (int j = 0; j < mbuilder.getSize(); j++)
                        colList.add("Col" + (j + 1) + "::StdDev(Time[s])/Col" + 
(j + 1) + "::Mean(Time[s])");
+               for (int j = 0; j < mbuilder.getSize(); j++)
+                       colList.add("Col" + (j + 1) + "_Meta");
 
                colList.add("Count");
 
-               return NGramBuilder.toCSV(colList.toArray(new 
String[colList.size()]), mbuilder.getTopK(100000, 
Statistics.NGramStats.getComparator(), true), e -> {
+               NGramBuilder.toCSVStream(colList.toArray(new 
String[colList.size()]), mbuilder.getTopK(100000, 
Statistics.NGramStats.getComparator(), true), e -> {
                        StringBuilder builder = new StringBuilder();
                        builder.append(e.getIdentifier().replace("(", 
"").replace(")", "").replace(", ", ","));
                        builder.append(",");
@@ -494,8 +684,31 @@ public class Statistics
                        } else {
                                builder.append(stdDevs);
                        }
+                       //builder.append(",");
+                       boolean first = true;
+                       NGramStats[] stats = e.getStats();
+                       for (int i = 0; i < stats.length; i++) {
+                               builder.append(",");
+                               NGramStats stat = stats[i];
+                               if (stat.getMeta() != null) {
+                                       for (Entry<String, Double> metaData : 
stat.getMeta().entrySet()) {
+                                               if (first)
+                                                       first = false;
+                                               else
+                                                       builder.append("&");
+                                               if (metaData.getValue() != null)
+                                                       
builder.append(metaData.getKey()).append(":").append(metaData.getValue());
+                                       }
+                               }
+                       }
                        return builder.toString();
-               });
+               }, lineConsumer);
+       }
+
+       public static String nGramToCSV(final NGramBuilder<String, NGramStats> 
mbuilder) {
+               final StringBuilder b = new StringBuilder();
+               toCSVStream(mbuilder, b::append);
+               return b.toString();
        }
 
        public static String getCommonNGrams(NGramBuilder<String, NGramStats> 
builder, int num) {
diff --git a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java 
b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
index e0212e5c73..85d8012789 100644
--- a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
+++ b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.utils.stats;
 
+import org.apache.commons.lang3.function.TriFunction;
+
 import java.lang.reflect.Array;
 import java.util.Arrays;
 import java.util.Comparator;
@@ -26,6 +28,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.BiFunction;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -53,6 +56,30 @@ public class NGramBuilder<T, U> {
                return builder.toString();
        }
 
+       public static <T, U> void toCSVStream(String[] columnNames, 
List<NGramEntry<T, U>> entries, Function<NGramEntry<T, U>, String> statsMapper, 
Consumer<String> lineConsumer) {
+               StringBuilder builder = new StringBuilder(String.join(",", 
columnNames));
+               builder.append("\n");
+               lineConsumer.accept(builder.toString());
+               builder.setLength(0);
+
+               for (NGramEntry<T, U> entry : entries) {
+                       builder.append(entry.getIdentifier().replace(",", ";"));
+                       builder.append(",");
+                       builder.append(entry.getCumStats());
+                       builder.append(",");
+
+                       if (statsMapper != null) {
+                               builder.append(statsMapper.apply(entry));
+                               builder.append(",");
+                       }
+
+                       builder.append(entry.getOccurrences());
+                       builder.append("\n");
+                       lineConsumer.accept(builder.toString());
+                       builder.setLength(0);
+               }
+       }
+
        public static class NGramEntry<T, U> {
                private final String identifier;
                private final T[] entry;
@@ -209,6 +236,11 @@ public class NGramBuilder<T, U> {
                                .collect(Collectors.toList());
        }
 
+       public synchronized void clearCurrentRecording() {
+               currentIndex = 0;
+               currentSize = 0;
+       }
+
        private synchronized void registerElement(String id, U stat) {
                nGrams.compute(id, (key, entry) ->  {
                        if (entry == null) {
diff --git 
a/src/test/java/org/apache/sysds/performance/matrix/MatrixAggregate.java 
b/src/test/java/org/apache/sysds/performance/matrix/MatrixAggregate.java
index f6a466efaa..8e60ee97cb 100644
--- a/src/test/java/org/apache/sysds/performance/matrix/MatrixAggregate.java
+++ b/src/test/java/org/apache/sysds/performance/matrix/MatrixAggregate.java
@@ -21,7 +21,6 @@ package org.apache.sysds.performance.matrix;
 
 import org.apache.sysds.performance.compression.APerfTest;
 import org.apache.sysds.performance.generators.ConstMatrix;
-import org.apache.sysds.performance.generators.GenPair;
 import org.apache.sysds.performance.generators.IGenerate;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.test.TestUtils;
@@ -39,8 +38,8 @@ public class MatrixAggregate extends APerfTest<Object, 
MatrixBlock> {
        public void run() throws Exception {
                MatrixBlock mb = gen.take();
 
-               String info = String.format("rows: %5d cols: %5d sp: %5.3f par: 
%2d", mb.getNumRows(), mb.getNumColumns(),
-                       mb.getSparsity(), k);
+               String info = String.format("rows: %5d cols: %5d sp: %5.3f par: 
%2d",
+                       mb.getNumRows(), mb.getNumColumns(), mb.getSparsity(), 
k);
                warmup(() -> sum(), 100);
                execute(() -> sum(), info + " sum");
        }
diff --git a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java 
b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
index dbb98160e2..534b058425 100644
--- a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
@@ -67,7 +67,16 @@ public class L2SVMTest extends AutomatedTestBase
        }
        
        @Test
-       public void testL2SVM()
+       public void testL2SVM1() {
+               testL2SVM(true);
+       }
+
+       @Test
+       public void testL2SVM2() {
+               testL2SVM(false);
+       }
+
+       private void testL2SVM(boolean ngrams)
        {
                System.out.println("------------ BEGIN " + TEST_NAME 
                        + " TEST WITH {" + numRecords + ", " + numFeatures
@@ -83,9 +92,11 @@ public class L2SVMTest extends AutomatedTestBase
 
                List<String> proArgs = new ArrayList<>();
                proArgs.add("-stats");
-               proArgs.add("-ngrams");
-               proArgs.add("3,2");
-               proArgs.add("10");
+               if (ngrams) {
+                       proArgs.add("-ngrams");
+                       proArgs.add("3,2");
+                       proArgs.add("10");
+               }
                proArgs.add("-nvargs");
                proArgs.add("X=" + input("X"));
                proArgs.add("Y=" + input("Y"));

Reply via email to