This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new e57dfd0 [SYSTEMDS-2819,2020] Various ctable improvements (rewrites),
part II
e57dfd0 is described below
commit e57dfd0e593052bb84ee388ec516bea578b90227
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Jan 31 22:58:16 2021 +0100
[SYSTEMDS-2819,2020] Various ctable improvements (rewrites), part II
* Fix ctable->rexpand rewrite (failed tests in previous commit)
* Allow ctable-reshape -> ctable rewrite besides for CP also for Spark
* Fix the ctable estimation of output partitions, to account for both
vectors and matrices, not just vectors (otherwise potentially large
underestimation of the nnz upper bound)
* Post-processing of spark ctable outputs to obtain exact sparsity of
ultra-sparse matrices that fit comfortably into the local memory budget.
---
src/main/java/org/apache/sysds/hops/TernaryOp.java | 4 ++--
.../hops/rewrite/RewriteAlgebraicSimplificationDynamic.java | 2 +-
.../sysds/runtime/instructions/spark/CtableSPInstruction.java | 10 +++++++++-
3 files changed, 12 insertions(+), 4 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 2989b43..f8c369c 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -570,7 +570,7 @@ public class TernaryOp extends Hop
@Override
public Object clone() throws CloneNotSupportedException
{
- TernaryOp ret = new TernaryOp();
+ TernaryOp ret = new TernaryOp();
//copy generic attributes
ret.clone(this, false);
@@ -723,7 +723,7 @@ public class TernaryOp extends Hop
public boolean isCTableReshapeRewriteApplicable(ExecType et,
Ctable.OperationTypes opType) {
//early abort if rewrite globally not allowed
- if( !ALLOW_CTABLE_SEQUENCE_REWRITES || _op!=OpOp3.CTABLE ||
et!=ExecType.CP )
+ if( !ALLOW_CTABLE_SEQUENCE_REWRITES || _op!=OpOp3.CTABLE ||
(et!=ExecType.CP && et!=ExecType.SPARK) )
return false;
//1) check for ctable CTABLE_TRANSFORM_SCALAR_WEIGHT
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 77efaf2..519c400 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2643,7 +2643,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
{
//pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v,
max=m, dir=row, ignore=false, cast=true)
//note: this rewrite supports both left/right sequence
- if( hi instanceof TernaryOp && hi.getInput().size()==5
//table without weights
+ if( hi instanceof TernaryOp && hi.getInput().size()==6
//table without weights
&&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) ) //i.e., weight of 1
{
Hop first = hi.getInput().get(0);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CtableSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CtableSPInstruction.java
index dd838f7..ec8d637 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CtableSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CtableSPInstruction.java
@@ -24,6 +24,7 @@ import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Ctable;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -130,7 +131,7 @@ public class CtableSPInstruction extends
ComputationSPInstruction {
sec.getScalarInput(input3).getLongValue();
}
mcOut.set(dim1, dim2, mc1.getBlocksize());
- mcOut.setNonZerosBound(mc1.getRows());
+ mcOut.setNonZerosBound(mc1.getLength()); //vector or matrix
if( !mcOut.dimsKnown() )
throw new DMLRuntimeException("Unknown ctable output
dimensions: "+mcOut);
@@ -189,6 +190,13 @@ public class CtableSPInstruction extends
ComputationSPInstruction {
sec.addLineageRDD(output.getName(), input2.getName());
if( ctableOp.hasThirdInput() )
sec.addLineageRDD(output.getName(), input3.getName());
+
+ //post-processing to obtain sparsity of ultra-sparse outputs
+ long memUB = OptimizerUtils.estimateSizeExactSparsity(
+ mcOut.getRows(), mcOut.getCols(),
mcOut.getNonZerosBound());
+ if( !OptimizerUtils.exceedsCachingThreshold(mcOut.getCols(),
memUB) //< mem budget
+ && memUB <
OptimizerUtils.estimateSizeExactSparsity(mcOut))
+ sec.getMatrixObject(output).acquireReadAndRelease();
}
private static class CTableFunction implements
PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>,
MatrixIndexes, MatrixBlock>