This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 14bb8a5  [SYSTEMDS-3066] CLA Spark Decompress
14bb8a5 is described below

commit 14bb8a5ba37fa2a6b81028797f24223a54297fc7
Author: baunsgaard <[email protected]>
AuthorDate: Sat Jul 17 13:05:33 2021 +0200

    [SYSTEMDS-3066] CLA Spark Decompress
    
    This commit adds/fixes spark decompression.
    also contained in this commit is the ability to see the compression
    size if logging level is Trace while compressing with spark instructions.
---
 .../runtime/instructions/SPInstructionParser.java  |  4 +++
 .../instructions/cp/CompressionCPInstruction.java  | 13 +++++++--
 .../spark/CompressionSPInstruction.java            | 34 ++++++++++++++++++++++
 .../spark/DeCompressionSPInstruction.java          | 12 ++++----
 .../sysds/utils/DMLCompressionStatistics.java      | 23 +++++++++++----
 .../compress/CompressInstructionRewrite.java       |  4 +--
 .../compress/configuration/CompressBase.java       |  4 +--
 .../compress/workload/WorkloadAlgorithmTest.java   | 20 ++++++++++---
 .../workload/SystemDS-config-compress-workload.xml |  3 +-
 .../compress/workload/WorkloadAnalysisMLogReg.dml  |  2 +-
 10 files changed, 93 insertions(+), 26 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 08a6998..d1eb4a7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -59,6 +59,7 @@ import 
org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
 import 
org.apache.sysds.runtime.instructions.spark.CumulativeAggregateSPInstruction;
 import 
org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.DeCompressionSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.DnnSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.MapmmChainSPInstruction;
@@ -500,6 +501,9 @@ public class SPInstructionParser extends InstructionParser
                        case Compression:
                                return 
CompressionSPInstruction.parseInstruction(str);
 
+                       case DeCompression:
+                               return 
DeCompressionSPInstruction.parseInstruction(str);
+
                        case SpoofFused:
                                return SpoofSPInstruction.parseInstruction(str);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
