Repository: incubator-systemml Updated Branches: refs/heads/master af93ca8a4 -> af1a8d852
[SYSTEMML-1438] Extended code generator (celltmpl min/max agg, fixes) This patch adds support for min and max aggregations (full, row-wise) in cell templates, which includes various runtime extensions because these aggregate functions have different properties compared to the existing sum and sum_sq. Furthermore, this also includes additional tests and fixes for (1) unnecessarily compiled classes of subsumed fused operators, and (2) best plan selection in both plan selection heuristics. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/af1a8d85 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/af1a8d85 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/af1a8d85 Branch: refs/heads/master Commit: af1a8d8527816a8812c60959332b23f8a04d7133 Parents: af93ca8 Author: Matthias Boehm <[email protected]> Authored: Mon Mar 27 18:26:21 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Mar 27 20:37:44 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 13 +- .../hops/codegen/template/PlanSelection.java | 15 ++ .../codegen/template/PlanSelectionFuseAll.java | 11 +- .../template/PlanSelectionFuseNoRedundancy.java | 11 +- .../hops/codegen/template/TemplateCell.java | 8 +- .../hops/codegen/template/TemplateUtils.java | 2 +- .../sysml/hops/rewrite/HopRewriteUtils.java | 10 + .../sysml/runtime/codegen/SpoofCellwise.java | 250 +++++++++++++------ .../instructions/spark/SpoofSPInstruction.java | 33 ++- .../sysml/runtime/util/UtilFunctions.java | 14 ++ .../test/integration/AutomatedTestBase.java | 5 +- .../functions/codegen/CellwiseTmplTest.java | 26 +- .../scripts/functions/codegen/cellwisetmpl10.R | 34 +++ .../functions/codegen/cellwisetmpl10.dml | 30 +++ 14 files changed, 356 insertions(+), 106 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 273e790..187a9ca 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -408,9 +408,16 @@ public class SpoofCompiler Statistics.incrementCodegenCPlanCompile(1); } - //process childs recursively - for( Hop c : hop.getInput() ) - rConstructCPlans(c, memo, cplans, compileLiterals); + //process children recursively, but skip compiled operator + if( cplans.containsKey(hop.getHopID()) ) { + for( Hop c : cplans.get(hop.getHopID()).getKey() ) + rConstructCPlans(c, memo, cplans, compileLiterals); + } + else { + for( Hop c : hop.getInput() ) + rConstructCPlans(c, memo, cplans, compileLiterals); + } + hop.setVisited(); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java index 4abe760..e7ae824 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java @@ -21,7 +21,9 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; +import java.util.List; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; @@ -31,6 +33,8 @@ import org.apache.sysml.runtime.util.UtilFunctions; public abstract class PlanSelection { + private final HashMap<Long, List<MemoTableEntry>> _bestPlans = + new HashMap<Long, List<MemoTableEntry>>(); private final HashSet<VisitMark> _visited = new HashSet<VisitMark>(); /** @@ -58,6 +62,17 @@ public abstract class PlanSelection || (me.type == TemplateType.CellTpl); } + protected void addBestPlan(long hopID, MemoTableEntry me) { + if( me == null ) return; + if( !_bestPlans.containsKey(hopID) ) + _bestPlans.put(hopID, new ArrayList<MemoTableEntry>()); + _bestPlans.get(hopID).add(me); + } + + protected HashMap<Long, List<MemoTableEntry>> getBestPlans() { + return _bestPlans; + } + public boolean isVisited(long hopID, TemplateType type) { return _visited.contains(new VisitMark(hopID, type)); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java index 5d8c9ce..a455302 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java @@ -21,7 +21,6 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.Comparator; -import java.util.HashMap; import java.util.Map.Entry; import java.util.HashSet; import java.util.List; @@ -39,9 +38,6 @@ import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; */ public class PlanSelectionFuseAll extends PlanSelection { - private HashMap<Long, List<MemoTableEntry>> _bestPlans = - new HashMap<Long, List<MemoTableEntry>>(); - @Override public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) { //pruning and collection pass @@ -49,7 +45,7 @@ public class PlanSelectionFuseAll extends PlanSelection rSelectPlans(memo, hop, null); //take all distinct best plans - for( Entry<Long, List<MemoTableEntry>> e : _bestPlans.entrySet() ) + for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) memo.setDistinct(e.getKey(), e.getValue()); } @@ -78,11 +74,12 @@ public class PlanSelectionFuseAll extends PlanSelection .min(new BasicPlanComparator()).orElse(null); } else { - best = memo.get(current.getHopID()).stream() .filter(p -> p.type==currentType || p.type==TemplateType.CellTpl) - .min(Comparator.comparing(p -> 3-p.countPlanRefs())).orElse(null); + .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) + .orElse(null); } + addBestPlan(current.getHopID(), best); } //step 3: recursively process children http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java index 71b74b2..fa8b25a 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java @@ -21,7 +21,6 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.Comparator; -import java.util.HashMap; import java.util.Map.Entry; import java.util.HashSet; import java.util.List; @@ -42,9 +41,6 @@ import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; */ public class PlanSelectionFuseNoRedundancy extends PlanSelection { - private HashMap<Long, List<MemoTableEntry>> _bestPlans = - new HashMap<Long, List<MemoTableEntry>>(); - @Override public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) { //pruning and collection pass @@ -52,7 +48,7 @@ public class PlanSelectionFuseNoRedundancy extends PlanSelection rSelectPlans(memo, hop, null); //take all distinct best plans - for( Entry<Long, List<MemoTableEntry>> e : _bestPlans.entrySet() ) + for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) memo.setDistinct(e.getKey(), e.getValue()); } @@ -92,11 +88,12 @@ public class PlanSelectionFuseNoRedundancy extends PlanSelection .min(new BasicPlanComparator()).orElse(null); } else { - best = memo.get(current.getHopID()).stream() .filter(p -> p.type==currentType || p.type==TemplateType.CellTpl) - .min(Comparator.comparing(p -> 3-p.countPlanRefs())).orElse(null); + .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) + .orElse(null); } + addBestPlan(current.getHopID(), best); } //step 3: recursively process children http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index f86bf69..46e265c 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -31,6 +31,7 @@ import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.IndexingOp; @@ -53,6 +54,9 @@ import org.apache.sysml.runtime.matrix.data.Pair; public class TemplateCell extends TemplateBase { + private static final AggOp[] SUPPORTED_AGG = + new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX}; + public TemplateCell() { super(TemplateType.CellTpl); } @@ -66,7 +70,7 @@ public class TemplateCell extends TemplateBase @Override public boolean fuse(Hop hop, Hop input) { return !isClosed() && (isValidOperation(hop) - || ((HopRewriteUtils.isSum(hop)||HopRewriteUtils.isSumSq(hop)) + || (HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) && ((AggUnaryOp) hop).getDirection()!= Direction.Col) || (HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1()==1 && hop.getDim2()==1) && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))); @@ -81,7 +85,7 @@ public class TemplateCell extends TemplateBase @Override public CloseType close(Hop hop) { //need to close cell tpl after aggregation, see fuse for exact properties - if( ((HopRewriteUtils.isSum(hop)||HopRewriteUtils.isSumSq(hop)) + if( (HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) && ((AggUnaryOp) hop).getDirection()!= Direction.Col) || (HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1()==1 && hop.getDim2()==1) ) return CloseType.CLOSED_VALID; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index dc42eb9..a172af5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -217,7 +217,7 @@ public class TemplateUtils return OutProdType.CELLWISE_OUTER_PRODUCT; //should never come here - throw new RuntimeException("Undefined outer product type"); + throw new RuntimeException("Undefined outer product type for hop "+out.getHopID()); } public static boolean isLookup(CNode node) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 8ea0355..88ee174 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -795,6 +795,16 @@ public class HopRewriteUtils return hop instanceof AggBinaryOp && ((AggBinaryOp)hop).isMatrixMultiply(); } + public static boolean isAggUnaryOp(Hop hop, AggOp...op) { + if( !(hop instanceof AggUnaryOp) ) + return false; + AggOp hopOp = ((AggUnaryOp)hop).getOp(); + for( AggOp opi : op ) + if( hopOp == opi ) + return true; + return false; + } + public static boolean isSum(Hop hop) { return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java index f27564b..e1ccdb4 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java @@ -29,9 +29,12 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.functionobjects.Builtin; +import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysml.runtime.functionobjects.KahanFunction; import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.functionobjects.KahanPlusSq; +import org.apache.sysml.runtime.functionobjects.ValueFunction; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.KahanObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; @@ -54,6 +57,8 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl public enum AggOp { SUM, SUM_SQ, + MIN, + MAX, } private final CellType _type; @@ -70,14 +75,20 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl return _type; } + public AggOp getAggOp() { + return _aggOp; + } + public boolean isSparseSafe() { return _sparseSafe; } - private KahanFunction getAggFunction() { + private ValueFunction getAggFunction() { switch( _aggOp ) { case SUM: return KahanPlus.getKahanPlusFnObject(); case SUM_SQ: return KahanPlusSq.getKahanPlusSqFnObject(); + case MIN: return Builtin.getBuiltinFnObject(BuiltinCode.MIN); + case MAX: return Builtin.getBuiltinFnObject(BuiltinCode.MAX); default: throw new RuntimeException("Unsupported " + "aggregation type: "+_aggOp.name()); @@ -137,6 +148,12 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl throw new DMLRuntimeException(ex); } } + + //correction for min/max + if( (_aggOp == AggOp.MIN || _aggOp == AggOp.MAX) && sparseSafe + && inputs.get(0).getNonZeros()<inputs.get(0).getNumRows()*inputs.get(0).getNumColumns() ) + sum = getAggFunction().execute(sum, 0); //unseen 0 might be max or min value + return new DoubleObject(sum); } @@ -210,36 +227,51 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl out.examSparsity(); } - /** - * - * @param a - * @param b - * @param c - * @param n - * @param rl - * @param ru - */ - private double executeDenseAndAgg(double[] a, double[][] b, double[] scalars, int m, int n, boolean sparseSafe, int rl, int ru) + private double executeDenseAndAgg(double[] a, double[][] b, double[] scalars, int m, int n, boolean sparseSafe, int rl, int ru) throws DMLRuntimeException { - KahanObject kbuff = new KahanObject(0, 0); - KahanFunction kplus = getAggFunction(); - - if( a == null && !sparseSafe ) { //empty - for( int i=rl; i<ru; i++ ) - for( int j=0; j<n; j++ ) - kplus.execute2(kbuff, genexec( 0, b, scalars, m, n, i, j )); + ValueFunction vfun = getAggFunction(); + double ret = 0; + + //numerically stable aggregation for sum/sum_sq + if( vfun instanceof KahanFunction ) { + KahanObject kbuff = new KahanObject(0, 0); + KahanFunction kplus = (KahanFunction) vfun; + + if( a == null && !sparseSafe ) { //empty + for( int i=rl; i<ru; i++ ) + for( int j=0; j<n; j++ ) + kplus.execute2(kbuff, genexec( 0, b, scalars, m, n, i, j )); + } + else if( a != null ) { //general case + for( int i=rl, ix=rl*n; i<ru; i++ ) + for( int j=0; j<n; j++, ix++ ) + if( a[ix] != 0 || !sparseSafe) + kplus.execute2(kbuff, genexec( a[ix], b, scalars, m, n, i, j )); + } + ret = kbuff._sum; } - else if( a != null ) { //general case - for( int i=rl, ix=rl*n; i<ru; i++ ) - for( int j=0; j<n; j++, ix++ ) - if( a[ix] != 0 || !sparseSafe) - kplus.execute2(kbuff, genexec( a[ix], b, scalars, m, n, i, j )); + //safe aggregation for min/max w/ handling of zero entries + //note: sparse safe with zero value as min/max handled outside + else { + ret = (_aggOp==AggOp.MIN) ? Double.MAX_VALUE : -Double.MAX_VALUE; + if( a == null && !sparseSafe ) { //empty + for( int i=rl; i<ru; i++ ) + for( int j=0; j<n; j++ ) + ret = vfun.execute(ret, genexec( 0, b, scalars, m, n, i, j )); + } + else if( a != null ) { //general case + for( int i=rl, ix=rl*n; i<ru; i++ ) + for( int j=0; j<n; j++, ix++ ) + if( a[ix] != 0 || !sparseSafe) + ret = vfun.execute(ret, genexec( a[ix], b, scalars, m, n, i, j )); + } } - return kbuff._sum; + return ret; } private long executeDense(double[] a, double[][] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru) + throws DMLRuntimeException { long lnnz = 0; @@ -265,26 +297,50 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl } else if( _type == CellType.ROW_AGG ) { - KahanObject kbuff = new KahanObject(0, 0); - KahanFunction kplus = getAggFunction(); - - if( a == null && !sparseSafe ) { //empty - for( int i=rl; i<ru; i++ ) { - kbuff.set(0, 0); - for( int j=0; j<n; j++ ) - kplus.execute2(kbuff, genexec( 0, b, scalars, m, n, i, j )); - c[i] = kbuff._sum; - lnnz += (c[i]!=0) ? 1 : 0; + ValueFunction vfun = getAggFunction(); + + if( vfun instanceof KahanFunction ) { + KahanObject kbuff = new KahanObject(0, 0); + KahanFunction kplus = (KahanFunction) vfun; + + if( a == null && !sparseSafe ) { //empty + for( int i=rl; i<ru; i++ ) { + kbuff.set(0, 0); + for( int j=0; j<n; j++ ) + kplus.execute2(kbuff, genexec( 0, b, scalars, m, n, i, j )); + lnnz += ((c[i] = kbuff._sum)!=0) ? 1 : 0; + } + } + else if( a != null ) { //general case + for( int i=rl, ix=rl*n; i<ru; i++ ) { + kbuff.set(0, 0); + for( int j=0; j<n; j++, ix++ ) + if( a[ix] != 0 || !sparseSafe) + kplus.execute2(kbuff, genexec( a[ix], b, scalars, m, n, i, j )); + lnnz += ((c[i] = kbuff._sum)!=0) ? 1 : 0; + } } } - else if( a != null ) { //general case - for( int i=rl, ix=rl*n; i<ru; i++ ) { - kbuff.set(0, 0); - for( int j=0; j<n; j++, ix++ ) - if( a[ix] != 0 || !sparseSafe) - kplus.execute2(kbuff, genexec( a[ix], b, scalars, m, n, i, j )); - c[i] = kbuff._sum; - lnnz += (c[i]!=0) ? 1 : 0; + else { + double initialVal = (_aggOp==AggOp.MIN) ? Double.MAX_VALUE : -Double.MAX_VALUE; + if( a == null && !sparseSafe ) { //empty + for( int i=rl; i<ru; i++ ) { + double tmp = initialVal; + for( int j=0; j<n; j++ ) + tmp = vfun.execute(tmp, genexec( 0, b, scalars, m, n, i, j )); + lnnz += ((c[i] = tmp)!=0) ? 1 : 0; + } + } + else if( a != null ) { //general case + for( int i=rl, ix=rl*n; i<ru; i++ ) { + double tmp = initialVal; + for( int j=0; j<n; j++, ix++ ) + if( a[ix] != 0 || !sparseSafe) + tmp = vfun.execute(tmp, genexec( a[ix], b, scalars, m, n, i, j )); + if( sparseSafe && UtilFunctions.containsZero(a, ix-n, n) ) + tmp = vfun.execute(tmp, 0); + lnnz += ((c[i] = tmp)!=0) ? 1 : 0; + } } } } @@ -293,35 +349,63 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl } private double executeSparseAndAgg(SparseBlock sblock, double[][] b, double[] scalars, int m, int n, boolean sparseSafe, int rl, int ru) + throws DMLRuntimeException { - KahanObject kbuff = new KahanObject(0, 0); - KahanFunction kplus = getAggFunction(); + ValueFunction vfun = getAggFunction(); + double ret = 0; - if( sparseSafe ) { - if( sblock != null ) { + //numerically stable aggregation for sum/sum_sq + if( vfun instanceof KahanFunction ) { + KahanObject kbuff = new KahanObject(0, 0); + KahanFunction kplus = (KahanFunction) vfun; + + if( !sparseSafe ) { + for(int i=rl; i<ru; i++) + for(int j=0; j<n; j++) { + double valij = (sblock != null) ? sblock.get(i, j) : 0; + kplus.execute2( kbuff, genexec(valij, b, scalars, m, n, i, j)); + } + } + else if( sblock != null ) { for( int i=rl; i<ru; i++ ) if( !sblock.isEmpty(i) ) { int apos = sblock.pos(i); int alen = sblock.size(i); double[] avals = sblock.values(i); - for( int j=apos; j<apos+alen; j++ ) { + for( int j=apos; j<apos+alen; j++ ) kplus.execute2( kbuff, genexec(avals[j], b, scalars, m, n, i, j)); - } } } + ret = kbuff._sum; } - else { //sparse-unsafe - for(int i=rl; i<ru; i++) - for(int j=0; j<n; j++) { - double valij = (sblock != null) ? sblock.get(i, j) : 0; - kplus.execute2( kbuff, genexec(valij, b, scalars, m, n, i, j)); - } - } + //safe aggregation for min/max w/ handling of zero entries + //note: sparse safe with zero value as min/max handled outside + else { + ret = (_aggOp==AggOp.MIN) ? Double.MAX_VALUE : -Double.MAX_VALUE; + if( !sparseSafe ) { + for(int i=rl; i<ru; i++) + for(int j=0; j<n; j++) { + double valij = (sblock != null) ? sblock.get(i, j) : 0; + ret = vfun.execute( ret, genexec(valij, b, scalars, m, n, i, j)); + } + } + else if( sblock != null ) { + for( int i=rl; i<ru; i++ ) + if( !sblock.isEmpty(i) ) { + int apos = sblock.pos(i); + int alen = sblock.size(i); + double[] avals = sblock.values(i); + for( int j=apos; j<apos+alen; j++ ) + ret = vfun.execute( ret, genexec(avals[j], b, scalars, m, n, i, j)); + } + } + } - return kbuff._sum; + return ret; } private long executeSparse(SparseBlock sblock, double[][] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru) + throws DMLRuntimeException { long lnnz = 0; if( _type == CellType.NO_AGG ) @@ -352,35 +436,57 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl } else if( _type == CellType.ROW_AGG ) { - KahanObject kbuff = new KahanObject(0, 0); - KahanFunction kplus = getAggFunction(); + ValueFunction vfun = getAggFunction(); - if( sparseSafe ) { - if( sblock != null ) { + if( vfun instanceof KahanFunction ) { + KahanObject kbuff = new KahanObject(0, 0); + KahanFunction kplus = (KahanFunction) vfun; + + if( !sparseSafe ) { + for(int i=rl; i<ru; i++) { + kbuff.set(0, 0); + for(int j=0; j<n; j++) + kplus.execute2( kbuff, genexec( (sblock != null) ? + sblock.get(i, j) : 0, b, scalars, m, n, i, j)); + lnnz += ((c[i] = kbuff._sum)!=0) ? 1 : 0; + } + } + else if( sblock != null ) { //general case for( int i=rl; i<ru; i++ ) { if( sblock.isEmpty(i) ) continue; kbuff.set(0, 0); int apos = sblock.pos(i); int alen = sblock.size(i); double[] avals = sblock.values(i); - for( int j=apos; j<apos+alen; j++ ) { + for( int j=apos; j<apos+alen; j++ ) kplus.execute2(kbuff, genexec(avals[j], b, scalars, m, n, i, j)); - } - c[i] = kbuff._sum; - lnnz += (c[i]!=0) ? 1 : 0; + lnnz += ((c[i] = kbuff._sum)!=0) ? 1 : 0; } } } - else { //sparse-unsafe - for(int i=rl; i<ru; i++) { - kbuff.set(0, 0); - for(int j=0; j<n; j++) { - double valij = (sblock != null) ? sblock.get(i, j) : 0; - kplus.execute2( kbuff, genexec(valij, b, scalars, m, n, i, j)); + else { + double initialVal = (_aggOp==AggOp.MIN) ? Double.MAX_VALUE : -Double.MAX_VALUE; + if( !sparseSafe ) { + for(int i=rl; i<ru; i++) { + double tmp = initialVal; + for(int j=0; j<n; j++) + tmp = vfun.execute( tmp, genexec( (sblock != null) ? + sblock.get(i, j) : 0, b, scalars, m, n, i, j)); + lnnz += ((c[i] = tmp)!=0) ? 1 : 0; } - c[i] = kbuff._sum; - lnnz += (c[i]!=0) ? 1 : 0; } + else if( sblock != null ) { //general case + for( int i=rl; i<ru; i++ ) { + if( sblock.isEmpty(i) ) continue; + int apos = sblock.pos(i); + int alen = sblock.size(i); + double[] avals = sblock.values(i); + double tmp = (alen < n) ? 0 : initialVal; + for( int j=apos; j<apos+alen; j++ ) + tmp = vfun.execute(tmp, genexec(avals[j], b, scalars, m, n, i, j)); + lnnz += ((c[i] = tmp)!=0) ? 1 : 0; + } + } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index ca92b9b..a5acfd9 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -26,10 +26,12 @@ import java.util.List; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; +import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.codegen.CodegenUtils; import org.apache.sysml.runtime.codegen.SpoofCellwise; +import org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.codegen.SpoofOperator; import org.apache.sysml.runtime.codegen.SpoofOuterProduct; @@ -37,6 +39,9 @@ import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; import org.apache.sysml.runtime.codegen.SpoofRowAggregate; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.functionobjects.Builtin; +import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; +import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.DoubleObject; @@ -47,6 +52,7 @@ import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import scala.Tuple2; @@ -94,7 +100,7 @@ public class SpoofSPInstruction extends SPInstruction //get input rdd and variable name ArrayList<String> bcVars = new ArrayList<String>(); MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName()); - JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable( _in[0].getName() ); + JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable( _in[0].getName() ); JavaPairRDD<MatrixIndexes, MatrixBlock> out = null; //simple case: map-side only operation (one rdd input, broadcast all) @@ -115,17 +121,17 @@ public class SpoofSPInstruction extends SPInstruction //initialize Spark Operator if(_class.getSuperclass() == SpoofCellwise.class) // cellwise operator { + SpoofCellwise op = (SpoofCellwise) CodegenUtils.createInstance(_class); + AggregateOperator aggop = getAggregateOperator(op.getAggOp()); + if( _out.getDataType()==DataType.MATRIX ) { - SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class); - out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); - if( ((SpoofCellwise)op).getCellType()==CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock() ) { - //NOTE: workaround with partition size needed due to potential bug in SPARK + if( op.getCellType()==CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock() ) { //TODO investigate if some other side effect of correct blocks if( out.partitions().size() > mcIn.getNumRowBlocks() ) - out = RDDAggregateUtils.sumByKeyStable(out, (int)mcIn.getNumRowBlocks(), false); + out = RDDAggregateUtils.aggByKeyStable(out, aggop, (int)mcIn.getNumRowBlocks(), false); else - out = RDDAggregateUtils.sumByKeyStable(out, false); + out = RDDAggregateUtils.aggByKeyStable(out, aggop, false); } sec.setRDDHandleForVariable(_out.getName(), out); @@ -139,7 +145,7 @@ public class SpoofSPInstruction extends SPInstruction } else { //SCALAR out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); - MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out); + MatrixBlock tmpMB = RDDAggregateUtils.aggStable(out, aggop); sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0))); } } @@ -155,7 +161,6 @@ public class SpoofSPInstruction extends SPInstruction out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true); if(type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT ) { - //NOTE: workaround with partition size needed due to potential bug in SPARK //TODO investigate if some other side effect of correct blocks if( in.partitions().size() > mcOut.getNumRowBlocks()*mcOut.getNumColBlocks() ) out = RDDAggregateUtils.sumByKeyStable(out, (int)(mcOut.getNumRowBlocks()*mcOut.getNumColBlocks()), false); @@ -408,4 +413,14 @@ public class SpoofSPInstruction extends SPInstruction return in; } } + + public static AggregateOperator getAggregateOperator(AggOp aggop) { + if( aggop == AggOp.SUM || aggop == AggOp.SUM_SQ ) + return new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.NONE); + else if( aggop == AggOp.MIN ) + return new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject(BuiltinCode.MIN), false, CorrectionLocationType.NONE); + else if( aggop == AggOp.MAX ) + return new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject(BuiltinCode.MAX), false, CorrectionLocationType.NONE); + return null; + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java index e17eb14..c15ace1 100644 --- a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java @@ -520,4 +520,18 @@ public class UtilFunctions public static ValueType[] copyOf(ValueType[] schema1, ValueType[] schema2) { return (ValueType[]) ArrayUtils.addAll(schema1, schema2); } + + public static int countNonZeros(double[] data, int pos, int len) { + int ret = 0; + for( int i=pos; i<pos+len; i++ ) + ret += (data[i] != 0) ? 1 : 0; + return ret; + } + + public static boolean containsZero(double[] data, int pos, int len) { + for( int i=pos; i<pos+len; i++ ) + if( data[i] == 0 ) + return true; + return false; + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index ffb080b..318a1c8 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -1781,9 +1781,10 @@ public abstract class AutomatedTestBase return writeInputFrame(name, data, false, schema, oi); } - protected boolean heavyHittersContainsSubString(String str) { + protected boolean heavyHittersContainsSubString(String... str) { for( String opcode : Statistics.getCPHeavyHitterOpCodes()) - if(opcode.contains(str)) + for( String s : str ) + if(opcode.contains(s)) return true; return false; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index 3be9da8..066b761 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -44,7 +44,8 @@ public class CellwiseTmplTest extends AutomatedTestBase private static final String TEST_NAME6 = TEST_NAME+6; private static final String TEST_NAME7 = TEST_NAME+7; private static final String TEST_NAME8 = TEST_NAME+8; - private static final String TEST_NAME9 = TEST_NAME+9; //sum((X + 7 * Y)^2) + private static final String TEST_NAME9 = TEST_NAME+9; //sum((X + 7 * Y)^2) + private static final String TEST_NAME10 = TEST_NAME+10; //min/max(X + 7 * Y) private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/"; @@ -57,7 +58,7 @@ public class CellwiseTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for( int i=1; i<=9; i++ ) { + for( int i=1; i<=10; i++ ) { addTestConfiguration( TEST_NAME+i, new TestConfiguration( TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) ); } @@ -108,6 +109,11 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwiseRewrite9() { testCodegenIntegration( TEST_NAME9, true, ExecType.CP ); } + + @Test + public void testCodegenCellwiseRewrite10() { + testCodegenIntegration( TEST_NAME10, true, ExecType.CP ); + } @Test public void testCodegenCellwise1() { @@ -154,6 +160,11 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwise9() { testCodegenIntegration( TEST_NAME9, false, ExecType.CP ); } + + @Test + public void testCodegenCellwise10() { + testCodegenIntegration( TEST_NAME10, false, ExecType.CP ); + } @Test public void testCodegenCellwiseRewrite1_sp() { @@ -175,6 +186,11 @@ public class CellwiseTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME9, true, ExecType.SPARK ); } + @Test + public void testCodegenCellwiseRewrite10_sp() { + testCodegenIntegration( TEST_NAME10, true, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { @@ -210,7 +226,8 @@ public class CellwiseTmplTest extends AutomatedTestBase runTest(true, false, null, -1); runRScript(true); - if(testname.equals(TEST_NAME6) || testname.equals(TEST_NAME7) || testname.equals(TEST_NAME9) ) { + if(testname.equals(TEST_NAME6) || testname.equals(TEST_NAME7) + || testname.equals(TEST_NAME9) || testname.equals(TEST_NAME10) ) { //compare scalars HashMap<CellIndex, Double> dmlfile = readDMLScalarFromHDFS("S"); HashMap<CellIndex, Double> rfile = readRScalarFromFS("S"); @@ -228,6 +245,9 @@ public class CellwiseTmplTest extends AutomatedTestBase || heavyHittersContainsSubString("sp_spoofCell")); if( testname.equals(TEST_NAME7) ) //ensure matrix mult is fused Assert.assertTrue(!heavyHittersContainsSubString("tsmm")); + else if( testname.equals(TEST_NAME10) ) //ensure min/max is fused + Assert.assertTrue(!heavyHittersContainsSubString("uamin","uamax")); + } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrites; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/test/scripts/functions/codegen/cellwisetmpl10.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl10.R b/src/test/scripts/functions/codegen/cellwisetmpl10.R new file mode 100644 index 0000000..e7ba81c --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl10.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(7, 1006), 500, 2); +Y = matrix(seq(6, 1005), 500, 2); + +Z = X + -7 * Y; +R1 = min(Z); +R2 = max(Z); +R = R1 + R2; + +write(R, paste(args[2],"S",sep="")) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/af1a8d85/src/test/scripts/functions/codegen/cellwisetmpl10.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl10.dml b/src/test/scripts/functions/codegen/cellwisetmpl10.dml new file mode 100644 index 0000000..225dd4c --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl10.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(7, 1006), 500, 2); +Y = matrix(seq(6, 1005), 500, 2); + +Z = X + -7 * Y; +R1 = min(Z); +R2 = max(Z); +R = R1 + R2; + +write(R, $1)
