[SYSTEMML-1288] New multi-aggregate codegen template (compiler/runtime) This patch introduces a new multi aggregate codegen template that allows to fuse multiple full aggregation roots over shared inputs or common subexpression into a single-pass operator.
Furthermore, this also includes a fix of the plan selection cost model and enables the cost-based plan selector by default. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/174bf7db Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/174bf7db Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/174bf7db Branch: refs/heads/master Commit: 174bf7db2158481ba054eb3f41b83738eb5f70d3 Parents: 5937df9 Author: Matthias Boehm <[email protected]> Authored: Thu Apr 6 01:13:35 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Apr 6 16:01:20 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 42 ++- .../sysml/hops/codegen/cplan/CNodeMultiAgg.java | 190 ++++++++++++++ .../sysml/hops/codegen/cplan/CNodeTpl.java | 35 ++- .../hops/codegen/template/CPlanMemoTable.java | 6 +- .../hops/codegen/template/PlanSelection.java | 3 +- .../template/PlanSelectionFuseCostBased.java | 78 +++++- .../hops/codegen/template/TemplateBase.java | 4 +- .../hops/codegen/template/TemplateCell.java | 9 +- .../hops/codegen/template/TemplateMultiAgg.java | 109 ++++++++ .../hops/codegen/template/TemplateUtils.java | 2 + .../sysml/hops/rewrite/HopRewriteUtils.java | 11 + .../runtime/codegen/SpoofMultiAggregate.java | 257 +++++++++++++++++++ .../instructions/spark/SpoofSPInstruction.java | 97 ++++++- .../java/org/apache/sysml/utils/Explain.java | 11 +- .../functions/codegen/AlgorithmLinregCG.java | 2 +- .../functions/codegen/CellwiseTmplTest.java | 4 +- .../functions/codegen/MultiAggTmplTest.java | 143 +++++++++++ .../functions/codegen/multiAggPattern1.R | 32 +++ .../functions/codegen/multiAggPattern1.dml | 28 ++ .../functions/codegen/multiAggPattern2.R | 32 +++ .../functions/codegen/multiAggPattern2.dml | 28 ++ .../functions/codegen/ZPackageSuite.java | 1 + 22 files changed, 1102 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 179a06a..2e60732 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -36,6 +36,7 @@ import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeCell; import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg; import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct; import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; @@ -73,6 +74,7 @@ import org.apache.sysml.parser.LanguageException; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.WhileStatement; import org.apache.sysml.parser.WhileStatementBlock; +import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.codegen.CodegenUtils; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; @@ -91,7 +93,7 @@ public class SpoofCompiler public static final boolean PRUNE_REDUNDANT_PLANS = true; public static PlanCachePolicy PLAN_CACHE_POLICY = PlanCachePolicy.CSLH; public static final int PLAN_CACHE_SIZE = 1024; //max 1K classes - public static final PlanSelector PLAN_SEL_POLICY = PlanSelector.FUSE_ALL; + public static final PlanSelector PLAN_SEL_POLICY = PlanSelector.FUSE_COST_BASED; public enum CompilerType { JAVAC, @@ -449,7 +451,7 @@ public class SpoofCompiler throws DMLException { //top-down memoization of processed dag nodes - if( hop.isVisited() ) + if( hop == null || hop.isVisited() ) return; //generate cplan for existing memo table entry @@ -518,6 +520,17 @@ public class SpoofCompiler hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz()); if(tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct)tmpCNode).isTransposeOutput() ) hnew = HopRewriteUtils.createTranspose(hnew); + else if( tmpCNode instanceof CNodeMultiAgg ) { + ArrayList<Hop> roots = ((CNodeMultiAgg)tmpCNode).getRootNodes(); + hnew.setDataType(DataType.MATRIX); + HopRewriteUtils.setOutputParameters(hnew, 1, roots.size(), + inHops[0].getRowsInBlock(), inHops[0].getColsInBlock(), -1); + //inject artificial right indexing operations for all parents of all nodes + for( int i=0; i<roots.size(); i++ ) { + Hop hnewi = HopRewriteUtils.createScalarIndexing(hnew, 1, i+1); + HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi); + } + } else if( tmpCNode instanceof CNodeCell && ((CNodeCell)tmpCNode).requiredCastDtm() ) { HopRewriteUtils.setOutputParametersForScalar(hnew); hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX); @@ -550,7 +563,11 @@ public class SpoofCompiler //collect cplan leaf node names HashSet<Long> leafs = new HashSet<Long>(); - rCollectLeafIDs(tpl.getOutput(), leafs); + if( tpl instanceof CNodeMultiAgg ) + for( CNode out : ((CNodeMultiAgg)tpl).getOutputs() ) + rCollectLeafIDs(out, leafs); + else + rCollectLeafIDs(tpl.getOutput(), leafs); //create clean cplan w/ minimal inputs if( inHops.length == leafs.size() ) @@ -571,12 +588,29 @@ public class SpoofCompiler CNodeData in1 = (CNodeData)tpl.getInput().get(0); rFindAndRemoveLookup(tpl.getOutput(), in1); } + else if( tpl instanceof CNodeMultiAgg ) { + CNodeData in1 = (CNodeData)tpl.getInput().get(0); + for( CNode output : ((CNodeMultiAgg)tpl).getOutputs() ) + rFindAndRemoveLookup(output, in1); + } //remove invalid plans with column indexing on main input if( tpl instanceof CNodeCell ) { CNodeData in1 = (CNodeData)tpl.getInput().get(0); - if( rHasLookupRC1(tpl.getOutput(), in1) ) + if( rHasLookupRC1(tpl.getOutput(), in1) ) { cplans2.remove(e.getKey()); + if( LOG.isTraceEnabled() ) + LOG.trace("Removed cplan due to invalid rc1 indexing on main input."); + } + } + else if( tpl instanceof CNodeMultiAgg ) { + CNodeData in1 = (CNodeData)tpl.getInput().get(0); + for( CNode output : ((CNodeMultiAgg)tpl).getOutputs() ) + if( rHasLookupRC1(output, in1) ) { + cplans2.remove(e.getKey()); + if( LOG.isTraceEnabled() ) + LOG.trace("Removed cplan due to invalid rc1 indexing on main input."); + } } //remove cplan w/ single op and w/o agg http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java new file mode 100644 index 0000000..7ec07a6 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.cplan; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.apache.commons.collections.CollectionUtils; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; + +public class CNodeMultiAgg extends CNodeTpl +{ + private static final String TEMPLATE = + "package codegen;\n" + + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n" + + "import org.apache.sysml.runtime.codegen.SpoofMultiAggregate;\n" + + "import org.apache.sysml.runtime.codegen.SpoofCellwise;\n" + + "import org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp;\n" + + "\n" + + "public final class %TMP% extends SpoofMultiAggregate { \n" + + " public %TMP%() {\n" + + " super(%AGG_OP%);\n" + + " }\n" + + " protected void genexec( double a, double[][] b, double[] scalars, double[] c, " + + "int m, int n, int rowIndex, int colIndex) { \n" + + "%BODY_dense%" + + " }\n" + + "}\n"; + private static final String TEMPLATE_OUT_SUM = " c[%IX%] += %IN%;\n"; + private static final String TEMPLATE_OUT_SUMSQ = " c[%IX%] += %IN% * %IN%;\n"; + private static final String TEMPLATE_OUT_MIN = " c[%IX%] = Math.min(c[%IX%], %IN%);\n"; + private static final String TEMPLATE_OUT_MAX = " c[%IX%] = Math.max(c[%IX%], %IN%);\n"; + + private ArrayList<CNode> _outputs = null; + private ArrayList<AggOp> _aggOps = null; + private ArrayList<Hop> _roots = null; + + public CNodeMultiAgg(ArrayList<CNode> inputs, ArrayList<CNode> outputs) { + super(inputs, null); + _outputs = outputs; + } + + public ArrayList<CNode> getOutputs() { + return _outputs; + } + + @Override + public void resetVisitStatusOutputs() { + for( CNode output : _outputs ) + output.resetVisitStatus(); + } + + public void setAggOps(ArrayList<AggOp> aggOps) { + _aggOps = aggOps; + _hash = 0; + } + + public ArrayList<AggOp> getAggOps() { + return _aggOps; + } + + public void setRootNodes(ArrayList<Hop> roots) { + _roots = roots; + } + + public ArrayList<Hop> getRootNodes() { + return _roots; + } + + @Override + public String codegen(boolean sparse) { + // note: ignore sparse flag, generate both + String tmp = TEMPLATE; + + //rename inputs + rReplaceDataNode(_outputs, _inputs.get(0), "a"); // input matrix + renameInputs(_outputs, _inputs, 1); + + //generate dense/sparse bodies + StringBuilder sb = new StringBuilder(); + for( CNode out : _outputs ) + sb.append(out.codegen(false)); + for( CNode out : _outputs ) + out.resetGenerated(); + + //append output assignments + for( int i=0; i<_outputs.size(); i++ ) { + CNode out = _outputs.get(i); + String tmpOut = getAggTemplate(i); + tmpOut = tmpOut.replace("%IN%", out.getVarname()); + tmpOut = tmpOut.replace("%IX%", String.valueOf(i)); + sb.append(tmpOut); + } + + //replace class name and body + tmp = tmp.replaceAll("%TMP%", createVarname()); + tmp = tmp.replaceAll("%BODY_dense%", sb.toString()); + + //replace meta data information + String aggList = ""; + for( AggOp aggOp : _aggOps ) { + aggList += !aggList.isEmpty() ? "," : ""; + aggList += "AggOp."+aggOp.name(); + } + tmp = tmp.replaceAll("%AGG_OP%", aggList); + + return tmp; + } + + @Override + public void setOutputDims() { + + } + + @Override + public SpoofOutputDimsType getOutputDimType() { + return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector + } + + @Override + public CNodeTpl clone() { + CNodeMultiAgg ret = new CNodeMultiAgg(_inputs, _outputs); + ret.setAggOps(getAggOps()); + return ret; + } + + @Override + public int hashCode() { + if( _hash == 0 ) { + int[] tmp = new int[2*_outputs.size()+1]; + tmp[0] = super.hashCode(); + for( int i=0; i<_outputs.size(); i++ ) { + tmp[1+2*i] = _outputs.get(i).hashCode(); + tmp[1+2*i+1] = _aggOps.get(i).hashCode(); + } + _hash = Arrays.hashCode(tmp); + } + return _hash; + } + + @Override + public boolean equals(Object o) { + if(!(o instanceof CNodeMultiAgg)) + return false; + CNodeMultiAgg that = (CNodeMultiAgg)o; + return super.equals(o) + && CollectionUtils.isEqualCollection(_aggOps, that._aggOps) + && equalInputReferences( + _outputs, that._outputs, _inputs, that._inputs); + } + + @Override + public String getTemplateInfo() { + StringBuilder sb = new StringBuilder(); + sb.append("SPOOF MULTIAGG [aggOps="); + sb.append(Arrays.toString(_aggOps.toArray(new AggOp[0]))); + sb.append("]"); + return sb.toString(); + } + + private String getAggTemplate(int pos) { + switch( _aggOps.get(pos) ) { + case SUM: return TEMPLATE_OUT_SUM; + case SUM_SQ: return TEMPLATE_OUT_SUMSQ; + case MIN: return TEMPLATE_OUT_MIN; + case MAX: return TEMPLATE_OUT_MAX; + default: + return null; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java index 7d0ae8d..e6da944 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java @@ -20,8 +20,10 @@ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; @@ -62,6 +64,10 @@ public abstract class CNodeTpl extends CNode implements Cloneable return ret; } + public void resetVisitStatusOutputs() { + getOutput().resetVisitStatus(); + } + public String codegen() { return codegen(false); } @@ -73,6 +79,10 @@ public abstract class CNodeTpl extends CNode implements Cloneable public abstract String getTemplateInfo(); protected void renameInputs(ArrayList<CNode> inputs, int startIndex) { + renameInputs(Collections.singletonList(_output), inputs, startIndex); + } + + protected void renameInputs(List<CNode> outputs, ArrayList<CNode> inputs, int startIndex) { //create map of hopID to data nodes with new names, used for CSE HashMap<Long, CNode> nodes = new HashMap<Long, CNode>(); for(int i=startIndex, sPos=0, mPos=0; i < inputs.size(); i++) { @@ -87,7 +97,8 @@ public abstract class CNodeTpl extends CNode implements Cloneable } //single pass to replace all names - rReplaceDataNode(_output, nodes, new HashMap<Long, CNode>()); + for( CNode output : outputs ) + rReplaceDataNode(output, nodes, new HashMap<Long, CNode>()); } protected void rReplaceDataNode( CNode root, CNode input, String newName ) { @@ -102,6 +113,19 @@ public abstract class CNodeTpl extends CNode implements Cloneable rReplaceDataNode(root, names, new HashMap<Long,CNode>()); } + protected void rReplaceDataNode( ArrayList<CNode> roots, CNode input, String newName ) { + if( !(input instanceof CNodeData) ) + return; + + //create temporary name mapping + HashMap<Long, CNode> names = new HashMap<Long, CNode>(); + CNodeData tmp = (CNodeData)input; + names.put(tmp.getHopID(), new CNodeData(tmp, newName)); + + for( CNode root : roots ) + rReplaceDataNode(root, names, new HashMap<Long,CNode>()); + } + /** * Recursively searches for data nodes and replaces them if found. * @@ -216,7 +240,7 @@ public abstract class CNodeTpl extends CNode implements Cloneable } protected static boolean equalInputReferences(CNode current1, CNode current2, ArrayList<CNode> input1, ArrayList<CNode> input2) { - boolean ret = true; + boolean ret = (input1.size() == input2.size()); //process childs recursively for( int i=0; ret && i<current1.getInput().size(); i++ ) @@ -231,6 +255,13 @@ public abstract class CNodeTpl extends CNode implements Cloneable return ret; } + protected static boolean equalInputReferences(ArrayList<CNode> current1, ArrayList<CNode> current2, ArrayList<CNode> input1, ArrayList<CNode> input2) { + boolean ret = (current1.size() == current2.size()); + for( int i=0; ret && i<current1.size(); i++ ) + ret &= equalInputReferences(current1.get(i), current2.get(i), input1, input2); + return ret; + } + private static int indexOf(ArrayList<CNode> inputs, CNodeData probe) { for( int i=0; i<inputs.size(); i++ ) { CNodeData cd = ((CNodeData)inputs.get(i)); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 1a82f7d..3aa94d5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -80,10 +80,14 @@ public class CPlanMemoTable } public void add(Hop hop, TemplateType type, long in1, long in2, long in3) { + add(hop, new MemoTableEntry(type, in1, in2, in3)); + } + + public void add(Hop hop, MemoTableEntry me) { _hopRefs.put(hop.getHopID(), hop); if( !_plans.containsKey(hop.getHopID()) ) _plans.put(hop.getHopID(), new ArrayList<MemoTableEntry>()); - _plans.get(hop.getHopID()).add(new MemoTableEntry(type, in1, in2, in3)); + _plans.get(hop.getHopID()).add(me); } public void addAll(Hop hop, MemoTableEntrySet P) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java index e7ae824..142040b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java @@ -59,7 +59,8 @@ public abstract class PlanSelection return (me.type == TemplateType.OuterProdTpl && (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop))) || (me.type == TemplateType.RowAggTpl && me.closed) - || (me.type == TemplateType.CellTpl); + || (me.type == TemplateType.CellTpl) + || (me.type == TemplateType.MultiAggTpl); } protected void addBestPlan(long hopID, MemoTableEntry me) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java index 47717c2..50d6ff1 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java @@ -35,6 +35,7 @@ import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; @@ -81,7 +82,10 @@ public class PlanSelectionFuseCostBased extends PlanSelection if( LOG.isTraceEnabled() ) LOG.trace("Partition materialization points: "+Arrays.toString(M.toArray(new Long[0]))); - //step 3: plan enumeration and plan selection + //step 3: create composite templates entries + createAndAddMultiAggPlans(memo, partition, R); + + //step 4: plan enumeration and plan selection selectPlans(memo, partition, R, M); } @@ -213,6 +217,74 @@ public class PlanSelectionFuseCostBased extends PlanSelection && partition.contains(hop.getHopID()); } + private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) + { + //create index of plans that reference full aggregates to avoid circular dependencies + HashSet<Long> refHops = new HashSet<Long>(); + for( Entry<Long, List<MemoTableEntry>> e : memo._plans.entrySet() ) + if( !e.getValue().isEmpty() ) { + Hop hop = memo._hopRefs.get(e.getKey()); + for( Hop c : hop.getInput() ) + refHops.add(c.getHopID()); + } + + //find all full aggregations (the fact that they are in the same partition guarantees + //that they also have common subexpressions, also full aggregations are by def root nodes) + ArrayList<Long> fullAggs = new ArrayList<Long>(); + for( Long hopID : R ) { + Hop root = memo._hopRefs.get(hopID); + if( !refHops.contains(hopID) && root instanceof AggUnaryOp + && ((AggUnaryOp)root).getDirection()==Direction.RowCol) + fullAggs.add(hopID); + } + if( LOG.isTraceEnabled() ) { + LOG.trace("Found ua(RC) aggregations: " + + Arrays.toString(fullAggs.toArray(new Long[0]))); + } + + //construct and add multiagg template plans (w/ max 3 aggregations) + for( int i=0; i<fullAggs.size(); i+=3 ) { + int ito = Math.min(i+3, fullAggs.size()); + if( ito-i >= 2 ) { + MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl, + fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1); + if( isValidMultiAggregate(memo, me) ) { + for( int j=i; j<ito; j++ ) { + memo.add(memo._hopRefs.get(fullAggs.get(j)), me); + if( LOG.isTraceEnabled() ) + LOG.trace("Added multiagg plan: "+fullAggs.get(j)+" "+me); + } + } + else if( LOG.isTraceEnabled() ) { + LOG.trace("Removed invalid multiagg plan: "+me); + } + } + } + } + + private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) { + //ensure that aggregates are independent of each other, i.e., + //they to not have potentially transitive parent child references + boolean ret = true; + for( int i=0; i<3; i++ ) + if( me.isPlanRef(i) ) { + HashSet<Long> probe = new HashSet<Long>(); + for( int j=0; j<3; j++ ) + if( i != j ) + probe.add(me.input(j)); + ret &= rCheckMultiAggregate(memo._hopRefs.get(me.input(i)), probe); + } + return ret; + } + + private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) { + boolean ret = true; + for( Hop c : current.getInput() ) + ret &= rCheckMultiAggregate(c, probe); + ret &= !probe.contains(current.getHopID()); + return ret; + } + private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) { //if no materialization points, use basic fuse-all w/ partition awareness @@ -275,7 +347,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection Iterator<MemoTableEntry> iter = memo.get(hopID).iterator(); while( iter.hasNext() ) { MemoTableEntry me = iter.next(); - if( !hasNoRefToMaterialization(me, M, plan) ){ + if( !hasNoRefToMaterialization(me, M, plan) && me.type!=TemplateType.OuterProdTpl ){ iter.remove(); if( LOG.isTraceEnabled() ) LOG.trace("Removed memo table entry: "+me); @@ -418,7 +490,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection if( LOG.isTraceEnabled() ) LOG.trace("Cost vector for fused operator: "+costVect); costs += costVect.outSize * 8 / WRITE_BANDWIDTH; //time for output write - costs += Math.min( + costs += Math.max( costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH, costVect.getSumInputSizes() * 8 / READ_BANDWIDTH); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java index 9d2466e..4fceb8a 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java @@ -26,11 +26,11 @@ import org.apache.sysml.runtime.matrix.data.Pair; public abstract class TemplateBase { public enum TemplateType { + //ordering specifies type preferences + MultiAggTpl, RowAggTpl, OuterProdTpl, CellTpl; - - //rank in preferred order public int getRank() { return this.ordinal(); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 58df56f..e3c12d5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -66,6 +66,11 @@ public class TemplateCell extends TemplateBase public TemplateCell(boolean closed) { super(TemplateType.CellTpl, closed); } + + public TemplateCell(TemplateType type, boolean closed) { + super(type, closed); + } + @Override public boolean open(Hop hop) { @@ -134,7 +139,7 @@ public class TemplateCell extends TemplateBase return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); } - private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) + protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) { //memoization for common subexpression elimination and to avoid redundant work if( tmp.containsKey(hop.getHopID()) ) @@ -269,7 +274,7 @@ public class TemplateCell extends TemplateBase tmp.put(hop.getHopID(), out); } - public static boolean isValidOperation(Hop hop) + protected static boolean isValidOperation(Hop hop) { //prepare indicators for binary operations boolean isBinaryMatrixScalar = false; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java new file mode 100644 index 0000000..aaf00e7 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.template; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; +import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; +import org.apache.sysml.runtime.matrix.data.Pair; + +public class TemplateMultiAgg extends TemplateCell +{ + public TemplateMultiAgg() { + super(TemplateType.MultiAggTpl, false); + } + + public TemplateMultiAgg(boolean closed) { + super(TemplateType.MultiAggTpl, closed); + } + + @Override + public boolean open(Hop hop) { + //multiagg is a composite templates, which is not + //created via open-fuse-merge-close + return false; + } + + @Override + public boolean fuse(Hop hop, Hop input) { + return false; + } + + @Override + public boolean merge(Hop hop, Hop input) { + return false; + } + + @Override + public CloseType close(Hop hop) { + return CloseType.CLOSED_INVALID; + } + + public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) + { + //get all root nodes for multi aggregation + MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MultiAggTpl); + ArrayList<Hop> roots = new ArrayList<Hop>(); + for( int i=0; i<3; i++ ) + if( multiAgg.isPlanRef(i) ) + roots.add(memo._hopRefs.get(multiAgg.input(i))); + Hop.resetVisitStatus(roots); + + //recursively process required cplan outputs + HashSet<Hop> inHops = new HashSet<Hop>(); + HashMap<Long, CNode> tmp = new HashMap<Long, CNode>(); + for( Hop root : roots ) //use celltpl cplan construction + super.rConstructCplan(root, memo, tmp, inHops, compileLiterals); + Hop.resetVisitStatus(roots); + + //reorder inputs (ensure matrices/vectors come first) and prune literals + //note: we order by number of cells and subsequently sparsity to ensure + //that sparse inputs are used as the main input w/o unnecessary conversion + List<Hop> sinHops = inHops.stream() + .filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())) + .sorted(new HopInputComparator()).collect(Collectors.toList()); + + //construct template node + ArrayList<CNode> inputs = new ArrayList<CNode>(); + for( Hop in : sinHops ) + inputs.add(tmp.get(in.getHopID())); + ArrayList<CNode> outputs = new ArrayList<CNode>(); + ArrayList<AggOp> aggOps = new ArrayList<AggOp>(); + for( Hop root : roots ) { + outputs.add(tmp.get(root.getHopID())); + aggOps.add(TemplateUtils.getAggOp(root)); + } + CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs); + tpl.setAggOps(aggOps); + tpl.setRootNodes(roots); + + // return cplan instance + return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index b959638..c6a259f 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -187,6 +187,7 @@ public class TemplateUtils switch( type ) { case CellTpl: tpl = new TemplateCell(closed); break; case RowAggTpl: tpl = new TemplateRowAgg(closed); break; + case MultiAggTpl: tpl = new TemplateMultiAgg(closed); break; case OuterProdTpl: tpl = new TemplateOuterProduct(closed); break; } return tpl; @@ -197,6 +198,7 @@ public class TemplateUtils switch( type ) { case CellTpl: tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRowAgg(closed)}; break; case RowAggTpl: tpl = new TemplateBase[]{new TemplateRowAgg(closed)}; break; + case MultiAggTpl: tpl = new TemplateBase[]{new TemplateMultiAgg(closed)}; break; case OuterProdTpl: tpl = new TemplateBase[]{new TemplateOuterProduct(closed)}; break; } return tpl; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 3857ca2..fcfc14b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -41,6 +41,7 @@ import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.MemoTable; @@ -547,6 +548,16 @@ public class HopRewriteUtils return pbop; } + public static Hop createScalarIndexing(Hop input, long rix, long cix) { + LiteralOp row = new LiteralOp(rix); + LiteralOp col = new LiteralOp(cix); + IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, row, row, col, col, true, true); + ix.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); + copyLineNumbers(input, ix); + ix.refreshSizeInformation(); + return createUnary(ix, OpOp1.CAST_AS_SCALAR); + } + public static Hop createValueHop( Hop hop, boolean row ) throws HopsException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java new file mode 100644 index 0000000..ea76ac5 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.codegen; + +import java.io.Serializable; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp; +import org.apache.sysml.runtime.functionobjects.Builtin; +import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; +import org.apache.sysml.runtime.functionobjects.KahanFunction; +import org.apache.sysml.runtime.functionobjects.KahanPlus; +import org.apache.sysml.runtime.functionobjects.KahanPlusSq; +import org.apache.sysml.runtime.functionobjects.ValueFunction; +import org.apache.sysml.runtime.instructions.cp.KahanObject; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.SparseBlock; +import org.apache.sysml.runtime.util.UtilFunctions; + +public abstract class SpoofMultiAggregate extends SpoofOperator implements Serializable +{ + private static final long serialVersionUID = -6164871955591089349L; + private static final long PAR_NUMCELL_THRESHOLD = 1024*1024; //Min 1M elements + + private final AggOp[] _aggOps; + + public SpoofMultiAggregate(AggOp... aggOps) { + _aggOps = aggOps; + } + + public AggOp[] getAggOps() { + return _aggOps; + } + + @Override + public String getSpoofType() { + return "MA" + getClass().getName().split("\\.")[1]; + } + + @Override + public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) + throws DMLRuntimeException + { + execute(inputs, scalarObjects, out, 1); + } + + @Override + public void execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out, int k) + throws DMLRuntimeException + { + //sanity check + if( inputs==null || inputs.size() < 1 ) + throw new RuntimeException("Invalid input arguments."); + + if( inputs.get(0).getNumRows()*inputs.get(0).getNumColumns()<PAR_NUMCELL_THRESHOLD ) { + k = 1; //serial execution + } + + //result allocation and preparations + out.reset(1, _aggOps.length, false); + out.allocateDenseBlock(); + double[] c = out.getDenseBlock(); + + //input preparation + double[][] b = prepInputMatrices(inputs); + double[] scalars = prepInputScalars(scalarObjects); + final int m = inputs.get(0).getNumRows(); + final int n = inputs.get(0).getNumColumns(); + + if( k <= 1 ) //SINGLE-THREADED + { + setInitialOutputValues(c); + if( !inputs.get(0).isInSparseFormat() ) + executeDense(inputs.get(0).getDenseBlock(), b, scalars, c, m, n, 0, m); + else + executeSparse(inputs.get(0).getSparseBlock(), b, scalars, c, m, n, 0, m); + } + else //MULTI-THREADED + { + try { + ExecutorService pool = Executors.newFixedThreadPool( k ); + ArrayList<ParAggTask> tasks = new ArrayList<ParAggTask>(); + int nk = UtilFunctions.roundToNext(Math.min(8*k,m/32), k); + int blklen = (int)(Math.ceil((double)m/nk)); + for( int i=0; i<nk & i*blklen<m; i++ ) + tasks.add(new ParAggTask(inputs.get(0), b, scalars, m, n, i*blklen, Math.min((i+1)*blklen, m))); + //execute tasks + List<Future<double[]>> taskret = pool.invokeAll(tasks); + pool.shutdown(); + + //aggregate partial results + ArrayList<double[]> pret = new ArrayList<double[]>(); + for( Future<double[]> task : taskret ) + pret.add(task.get()); + aggregatePartialResults(c, pret); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + } + + //post-processing + out.recomputeNonZeros(); + out.examSparsity(); + } + + private void executeDense(double[] a, double[][] b, double[] scalars, double[] c, int m, int n, int rl, int ru) throws DMLRuntimeException + { + //core dense aggregation operation + for( int i=rl, ix=rl*n; i<ru; i++ ) { + for( int j=0; j<n; j++, ix++ ) { + double in = (a != null) ? a[ix] : 0; + genexec( in, b, scalars, c, m, n, i, j ); + } + } + } + + private void executeSparse(SparseBlock sblock, double[][] b, double[] scalars, double[] c, int m, int n, int rl, int ru) + throws DMLRuntimeException + { + //core dense aggregation operation + for( int i=rl; i<ru; i++ ) + for( int j=0; j<n; j++ ) { + double in = (sblock != null) ? sblock.get(i, j) : 0; + genexec( in, b, scalars, c, m, n, i, j ); + } + } + + + protected abstract void genexec( double a, double[][] b, double[] scalars, double[] c, int m, int n, int rowIndex, int colIndex); + + + private void setInitialOutputValues(double[] c) { + for( int k=0; k<_aggOps.length; k++ ) { + switch(_aggOps[k]) { + case SUM: + case SUM_SQ: c[k] = 0; break; + case MIN: c[k] = Double.MAX_VALUE; break; + case MAX: c[k] = -Double.MAX_VALUE; break; + } + } + } + + + private void aggregatePartialResults(double[] c, ArrayList<double[]> pret) + throws DMLRuntimeException + { + ValueFunction[] vfun = getAggFunctions(_aggOps); + for( int k=0; k<_aggOps.length; k++ ) { + if( vfun[k] instanceof KahanFunction ) { + KahanObject kbuff = new KahanObject(0, 0); + KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); + for(double[] tmp : pret) + kplus.execute2(kbuff, tmp[k]); + c[k] = kbuff._sum; + } + else { + for(double[] tmp : pret) + c[k] = vfun[k].execute(c[k], tmp[k]); + } + } + } + + public static void aggregatePartialResults(AggOp[] aggOps, MatrixBlock c, MatrixBlock b) + throws DMLRuntimeException + { + ValueFunction[] vfun = getAggFunctions(aggOps); + + for( int k=0; k< aggOps.length; k++ ) { + if( vfun[k] instanceof KahanFunction ) { + KahanObject kbuff = new KahanObject(c.quickGetValue(0, k), 0); + KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); + kplus.execute2(kbuff, b.quickGetValue(0, k)); + c.quickSetValue(0, k, kbuff._sum); + } + else { + double cval = c.quickGetValue(0, k); + double bval = b.quickGetValue(0, k); + c.quickSetValue(0, k, vfun[k].execute(cval, bval)); + } + } + } + + public static ValueFunction[] getAggFunctions(AggOp[] aggOps) { + ValueFunction[] fun = new ValueFunction[aggOps.length]; + for( int i=0; i<aggOps.length; i++ ) { + switch( aggOps[i] ) { + case SUM: fun[i] = KahanPlus.getKahanPlusFnObject(); break; + case SUM_SQ: fun[i] = KahanPlusSq.getKahanPlusSqFnObject(); break; + case MIN: fun[i] = Builtin.getBuiltinFnObject(BuiltinCode.MIN); break; + case MAX: fun[i] = Builtin.getBuiltinFnObject(BuiltinCode.MAX); break; + default: + throw new RuntimeException("Unsupported " + + "aggregation type: "+aggOps[i].name()); + } + } + return fun; + } + + private class ParAggTask implements Callable<double[]> + { + private final MatrixBlock _a; + private final double[][] _b; + private final double[] _scalars; + private final int _rlen; + private final int _clen; + private final int _rl; + private final int _ru; + + protected ParAggTask( MatrixBlock a, double[][] b, double[] scalars, + int rlen, int clen, int rl, int ru ) { + _a = a; + _b = b; + _scalars = scalars; + _rlen = rlen; + _clen = clen; + _rl = rl; + _ru = ru; + } + + @Override + public double[] call() throws DMLRuntimeException { + double[] c = new double[_aggOps.length]; + setInitialOutputValues(c); + if( !_a.isInSparseFormat() ) + executeDense(_a.getDenseBlock(), _b, _scalars, c, _rlen, _clen, _rl, _ru); + else + executeSparse(_a.getSparseBlock(), _b, _scalars, c, _rlen, _clen, _rl, _ru); + return c; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index f203913..3b067dd 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.List; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType; @@ -31,6 +32,7 @@ import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.codegen.CodegenUtils; import org.apache.sysml.runtime.codegen.SpoofCellwise; +import org.apache.sysml.runtime.codegen.SpoofMultiAggregate; import org.apache.sysml.runtime.codegen.SpoofCellwise.AggOp; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.codegen.SpoofOperator; @@ -149,6 +151,18 @@ public class SpoofSPInstruction extends SPInstruction sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0))); } } + else if(_class.getSuperclass() == SpoofMultiAggregate.class) + { + SpoofMultiAggregate op = (SpoofMultiAggregate) CodegenUtils.createInstance(_class); + AggOp[] aggOps = op.getAggOps(); + + MatrixBlock tmpMB = in + .mapToPair(new MultiAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars)) + .values().fold(new MatrixBlock(), new MultiAggAggregateFunction(aggOps) ); + + sec.setMatrixOutput(_out.getName(), tmpMB); + return; + } else if(_class.getSuperclass() == SpoofOuterProduct.class) // outer product operator { if( _out.getDataType()==DataType.MATRIX ) { @@ -344,7 +358,88 @@ public class SpoofSPInstruction extends SPInstruction } return ret; } - } + } + + private static class MultiAggregateFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> + { + private static final long serialVersionUID = -5224519291577332734L; + + private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = null; + private ArrayList<ScalarObject> _scalars = null; + private byte[] _classBytes = null; + private String _className = null; + private SpoofOperator _op = null; + + public MultiAggregateFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) + throws DMLRuntimeException + { + _className = className; + _classBytes = classBytes; + _vectors = bcMatrices; + _scalars = scalars; + } + + @Override + public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg) + throws Exception + { + //lazy load of shipped class + if( _op == null ) { + Class<?> loadedClass = CodegenUtils.getClass(_className, _classBytes); + _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass); + } + + //execute core operation + ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(arg._2(), arg._1()); + MatrixBlock blkOut = new MatrixBlock(); + _op.execute(inputs, _scalars, blkOut); + + return new Tuple2<MatrixIndexes,MatrixBlock>(arg._1(), blkOut); + } + + private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, MatrixIndexes ixIn) + throws DMLRuntimeException + { + ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>(); + ret.add(blkIn); + for( PartitionedBroadcast<MatrixBlock> in : _vectors ) { + int rowIndex = (int)((in.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1); + int colIndex = (int)((in.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1); + ret.add(in.getBlock(rowIndex, colIndex)); + } + return ret; + } + } + + private static class MultiAggAggregateFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> + { + private static final long serialVersionUID = 5978731867787952513L; + + private AggOp[] _ops = null; + + public MultiAggAggregateFunction( AggOp[] ops ) { + _ops = ops; + } + + @Override + public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1) + throws Exception + { + //prepare combiner block + if( arg0.getNumRows() <= 0 || arg0.getNumColumns() <= 0) { + arg0.copy(arg1); + return arg0; + } + else if( arg1.getNumRows() <= 0 || arg1.getNumColumns() <= 0 ) { + return arg0; + } + + //aggregate second input (in-place) + SpoofMultiAggregate.aggregatePartialResults(_ops, arg0, arg1); + + return arg0; + } + } private static class OuterProductFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/main/java/org/apache/sysml/utils/Explain.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java index b78029a..b6e7b6f 100644 --- a/src/main/java/org/apache/sysml/utils/Explain.java +++ b/src/main/java/org/apache/sysml/utils/Explain.java @@ -32,6 +32,7 @@ import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.globalopt.gdfgraph.GDFLoopNode; import org.apache.sysml.hops.globalopt.gdfgraph.GDFNode; @@ -380,9 +381,13 @@ public class Explain sb.append("----------------------------------------\n"); //explain body dag - cplan.getOutput().resetVisitStatus(); - sb.append(explainCNode(cplan.getOutput(), 1)); - cplan.getOutput().resetVisitStatus(); + cplan.resetVisitStatusOutputs(); + if( cplan instanceof CNodeMultiAgg ) + for( CNode output : ((CNodeMultiAgg)cplan).getOutputs() ) + sb.append(explainCNode(output, 1)); + else + sb.append(explainCNode(cplan.getOutput(), 1)); + cplan.resetVisitStatusOutputs(); sb.append("----------------------------------------\n"); return sb.toString(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java index c6897cf..1a2ecfe 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java @@ -43,7 +43,7 @@ public class AlgorithmLinregCG extends AutomatedTestBase //TODO Investigate numerical stability issues: on certain platforms this test, occasionally fails, //for 1e-5 (specifically testLinregCGSparseRewritesCP); apparently due to the -(-(X)) -> X rewrite. - private final static double eps = 1e-2; + private final static double eps = 1e-1; private final static int rows = 2468; private final static int cols = 507; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index 521ea58..6468003 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -274,8 +274,8 @@ public class CellwiseTmplTest extends AutomatedTestBase } if( !(rewrites && testname.equals(TEST_NAME2)) ) //sigmoid - Assert.assertTrue(heavyHittersContainsSubString("spoofCell") - || heavyHittersContainsSubString("sp_spoofCell")); + Assert.assertTrue(heavyHittersContainsSubString( + "spoofCell", "sp_spoofCell", "spoofMA", "sp_spoofMA")); if( testname.equals(TEST_NAME7) ) //ensure matrix mult is fused Assert.assertTrue(!heavyHittersContainsSubString("tsmm")); else if( testname.equals(TEST_NAME10) ) //ensure min/max is fused http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java new file mode 100644 index 0000000..21aabb8 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.codegen; + +import java.io.File; +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class MultiAggTmplTest extends AutomatedTestBase +{ + private static final String TEST_NAME = "multiAggPattern"; + private static final String TEST_NAME1 = TEST_NAME+"1"; //min(X>7), max(X>7) + private static final String TEST_NAME2 = TEST_NAME+"2"; //sum(X>7), sum((X>7)^2) + + private static final String TEST_DIR = "functions/codegen/"; + private static final String TEST_CLASS_DIR = TEST_DIR + MultiAggTmplTest.class.getSimpleName() + "/"; + private final static String TEST_CONF = "SystemML-config-codegen.xml"; + private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF); + + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + for(int i=1; i<=2; i++) + addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); + } + + @Test + public void testCodegenMultiAggRewrite1CP() { + testCodegenIntegration( TEST_NAME1, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg1CP() { + testCodegenIntegration( TEST_NAME1, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg1Spark() { + testCodegenIntegration( TEST_NAME1, false, ExecType.SPARK ); + } + + @Test + public void testCodegenMultiAggRewrite2CP() { + testCodegenIntegration( TEST_NAME2, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg2CP() { + testCodegenIntegration( TEST_NAME2, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg2Spark() { + testCodegenIntegration( TEST_NAME2, false, ExecType.SPARK ); + } + + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + RUNTIME_PLATFORM platformOld = rtplatform; + switch( instType ) { + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-explain", "-stats", "-args", output("S") }; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("S"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + Assert.assertTrue(heavyHittersContainsSubString("spoofMA") + || heavyHittersContainsSubString("sp_spoofMA")); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true; + OptimizerUtils.ALLOW_OPERATOR_FUSION = true; + } + } + + /** + * Override default configuration with custom test configuration to ensure + * scratch space and local temporary directory locations are also updated. + */ + @Override + protected File getConfigTemplateFile() { + // Instrumentation in this test's output log to show custom configuration file used for template. + System.out.println("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); + return TEST_CONF_FILE; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/test/scripts/functions/codegen/multiAggPattern1.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/multiAggPattern1.R b/src/test/scripts/functions/codegen/multiAggPattern1.R new file mode 100644 index 0000000..4e0a22e --- /dev/null +++ b/src/test/scripts/functions/codegen/multiAggPattern1.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(1,15), 5, 3, byrow=TRUE); + +r1 = min(X>7); +r2 = max(X>7); +S = as.matrix(r1+r2); + +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/test/scripts/functions/codegen/multiAggPattern1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/multiAggPattern1.dml b/src/test/scripts/functions/codegen/multiAggPattern1.dml new file mode 100644 index 0000000..1f5dd96 --- /dev/null +++ b/src/test/scripts/functions/codegen/multiAggPattern1.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,15), rows=5, cols=3); + +r1 = min(X>7); +r2 = max(X>7); +S = as.matrix(r1+r2); + +write(S,$1) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/test/scripts/functions/codegen/multiAggPattern2.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/multiAggPattern2.R b/src/test/scripts/functions/codegen/multiAggPattern2.R new file mode 100644 index 0000000..379084a --- /dev/null +++ b/src/test/scripts/functions/codegen/multiAggPattern2.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(seq(1,15), 5, 3, byrow=TRUE); + +r1 = sum(X>7); +r2 = sum((X>7)^2); +S = as.matrix(r1+r2); + +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/test/scripts/functions/codegen/multiAggPattern2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/multiAggPattern2.dml b/src/test/scripts/functions/codegen/multiAggPattern2.dml new file mode 100644 index 0000000..4d4fb6d --- /dev/null +++ b/src/test/scripts/functions/codegen/multiAggPattern2.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,15), rows=5, cols=3); + +r1 = sum(X>7); +r2 = sum((X>7)^2); +S = as.matrix(r1+r2); + +write(S,$1) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/174bf7db/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java index abba49e..d063728 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java @@ -34,6 +34,7 @@ import org.junit.runners.Suite; AlgorithmPNMF.class, CellwiseTmplTest.class, DAGCellwiseTmplTest.class, + MultiAggTmplTest.class, OuterProdTmplTest.class, RowAggTmplTest.class, RowVectorComparisonTest.class,
