This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 46d8e19ace5bb7d6a7c2636488a553ee7845356d
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Aug 11 00:16:42 2021 +0200

    [SYSTEMDS-3090] Fix nnz mismatch in shuffle (invalid aggregation)
    
    A recent change in the aggregation logic forced the accumulator block to
    dense for performance, but also changed the nnz metadata to invalid
    values (likely to preserve the dense representation), which however, can
    lead to severe correctness issues and as surfaced in DBSCAN test
    failures (after other script modifications) can cause crashes during
    shuffle.
    
    This patch corrects the sources of this metadata corruption and adds
    additional utils to simplify debugging of invalid block metadata in
    distributed RDDs (to quickly narrow down the operation that introduces
    the violation).
---
 .../sysds/runtime/controlprogram/ProgramBlock.java | 13 +++++-
 .../instructions/spark/utils/SparkUtils.java       | 19 ++++++++
 .../sysds/runtime/matrix/data/LibMatrixAgg.java    | 50 ++++++++++------------
 .../test/functions/builtin/BuiltinDBSCANTest.java  |  3 +-
 4 files changed, 54 insertions(+), 31 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 19b769f..ff33f6d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -24,6 +24,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.api.jmlc.JMLCUtils;
+import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.Hop;
@@ -46,9 +47,12 @@ import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.IntObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.lineage.LineageCache;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MetaData;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
 import org.apache.sysds.utils.Statistics;
 
@@ -274,7 +278,7 @@ public abstract class ProgramBlock implements ParseInfo {
                        // optional check for correct nnz and sparse/dense 
representation of all
                        // variables in symbol table (for tracking source of 
wrong representation)
                        if(CHECK_MATRIX_PROPERTIES) {
-                               checkSparsity(tmp, ec.getVariables());
+                               checkSparsity(tmp, ec.getVariables(), ec);
                                checkFederated(tmp, ec.getVariables());
                        }
                }
@@ -333,7 +337,7 @@ public abstract class ProgramBlock implements ParseInfo {
                        }
        }
 
-       private static void checkSparsity(Instruction lastInst, 
LocalVariableMap vars) {
+       private static void checkSparsity(Instruction lastInst, 
LocalVariableMap vars, ExecutionContext ec) {
                for(String varname : vars.keySet()) {
                        Data dat = vars.get(varname);
                        if(dat instanceof MatrixObject) {
@@ -364,6 +368,11 @@ public abstract class ProgramBlock implements ParseInfo {
                                                        + ", actual=" + sparse1 
+ ", expected=" + sparse2 + ", nrow=" + mb.getNumRows() + ", ncol="
                                                        + mb.getNumColumns() + 
", nnz=" + nnz1 + ", inst=" + lastInst + ")");
                                }
+                               MetaData meta = mo.getMetaData();
+                               if( mo.getRDDHandle() != null && !(meta 
instanceof MetaDataFormat 
+                                       && 
((MetaDataFormat)meta).getFileFormat() != FileFormat.BINARY) ) {
+                                       SparkUtils.checkSparsity(varname, ec);
+                               }
                        }
                }
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
index 5e51977..2c15b91 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
@@ -25,10 +25,12 @@ import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.api.java.function.VoidFunction;
 import org.apache.spark.storage.StorageLevel;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Checkpoint;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.data.IndexedTensorBlock;
@@ -200,6 +202,12 @@ public class SparkUtils
                else //requires key access, so use mappartitions
                        return in.mapPartitionsToPair(new 
CopyTensorBlockPairFunction(deep), true);
        }
+       
+       public static void checkSparsity(String varname, ExecutionContext ec) {
+               SparkExecutionContext sec = (SparkExecutionContext) ec;
+               sec.getBinaryMatrixBlockRDDHandleForVariable(varname)
+                       .foreach(new CheckSparsityFunction());
+       }
 
        // This returns RDD with identifier as well as location
        public static String getStartLineFromSparkDebugInfo(String line) {
@@ -288,6 +296,17 @@ public class SparkUtils
                        mo.acquireReadAndRelease();
        }
        
+       private static class CheckSparsityFunction implements 
VoidFunction<Tuple2<MatrixIndexes,MatrixBlock>>
+       {
+               private static final long serialVersionUID = 
4150132775681848807L;
+
+               @Override
+               public void call(Tuple2<MatrixIndexes, MatrixBlock> arg) throws 
Exception {
+                       arg._2.checkNonZeros();
+                       arg._2.checkSparseRows();
+               }
+       }
+       
        private static class AnalyzeCellDataCharacteristics implements 
