This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new e57dfd0  [SYSTEMDS-2819,2020] Various ctable improvements (rewrites), 
part II
e57dfd0 is described below

commit e57dfd0e593052bb84ee388ec516bea578b90227
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Jan 31 22:58:16 2021 +0100

    [SYSTEMDS-2819,2020] Various ctable improvements (rewrites), part II
    
    * Fix ctable->rexpand rewrite (failed tests in previous commit)
    
    * Allow ctable-reshape -> ctable rewrite besides for CP also for Spark
    
    * Fix the ctable estimation of output partitions, to account for both
    vectors and matrices, not just vectors (otherwise potentially large
    underestimation of the nnz upper bound)
    
    * Post-processing of spark ctable outputs to obtain exact sparsity of
    ultra-sparse matrices that fit comfortably into the local memory budget.
---
 src/main/java/org/apache/sysds/hops/TernaryOp.java             |  4 ++--
 .../hops/rewrite/RewriteAlgebraicSimplificationDynamic.java    |  2 +-
 .../sysds/runtime/instructions/spark/CtableSPInstruction.java  | 10 +++++++++-
 3 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java 
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 2989b43..f8c369c 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -570,7 +570,7 @@ public class TernaryOp extends Hop
        @Override
        public Object clone() throws CloneNotSupportedException 
        {
-               TernaryOp ret = new TernaryOp();        
+               TernaryOp ret = new TernaryOp();
                
                //copy generic attributes
                ret.clone(this, false);
@@ -723,7 +723,7 @@ public class TernaryOp extends Hop
        
        public boolean isCTableReshapeRewriteApplicable(ExecType et, 
Ctable.OperationTypes opType) {
                //early abort if rewrite globally not allowed
-               if( !ALLOW_CTABLE_SEQUENCE_REWRITES || _op!=OpOp3.CTABLE || 
et!=ExecType.CP )
+               if( !ALLOW_CTABLE_SEQUENCE_REWRITES || _op!=OpOp3.CTABLE || 
(et!=ExecType.CP && et!=ExecType.SPARK) )
                        return false;
                
                //1) check for ctable CTABLE_TRANSFORM_SCALAR_WEIGHT
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 77efaf2..519c400 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2643,7 +2643,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
        {
                //pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, 
max=m, dir=row, ignore=false, cast=true)
                //note: this rewrite supports both left/right sequence 
-               if(    hi instanceof TernaryOp && hi.getInput().size()==5 
//table without weights 
+               if(    hi instanceof TernaryOp && hi.getInput().size()==6 
//table without weights 
                        && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) ) //i.e., weight of 1
                {
                        Hop first = hi.getInput().get(0);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CtableSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CtableSPInstruction.java
index dd838f7..ec8d637 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CtableSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CtableSPInstruction.java
@@ -24,6 +24,7 @@ import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Ctable;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -130,7 +131,7 @@ public class CtableSPInstruction extends 
ComputationSPInstruction {
                                sec.getScalarInput(input3).getLongValue();
                }
                mcOut.set(dim1, dim2, mc1.getBlocksize());
-               mcOut.setNonZerosBound(mc1.getRows());
+               mcOut.setNonZerosBound(mc1.getLength()); //vector or matrix
                if( !mcOut.dimsKnown() )
                        throw new DMLRuntimeException("Unknown ctable output 
dimensions: "+mcOut);
                
@@ -189,6 +190,13 @@ public class CtableSPInstruction extends 
ComputationSPInstruction {
                        sec.addLineageRDD(output.getName(), input2.getName());
                if( ctableOp.hasThirdInput() )
                        sec.addLineageRDD(output.getName(), input3.getName());
+               
+               //post-processing to obtain sparsity of ultra-sparse outputs
+               long memUB = OptimizerUtils.estimateSizeExactSparsity(
+                       mcOut.getRows(), mcOut.getCols(), 
mcOut.getNonZerosBound());
+               if( !OptimizerUtils.exceedsCachingThreshold(mcOut.getCols(), 
memUB) //< mem budget
+                       && memUB < 
OptimizerUtils.estimateSizeExactSparsity(mcOut))
+                       sec.getMatrixObject(output).acquireReadAndRelease();
        }
 
        private static class CTableFunction implements 
PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, 
MatrixIndexes, MatrixBlock> 

Reply via email to