index 5ccbc41..b3acc26 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
@@ -19,8 +19,12 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
+import org.apache.sysds.runtime.compress.CompressionStatistics;
 import org.apache.sysds.runtime.compress.SingletonLookupHashMap;
 import org.apache.sysds.runtime.compress.workload.WTreeRoot;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -29,9 +33,11 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class CompressionCPInstruction extends ComputationCPInstruction {
+       private static final Log LOG = 
LogFactory.getLog(CompressionCPInstruction.class.getName());
 
        private final int _singletonLookupID;
 
+
        private CompressionCPInstruction(Operator op, CPOperand in, CPOperand 
out, String opcode, String istr,
                int singletonLookupID) {
                super(CPType.Compression, op, in, null, null, out, opcode, 
istr);
@@ -61,9 +67,12 @@ public class CompressionCPInstruction extends 
ComputationCPInstruction {
 
                WTreeRoot root = (_singletonLookupID != 0) ? (WTreeRoot) 
m.get(_singletonLookupID) : null;
                // Compress the matrix block
-               MatrixBlock out = CompressedMatrixBlockFactory.compress(in, 
OptimizerUtils.getConstrainedNumThreads(-1), root)
-                       .getLeft();
+               Pair<MatrixBlock, CompressionStatistics> compResult = 
CompressedMatrixBlockFactory.compress(in, 
OptimizerUtils.getConstrainedNumThreads(-1), root);
 
+               if(LOG.isTraceEnabled())
+                       LOG.trace(compResult.getRight());
+               MatrixBlock out = compResult.getLeft();
+               
                m.removeKey(_singletonLookupID);
                // Set output and release input
                ec.releaseMatrixInput(input1.getName());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java
index 64809cf..e6b62ee 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java
@@ -19,6 +19,10 @@
 
 package org.apache.sysds.runtime.instructions.spark;
 
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.function.Function;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
@@ -35,7 +39,10 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
+import scala.Tuple2;
+
 public class CompressionSPInstruction extends UnarySPInstruction {
+       private static final Log LOG = 
LogFactory.getLog(CompressionSPInstruction.class.getName());
 
        private final int _singletonLookupID;
 
@@ -79,6 +86,11 @@ public class CompressionSPInstruction extends 
UnarySPInstruction {
 
                // execute compression
                JavaPairRDD<MatrixIndexes, MatrixBlock> out = 
in.mapValues(mappingFunction);
+               if(LOG.isTraceEnabled()) {
+                       out.checkpoint();
+                       LOG.trace("\nSpark compressed    : " + 
reduceSizes(out.mapValues(new SizeFunction()).collect())
+                               + "\nSpark uncompressed  : " + 
reduceSizes(in.mapValues(new SizeFunction()).collect()));
+               }
 
                // set outputs
                sec.setRDDHandleForVariable(output.getName(), out);
@@ -110,4 +122,26 @@ public class CompressionSPInstruction extends 
UnarySPInstruction {
                                .getLeft();
                }
        }
+
+       public static class SizeFunction implements Function<MatrixBlock, 
Double> {
+               private static final long serialVersionUID = 1L;
+
+               public SizeFunction() {
+
+               }
+
+               @Override
+               public Double call(MatrixBlock arg0) throws Exception {
+                       return (double) arg0.getInMemorySize();
+               }
+       }
+
+       public static String reduceSizes(List<Tuple2<MatrixIndexes, Double>> 
in) {
+               double sum = 0;
+               for(Tuple2<MatrixIndexes, Double> e : in) {
+                       sum += e._2();
+               }
+
+               return "sum: " + sum + " mean: " + (sum / in.size());
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/DeCompressionSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/DeCompressionSPInstruction.java
index bd64775..d002d55 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/DeCompressionSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/DeCompressionSPInstruction.java
@@ -27,10 +27,10 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import 
org.apache.sysds.runtime.instructions.spark.CompressionSPInstruction.CompressionFunction;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.utils.DMLCompressionStatistics;
 
 public class DeCompressionSPInstruction extends UnarySPInstruction {
 
@@ -51,9 +51,10 @@ public class DeCompressionSPInstruction extends 
UnarySPInstruction {
                // get input rdd handle
                JavaPairRDD<MatrixIndexes, MatrixBlock> in = 
sec.getBinaryMatrixBlockRDDHandleForVariable(input1.getName());
 
-               // execute compression
-               JavaPairRDD<MatrixIndexes, MatrixBlock> out = in.mapValues(new 
CompressionFunction());
+               // execute decompression
+               JavaPairRDD<MatrixIndexes, MatrixBlock> out = in.mapValues(new 
DeCompressionFunction());
 
+               DMLCompressionStatistics.addDecompressSparkCount();
                // set outputs
                sec.setRDDHandleForVariable(output.getName(), out);
                sec.addLineageRDD(input1.getName(), output.getName());
@@ -64,11 +65,10 @@ public class DeCompressionSPInstruction extends 
UnarySPInstruction {
 
                @Override
                public MatrixBlock call(MatrixBlock arg0) throws Exception {
-                       if(arg0 instanceof CompressedMatrixBlock){
+                       if(arg0 instanceof CompressedMatrixBlock) 
                                return ((CompressedMatrixBlock) 
arg0).decompress(OptimizerUtils.getConstrainedNumThreads(-1));
-                       }else{
+                       else 
                                return arg0;
-                       }
                }
        }
 }
diff --git a/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java 
b/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
index 0f7fda5..92130e8 100644
--- a/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/DMLCompressionStatistics.java
@@ -33,6 +33,9 @@ public class DMLCompressionStatistics {
        private static int DecompressMTCount = 0;
        private static double DecompressMT = 0.0;
 
+       private static int DecompressSparkCount = 0;
+       private static int DecompressCacheCount = 0;
+
        public static void reset() {
                Phase0 = 0.0;
                Phase1 = 0.0;
@@ -44,6 +47,8 @@ public class DMLCompressionStatistics {
                DecompressST = 0.0;
                DecompressMTCount = 0;
                DecompressMT = 0.0;
+               DecompressSparkCount = 0;
+               DecompressCacheCount = 0;
        }
 
        public static boolean haveCompressed(){
@@ -85,12 +90,16 @@ public class DMLCompressionStatistics {
                }
        }
 
-       public static int getDecompressionCount() {
-               return DecompressMTCount;
+       public static void addDecompressSparkCount(){
+               DecompressSTCount++;
        }
 
-       public static int getDecompressionSTCount() {
-               return DecompressSTCount;
+       public static void addDecompressCacheCount(){
+               DecompressCacheCount++;
+       }
+
+       public static int getDecompressionCount() {
+               return DecompressMTCount + DecompressSTCount + 
DecompressSparkCount + DecompressCacheCount;
        }
 
        public static void display(StringBuilder sb) {
@@ -102,9 +111,11 @@ public class DMLCompressionStatistics {
                                Phase3 / 1000,
                                Phase4 / 1000,
                                Phase5 / 1000));
-                       sb.append(String.format("Decompression Counts (Single , 
Multi) thread                     :\t%d/%d\n",
+                       sb.append(String.format("Decompression Counts (Single , 
Multi, Spark, Cache) thread       :\t%d/%d/%d/%d\n",
                                DecompressSTCount,
-                               DecompressMTCount));
+                               DecompressMTCount,
+                               DecompressSparkCount,
+                               DecompressCacheCount));
                        sb.append(String.format("Dedicated Decompression Time 
(Single , Multi) thread             :\t%.3f/%.3f\n",
                                DecompressST / 1000,
                                DecompressMT / 1000));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/CompressInstructionRewrite.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/CompressInstructionRewrite.java
index ccc6d79..c005115 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/compress/CompressInstructionRewrite.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/CompressInstructionRewrite.java
@@ -127,9 +127,7 @@ public class CompressInstructionRewrite extends 
AutomatedTestBase {
                        if(LOG.isDebugEnabled())
                                LOG.debug(stdout);
 
-                       int decompressCount = 0;
-                       decompressCount += 
DMLCompressionStatistics.getDecompressionCount();
-                       decompressCount += 
DMLCompressionStatistics.getDecompressionSTCount();
+                       int decompressCount = 
DMLCompressionStatistics.getDecompressionCount();
                        long compressionCount = 
Statistics.getCPHeavyHitterCount("compress");
 
                        Assert.assertEquals(compressionCountsExpected, 
compressionCount);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java
index 07b0441..5e1f3f5 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java
@@ -70,9 +70,7 @@ public abstract class CompressBase extends AutomatedTestBase {
 
                        LOG.debug(runTest(null));
 
-                       int decompressCount = 0;
-                       decompressCount += 
DMLCompressionStatistics.getDecompressionCount();
-                       decompressCount += 
DMLCompressionStatistics.getDecompressionSTCount();
+                       int decompressCount = 
DMLCompressionStatistics.getDecompressionCount();
                        long compressionCount = (instType == ExecType.SPARK) ? 
Statistics
                                .getCPHeavyHitterCount("sp_compress") : 
Statistics.getCPHeavyHitterCount("compress");
                        DMLCompressionStatistics.reset();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index e94a4ab..c257a57 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -55,12 +55,23 @@ public class WorkloadAlgorithmTest extends 
AutomatedTestBase {
                runWorkloadAnalysisTest(TEST_NAME1, ExecMode.HYBRID, 2);
        }
 
+
+       @Test
+       public void testLmSP() {
+               runWorkloadAnalysisTest(TEST_NAME2, ExecMode.SPARK, 2);
+       }
+
        @Test
        public void testLmCP() {
                runWorkloadAnalysisTest(TEST_NAME2, ExecMode.HYBRID, 2);
        }
 
        @Test
+       public void testPCASP() {
+               runWorkloadAnalysisTest(TEST_NAME3, ExecMode.SPARK, 1);
+       }
+
+       @Test
        public void testPCACP() {
                runWorkloadAnalysisTest(TEST_NAME3, ExecMode.HYBRID, 1);
        }
@@ -85,18 +96,19 @@ public class WorkloadAlgorithmTest extends 
AutomatedTestBase {
                        writeInputMatrixWithMTD("y", y, false);
 
                        String ret = runTest(null).toString();
-
                        if(ret.contains("ERROR:"))
                                fail(ret);
 
                        // check various additional expectations
-                       long actualCompressionCount = 
Statistics.getCPHeavyHitterCount("compress");
+                       long actualCompressionCount = mode == ExecMode.HYBRID ? 
Statistics
+                               .getCPHeavyHitterCount("compress") : 
Statistics.getCPHeavyHitterCount("sp_compress");
+
                        Assert.assertEquals(compressionCount, 
actualCompressionCount);
-                       
Assert.assertTrue(heavyHittersContainsString("compress"));
+                       Assert.assertTrue( mode == ExecMode.HYBRID ? 
heavyHittersContainsString("compress") : 
heavyHittersContainsString("sp_compress"));
                        
Assert.assertFalse(heavyHittersContainsString("m_scale"));
 
                }
-               catch(Exception e){
+               catch(Exception e) {
                        resetExecMode(oldPlatform);
                        fail("Failed workload test");
                }
diff --git 
a/src/test/scripts/functions/compress/workload/SystemDS-config-compress-workload.xml
 
b/src/test/scripts/functions/compress/workload/SystemDS-config-compress-workload.xml
index 4e735c6..ed2ab68 100644
--- 
a/src/test/scripts/functions/compress/workload/SystemDS-config-compress-workload.xml
+++ 
b/src/test/scripts/functions/compress/workload/SystemDS-config-compress-workload.xml
@@ -19,6 +19,7 @@
 
 <root>
        <sysds.compressed.linalg>workload</sysds.compressed.linalg>
+       <sysds.defaultblocksize>8000</sysds.defaultblocksize>
        <sysds.cp.parallel.ops>true</sysds.cp.parallel.ops>
        <sysds.scratch>target/force_comp_scratch_space</sysds.scratch>
-</root>
+</root>
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml 
b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
index 78e62c1..77b2959 100644
--- a/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
+++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisMLogReg.dml
@@ -27,7 +27,7 @@ print("")
 print("MLogReg")
 
 X = scale(X=X, scale=TRUE, center=TRUE);
-B = multiLogReg(X=X, Y=Y, verbose=TRUE, maxi = 10, maxii=10);
+B = multiLogReg(X=X, Y=Y, verbose=FALSE, maxi=3, maxii=2);
 
 [nn, P, acc] = multiLogRegPredict(X=X, B=B, Y=Y)
 

Reply via email to