Function<Tuple2<MatrixIndexes,MatrixCell>, DataCharacteristics>
        {
                private static final long serialVersionUID = 
8899395272683723008L;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 722eeca..608d79c 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -1153,16 +1153,14 @@ public class LibMatrixAgg
                
                double[] a = in.getDenseBlockValues();
 
-               if(aggVal.isEmpty()){
+               if(aggVal.isEmpty()) {
                        aggVal.allocateDenseBlock();
-                       aggVal.setNonZeros(in.getNonZeros());
                }
                else if(aggVal.isInSparseFormat()){
                        // If for some reason the agg Val is sparse then force 
it to dence,
                        // since the values that are going to be added
                        // will make it dense anyway.
                        aggVal.sparseToDense();
-                       aggVal.setNonZeros(in.getNonZeros()); 
                        if(aggVal.denseBlock == null)
                                aggVal.allocateDenseBlock();
                }
@@ -1171,8 +1169,6 @@ public class LibMatrixAgg
                KahanObject buffer = new KahanObject(0, 0);
                KahanPlus akplus = KahanPlus.getKahanPlusFnObject();
                
-               // Don't include nnz maintenence since this function most 
likely aggregate more than one matrixblock.
-               
                // j is the pointer to column.
                // c is the pointer to correction. 
                for(int j=0, c = n; j<n; j++, c++){
@@ -1182,6 +1178,8 @@ public class LibMatrixAgg
                        t[j] =  buffer._sum;
                        t[c] = buffer._correction;
                }
+               
+               aggVal.recomputeNonZeros();
        }
 
        private static void 
aggregateBinaryMatrixLastRowSparseGeneric(MatrixBlock in, MatrixBlock aggVal) {
@@ -1197,30 +1195,26 @@ public class LibMatrixAgg
                final int m = in.rlen;
                final int rlen = Math.min(a.numRows(), m);
                
-               if(aggVal.isEmpty()){
+               if(aggVal.isEmpty())
                        aggVal.allocateSparseRowsBlock();
-                       aggVal.setNonZeros(in.getNonZeros());
-               }
-
-               for( int i=0; i<rlen-1; i++ )
-               {
-                       if( !a.isEmpty(i) )
-                       {
-                               int apos = a.pos(i);
-                               int alen = a.size(i);
-                               int[] aix = a.indexes(i);
-                               double[] avals = a.values(i);
-                               
-                               for( int j=apos; j<apos+alen; j++ )
-                               {
-                                       int jix = aix[j];
-                                       double corr = in.quickGetValue(m-1, 
jix);
-                                       buffer1._sum        = 
aggVal.quickGetValue(i, jix);
-                                       buffer1._correction = 
aggVal.quickGetValue(m-1, jix);
-                                       akplus.execute(buffer1, avals[j], corr);
-                                       aggVal.quickSetValue(i, jix, 
buffer1._sum);
-                                       aggVal.quickSetValue(m-1, jix, 
buffer1._correction);
-                               }
+               
+               // add to aggVal with implicit nnz maintenance
+               for( int i=0; i<rlen-1; i++ ) {
+                       if( a.isEmpty(i) )
+                               continue;
+                       int apos = a.pos(i);
+                       int alen = a.size(i);
+                       int[] aix = a.indexes(i);
+                       double[] avals = a.values(i);
+                       
+                       for( int j=apos; j<apos+alen; j++ ) {
+                               int jix = aix[j];
+                               double corr = in.quickGetValue(m-1, jix);
+                               buffer1._sum        = aggVal.quickGetValue(i, 
jix);
+                               buffer1._correction = aggVal.quickGetValue(m-1, 
jix);
+                               akplus.execute(buffer1, avals[j], corr);
+                               aggVal.quickSetValue(i, jix, buffer1._sum);
+                               aggVal.quickSetValue(m-1, jix, 
buffer1._correction);
                        }
                }
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java
index 41ca5d0..dfb79cf 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java
@@ -67,7 +67,8 @@ public class BuiltinDBSCANTest extends AutomatedTestBase
                        String HOME = SCRIPT_DIR + TEST_DIR;
 
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{"-nvargs", "X=" + 
input("A"), "Y=" + output("B"), "eps=" + epsDBSCAN, "minPts=" + minPts};
+                       programArgs = new String[]{"-explain","-nvargs",
+                               "X=" + input("A"), "Y=" + output("B"), "eps=" + 
epsDBSCAN, "minPts=" + minPts};
                        fullRScriptName = HOME + TEST_NAME + ".R";
                        rCmd = getRCmd(inputDir(), Double.toString(epsDBSCAN), 
Integer.toString(minPts), expectedDir());
 

Reply via email to