http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java new file mode 100644 index 0000000..2fa0de7 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -0,0 +1,1100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.opt; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map.Entry; +import java.util.stream.Collectors; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; + +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.hops.AggBinaryOp; +import org.apache.sysml.hops.AggUnaryOp; +import org.apache.sysml.hops.BinaryOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.AggOp; +import org.apache.sysml.hops.Hop.Direction; +import org.apache.sysml.hops.IndexingOp; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.ParameterizedBuiltinOp; +import org.apache.sysml.hops.ReorgOp; +import org.apache.sysml.hops.TernaryOp; +import org.apache.sysml.hops.UnaryOp; +import org.apache.sysml.hops.codegen.opt.ReachabilityGraph.SubProblem; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable; +import org.apache.sysml.hops.codegen.template.TemplateOuterProduct; +import org.apache.sysml.hops.codegen.template.TemplateRow; +import org.apache.sysml.hops.codegen.template.TemplateUtils; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; +import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; +import org.apache.sysml.runtime.codegen.LibSpoofPrimitives; +import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; +import org.apache.sysml.utils.Statistics; + +/** + * This cost-based plan selection algorithm chooses fused operators + * based on the DAG structure and resulting overall costs. This includes + * holistic decisions on + * <ul> + * <li>Materialization points per consumer</li> + * <li>Sparsity exploitation and operator ordering</li> + * <li>Decisions on overlapping template types</li> + * <li>Decisions on multi-aggregates with shared reads</li> + * <li>Constraints (e.g., memory budgets and block sizes)</li> + * </ul> + * + */ +public class PlanSelectionFuseCostBasedV2 extends PlanSelection +{ + private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBasedV2.class.getName()); + + //common bandwidth characteristics, with a conservative write bandwidth in order + //to cover result allocation, write into main memory, and potential evictions + private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024; //2GB/s + private static final double READ_BANDWIDTH = 32d*1024*1024*1024; //32GB/s + private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core + * InfrastructureAnalyzer.getLocalParallelism(); + + //sparsity estimate for unknown sparsity to prefer sparse-safe fusion plans + private static final double SPARSE_SAFE_SPARSITY_EST = 0.1; + + //optimizer configuration + private static final boolean USE_COST_PRUNING = true; + private static final boolean USE_STRUCTURAL_PRUNING = true; + + private static final IDSequence COST_ID = new IDSequence(); + private static final TemplateRow ROW_TPL = new TemplateRow(); + private static final BasicPlanComparator BASE_COMPARE = new BasicPlanComparator(); + private final TypedPlanComparator _typedCompare = new TypedPlanComparator(); + + @Override + public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) + { + //step 1: analyze connected partitions (nodes, roots, mat points) + Collection<PlanPartition> parts = PlanAnalyzer.analyzePlanPartitions(memo, roots, true); + + //step 2: optimize individual plan partitions + for( PlanPartition part : parts ) { + //create composite templates (within the partition) + createAndAddMultiAggPlans(memo, part.getPartition(), part.getRoots()); + + //plan enumeration and plan selection + selectPlans(memo, part); + } + + //step 3: add composite templates (across partitions) + createAndAddMultiAggPlans(memo, roots); + + //take all distinct best plans + for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) + memo.setDistinct(e.getKey(), e.getValue()); + } + + private void selectPlans(CPlanMemoTable memo, PlanPartition part) + { + //prune row aggregates with pure cellwise operations + for( Long hopID : part.getRoots() ) { + MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); + if( me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL) + && isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) { + List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); + memo.remove(memo.getHopRefs().get(hopID), new HashSet<MemoTableEntry>(blacklist)); + if( LOG.isTraceEnabled() ) { + LOG.trace("Removed row memo table entries w/o aggregation: " + + Arrays.toString(blacklist.toArray(new MemoTableEntry[0]))); + } + } + } + + //prune suboptimal outer product plans that are dominated by outer product plans w/ same number of + //references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), + //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern. + for( Long hopID : part.getPartition() ) { + if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) { + List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER); + MemoTableEntry me1 = entries.get(0); + MemoTableEntry me2 = entries.get(1); + MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2); + if( rmEntry != null ) { + memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry)); + memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex())); + if( LOG.isTraceEnabled() ) + LOG.trace("Removed dominated outer product memo table entry: " + rmEntry); + } + } + } + + //if no materialization points, use basic fuse-all w/ partition awareness + if( part.getMatPointsExt() == null || part.getMatPointsExt().length==0 ) { + for( Long hopID : part.getRoots() ) + rSelectPlansFuseAll(memo, + memo.getHopRefs().get(hopID), null, part.getPartition()); + } + else { + //obtain hop compute costs per cell once + HashMap<Long, Double> computeCosts = new HashMap<Long, Double>(); + for( Long hopID : part.getRoots() ) + rGetComputeCosts(memo.getHopRefs().get(hopID), part.getPartition(), computeCosts); + + //prepare pruning helpers and prune memo table w/ determined mat points + StaticCosts costs = new StaticCosts(computeCosts, getComputeCost(computeCosts, memo), + getReadCost(part, memo), getWriteCost(part.getRoots(), memo)); + ReachabilityGraph rgraph = USE_STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null; + if( USE_STRUCTURAL_PRUNING ) { + part.setMatPointsExt(rgraph.getSortedSearchSpace()); + for( Long hopID : part.getPartition() ) + memo.pruneRedundant(hopID, true, part.getMatPointsExt()); + } + + //enumerate and cost plans, returns optional plan + boolean[] bestPlan = enumPlans(memo, part, costs, rgraph, + part.getMatPointsExt(), 0, Double.MAX_VALUE); + + //prune memo table wrt best plan and select plans + HashSet<Long> visited = new HashSet<Long>(); + for( Long hopID : part.getRoots() ) + rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID), + visited, part, part.getMatPointsExt(), bestPlan); + HashSet<Long> visited2 = new HashSet<Long>(); + for( Long hopID : part.getRoots() ) + rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID), + visited2, part, bestPlan); + + for( Long hopID : part.getRoots() ) + rSelectPlansFuseAll(memo, + memo.getHopRefs().get(hopID), null, part.getPartition()); + } + } + + /** + * Core plan enumeration algorithm, invoked recursively for conditionally independent + * subproblems. This algorithm fully explores the exponential search space of 2^m, + * where m is the number of interesting materialization points. We iterate over + * a linearized search space without every instantiating the search tree. Furthermore, + * in order to reduce the enumeration overhead, we apply two high-impact pruning + * techniques (1) pruning by evolving lower/upper cost bounds, and (2) pruning by + * conditional structural properties (so-called cutsets of interesting points). + * + * @param memo memoization table of partial fusion plans + * @param part connected component (partition) of partial fusion plans with all necessary meta data + * @param costs summary of static costs (e.g., partition reads, writes, and compute costs per operator) + * @param rgraph reachability graph of interesting materialization points + * @param matPoints sorted materialization points (defined the search space) + * @param off offset for recursive invocation, indicating the fixed plan part + * @param bestC currently known best plan costs (used of upper bound) + * @return optimal assignment of materialization points + */ + private static boolean[] enumPlans(CPlanMemoTable memo, PlanPartition part, StaticCosts costs, + ReachabilityGraph rgraph, InterestingPoint[] matPoints, int off, double bestC) + { + //scan linearized search space, w/ skips for branch and bound pruning + //and structural pruning (where we solve conditionally independent problems) + //bestC is monotonically non-increasing and serves as the upper bound + long len = (long)Math.pow(2, matPoints.length-off); + boolean[] bestPlan = null; + int numEvalPlans = 0; + + for( long i=0; i<len; i++ ) { + //construct assignment + boolean[] plan = createAssignment(matPoints.length-off, off, i); + long pskip = 0; //skip after costing + + //skip plans with structural pruning + if( USE_STRUCTURAL_PRUNING && (rgraph!=null) && rgraph.isCutSet(plan) ) { + //compute skip (which also acts as boundary for subproblems) + pskip = rgraph.getNumSkipPlans(plan); + + //start increment rgraph get subproblems + SubProblem[] prob = rgraph.getSubproblems(plan); + + //solve subproblems independently and combine into best plan + for( int j=0; j<prob.length; j++ ) { + boolean[] bestTmp = enumPlans(memo, part, + costs, null, prob[j].freeMat, prob[j].offset, bestC); + LibSpoofPrimitives.vectWrite(bestTmp, plan, prob[j].freePos); + } + + //note: the overall plan costs are evaluated in full, which reused + //the default code path; hence we postpone the skip after costing + } + //skip plans with branch and bound pruning (cost) + else if( USE_COST_PRUNING ) { + double lbC = Math.max(costs._read, costs._compute) + costs._write + + getMaterializationCost(part, matPoints, memo, plan); + if( lbC >= bestC ) { + long skip = getNumSkipPlans(plan); + if( LOG.isTraceEnabled() ) + LOG.trace("Enum: Skip "+skip+" plans (by cost)."); + i += skip - 1; + continue; + } + } + + //cost assignment on hops + double C = getPlanCost(memo, part, matPoints, plan, costs._computeCosts); + numEvalPlans ++; + if( LOG.isTraceEnabled() ) + LOG.trace("Enum: "+Arrays.toString(plan)+" -> "+C); + + //cost comparisons + if( bestPlan == null || C < bestC ) { + bestC = C; + bestPlan = plan; + if( LOG.isTraceEnabled() ) + LOG.trace("Enum: Found new best plan."); + } + + //post skipping + i += pskip; + if( pskip !=0 && LOG.isTraceEnabled() ) + LOG.trace("Enum: Skip "+pskip+" plans (by structure)."); + } + + if( DMLScript.STATISTICS ) + Statistics.incrementCodegenFPlanCompile(numEvalPlans); + if( LOG.isTraceEnabled() ) + LOG.trace("Enum: Optimal plan: "+Arrays.toString(bestPlan)); + + //copy best plan w/o fixed offset plan + return Arrays.copyOfRange(bestPlan, off, bestPlan.length); + } + + private static boolean[] createAssignment(int len, int off, long pos) { + boolean[] ret = new boolean[off+len]; + Arrays.fill(ret, 0, off, true); + long tmp = pos; + for( int i=0; i<len; i++ ) { + ret[off+i] = (tmp >= Math.pow(2, len-i-1)); + tmp %= Math.pow(2, len-i-1); + } + return ret; + } + + private static long getNumSkipPlans(boolean[] plan) { + int pos = ArrayUtils.lastIndexOf(plan, true); + return (long) Math.pow(2, plan.length-pos-1); + } + + private static double getMaterializationCost(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, boolean[] plan) { + double costs = 0; + //currently active materialization points + HashSet<Long> matTargets = new HashSet<>(); + for( int i=0; i<plan.length; i++ ) { + long hopID = M[i].getToHopID(); + if( plan[i] && !matTargets.contains(hopID) ) { + matTargets.add(hopID); + Hop hop = memo.getHopRefs().get(hopID); + long size = getSize(hop); + costs += size * 8 / WRITE_BANDWIDTH + + size * 8 / READ_BANDWIDTH; + } + } + //points with non-partition consumers + for( Long hopID : part.getExtConsumed() ) + if( !matTargets.contains(hopID) ) { + matTargets.add(hopID); + Hop hop = memo.getHopRefs().get(hopID); + costs += getSize(hop) * 8 / WRITE_BANDWIDTH; + } + + return costs; + } + + private static double getReadCost(PlanPartition part, CPlanMemoTable memo) { + double costs = 0; + //get partition input reads (at least read once) + for( Long hopID : part.getInputs() ) { + Hop hop = memo.getHopRefs().get(hopID); + costs += getSize(hop) * 8 / READ_BANDWIDTH; + } + return costs; + } + + private static double getWriteCost(Collection<Long> R, CPlanMemoTable memo) { + double costs = 0; + for( Long hopID : R ) { + Hop hop = memo.getHopRefs().get(hopID); + costs += getSize(hop) * 8 / WRITE_BANDWIDTH; + } + return costs; + } + + private static double getComputeCost(HashMap<Long, Double> computeCosts, CPlanMemoTable memo) { + double costs = 0; + for( Entry<Long,Double> e : computeCosts.entrySet() ) { + Hop mainInput = memo.getHopRefs() + .get(e.getKey()).getInput().get(0); + costs += getSize(mainInput) * e.getValue() / COMPUTE_BANDWIDTH; + } + return costs; + } + + private static long getSize(Hop hop) { + return Math.max(hop.getDim1(),1) + * Math.max(hop.getDim2(),1); + } + + //within-partition multi-agg templates + private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) + { + //create index of plans that reference full aggregates to avoid circular dependencies + HashSet<Long> refHops = new HashSet<Long>(); + for( Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet() ) + if( !e.getValue().isEmpty() ) { + Hop hop = memo.getHopRefs().get(e.getKey()); + for( Hop c : hop.getInput() ) + refHops.add(c.getHopID()); + } + + //find all full aggregations (the fact that they are in the same partition guarantees + //that they also have common subexpressions, also full aggregations are by def root nodes) + ArrayList<Long> fullAggs = new ArrayList<Long>(); + for( Long hopID : R ) { + Hop root = memo.getHopRefs().get(hopID); + if( !refHops.contains(hopID) && isMultiAggregateRoot(root) ) + fullAggs.add(hopID); + } + if( LOG.isTraceEnabled() ) { + LOG.trace("Found within-partition ua(RC) aggregations: " + + Arrays.toString(fullAggs.toArray(new Long[0]))); + } + + //construct and add multiagg template plans (w/ max 3 aggregations) + for( int i=0; i<fullAggs.size(); i+=3 ) { + int ito = Math.min(i+3, fullAggs.size()); + if( ito-i >= 2 ) { + MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, + fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1, ito-i); + if( isValidMultiAggregate(memo, me) ) { + for( int j=i; j<ito; j++ ) { + memo.add(memo.getHopRefs().get(fullAggs.get(j)), me); + if( LOG.isTraceEnabled() ) + LOG.trace("Added multiagg plan: "+fullAggs.get(j)+" "+me); + } + } + else if( LOG.isTraceEnabled() ) { + LOG.trace("Removed invalid multiagg plan: "+me); + } + } + } + } + + //across-partition multi-agg templates with shared reads + private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) + { + //collect full aggregations as initial set of candidates + HashSet<Long> fullAggs = new HashSet<Long>(); + Hop.resetVisitStatus(roots); + for( Hop hop : roots ) + rCollectFullAggregates(hop, fullAggs); + Hop.resetVisitStatus(roots); + + //remove operators with assigned multi-agg plans + fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG)); + + //check applicability for further analysis + if( fullAggs.size() <= 1 ) + return; + + if( LOG.isTraceEnabled() ) { + LOG.trace("Found across-partition ua(RC) aggregations: " + + Arrays.toString(fullAggs.toArray(new Long[0]))); + } + + //collect information for all candidates + //(subsumed aggregations, and inputs to fused operators) + List<AggregateInfo> aggInfos = new ArrayList<AggregateInfo>(); + for( Long hopID : fullAggs ) { + Hop aggHop = memo.getHopRefs().get(hopID); + AggregateInfo tmp = new AggregateInfo(aggHop); + for( int i=0; i<aggHop.getInput().size(); i++ ) { + Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? + aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i); + rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL); + } + if( tmp._fusedInputs.isEmpty() ) { + if( HopRewriteUtils.isMatrixMultiply(aggHop) ) { + tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID()); + tmp.addFusedInput(aggHop.getInput().get(1).getHopID()); + } + else + tmp.addFusedInput(aggHop.getInput().get(0).getHopID()); + } + aggInfos.add(tmp); + } + + if( LOG.isTraceEnabled() ) { + LOG.trace("Extracted across-partition ua(RC) aggregation info: "); + for( AggregateInfo info : aggInfos ) + LOG.trace(info); + } + + //sort aggregations by num dependencies to simplify merging + //clusters of aggregations with parallel dependencies + aggInfos = aggInfos.stream() + .sorted(Comparator.comparing(a -> a._inputAggs.size())) + .collect(Collectors.toList()); + + //greedy grouping of multi-agg candidates + boolean converged = false; + while( !converged ) { + AggregateInfo merged = null; + for( int i=0; i<aggInfos.size(); i++ ) { + AggregateInfo current = aggInfos.get(i); + for( int j=i+1; j<aggInfos.size(); j++ ) { + AggregateInfo that = aggInfos.get(j); + if( current.isMergable(that) ) { + merged = current.merge(that); + aggInfos.remove(j); j--; + } + } + } + converged = (merged == null); + } + + if( LOG.isTraceEnabled() ) { + LOG.trace("Merged across-partition ua(RC) aggregation info: "); + for( AggregateInfo info : aggInfos ) + LOG.trace(info); + } + + //construct and add multiagg template plans (w/ max 3 aggregations) + for( AggregateInfo info : aggInfos ) { + if( info._aggregates.size()<=1 ) + continue; + Long[] aggs = info._aggregates.keySet().toArray(new Long[0]); + MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, + aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, aggs.length); + for( int i=0; i<aggs.length; i++ ) { + memo.add(memo.getHopRefs().get(aggs[i]), me); + addBestPlan(aggs[i], me); + if( LOG.isTraceEnabled() ) + LOG.trace("Added multiagg* plan: "+aggs[i]+" "+me); + + } + } + } + + private static boolean isMultiAggregateRoot(Hop root) { + return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) + && ((AggUnaryOp)root).getDirection()==Direction.RowCol) + || (root instanceof AggBinaryOp && root.getDim1()==1 && root.getDim2()==1 + && HopRewriteUtils.isTransposeOperation(root.getInput().get(0))); + } + + private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) { + //ensure input consistent sizes (otherwise potential for incorrect results) + boolean ret = true; + Hop refSize = memo.getHopRefs().get(me.input1).getInput().get(0); + for( int i=1; ret && i<3; i++ ) { + if( me.isPlanRef(i) ) + ret &= HopRewriteUtils.isEqualSize(refSize, + memo.getHopRefs().get(me.input(i)).getInput().get(0)); + } + + //ensure that aggregates are independent of each other, i.e., + //they to not have potentially transitive parent child references + for( int i=0; ret && i<3; i++ ) + if( me.isPlanRef(i) ) { + HashSet<Long> probe = new HashSet<Long>(); + for( int j=0; j<3; j++ ) + if( i != j ) + probe.add(me.input(j)); + ret &= rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe); + } + return ret; + } + + private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) { + boolean ret = true; + for( Hop c : current.getInput() ) + ret &= rCheckMultiAggregate(c, probe); + ret &= !probe.contains(current.getHopID()); + return ret; + } + + private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) { + if( current.isVisited() ) + return; + + //collect all applicable full aggregations per read + if( isMultiAggregateRoot(current) ) + aggs.add(current.getHopID()); + + //recursively process children + for( Hop c : current.getInput() ) + rCollectFullAggregates(c, aggs); + + current.setVisited(); + } + + private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateType type) { + //collect input aggregates (dependents) + if( isMultiAggregateRoot(current) ) + aggInfo.addInputAggregate(current.getHopID()); + + //recursively process children + MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null; + for( int i=0; i<current.getInput().size(); i++ ) { + Hop c = current.getInput().get(i); + if( me != null && me.isPlanRef(i) ) + rExtractAggregateInfo(memo, c, aggInfo, type); + else { + if( type != null && c.getDataType().isMatrix() ) //add fused input + aggInfo.addFusedInput(c.getHopID()); + rExtractAggregateInfo(memo, c, aggInfo, null); + } + } + } + + private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { + //consider all aggregations other than root operation + MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); + boolean ret = true; + for(int i=0; i<3; i++) + if( me.isPlanRef(i) ) + ret &= rIsRowTemplateWithoutAgg(memo, + current.getInput().get(i), visited); + return ret; + } + + private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { + if( visited.contains(current.getHopID()) ) + return true; + + boolean ret = true; + MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); + for(int i=0; i<3; i++) + if( me.isPlanRef(i) ) + ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited); + ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp); + + visited.add(current.getHopID()); + return ret; + } + + private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, + PlanPartition part, InterestingPoint[] matPoints, boolean[] plan) + { + //memoization (not via hops because in middle of dag) + if( visited.contains(current.getHopID()) ) + return; + + //remove memo table entries if necessary + long hopID = current.getHopID(); + if( part.getPartition().contains(hopID) && memo.contains(hopID) ) { + Iterator<MemoTableEntry> iter = memo.get(hopID).iterator(); + while( iter.hasNext() ) { + MemoTableEntry me = iter.next(); + if( !hasNoRefToMatPoint(hopID, me, matPoints, plan) && me.type!=TemplateType.OUTER ) { + iter.remove(); + if( LOG.isTraceEnabled() ) + LOG.trace("Removed memo table entry: "+me); + } + } + } + + //process children recursively + for( Hop c : current.getInput() ) + rPruneSuboptimalPlans(memo, c, visited, part, matPoints, plan); + + visited.add(current.getHopID()); + } + + private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, boolean[] plan) { + //memoization (not via hops because in middle of dag) + if( visited.contains(current.getHopID()) ) + return; + + //process children recursively + for( Hop c : current.getInput() ) + rPruneInvalidPlans(memo, c, visited, part, plan); + + //find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs, + //i.e., plans that become invalid after the previous pruning step + long hopID = current.getHopID(); + if( part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW) ) { + for( MemoTableEntry me : memo.get(hopID) ) { + if( me.type==TemplateType.ROW ) { + //convert leaf node with pure vector inputs + if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) { + me.type = TemplateType.CELL; + if( LOG.isTraceEnabled() ) + LOG.trace("Converted leaf memo table entry from row to cell: "+me); + } + + //convert inner node without row template input + if( me.hasPlanRef() && !ROW_TPL.open(current) ) { + boolean hasRowInput = false; + for( int i=0; i<3; i++ ) + if( me.isPlanRef(i) ) + hasRowInput |= memo.contains(me.input(i), TemplateType.ROW); + if( !hasRowInput ) { + me.type = TemplateType.CELL; + if( LOG.isTraceEnabled() ) + LOG.trace("Converted inner memo table entry from row to cell: "+me); + } + } + + } + } + } + + visited.add(current.getHopID()); + } + + private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) + { + if( isVisited(current.getHopID(), currentType) + || !partition.contains(current.getHopID()) ) + return; + + //step 1: prune subsumed plans of same type + if( memo.contains(current.getHopID()) ) { + HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); + List<MemoTableEntry> hopP = memo.get(current.getHopID()); + for( MemoTableEntry e1 : hopP ) + for( MemoTableEntry e2 : hopP ) + if( e1 != e2 && e1.subsumes(e2) ) + rmSet.add(e2); + memo.remove(current, rmSet); + } + + //step 2: select plan for current path + MemoTableEntry best = null; + if( memo.contains(current.getHopID()) ) { + if( currentType == null ) { + best = memo.get(current.getHopID()).stream() + .filter(p -> isValid(p, current)) + .min(BASE_COMPARE).orElse(null); + } + else { + _typedCompare.setType(currentType); + best = memo.get(current.getHopID()).stream() + .filter(p -> p.type==currentType || p.type==TemplateType.CELL) + .min(_typedCompare).orElse(null); + } + addBestPlan(current.getHopID(), best); + } + + //step 3: recursively process children + for( int i=0; i< current.getInput().size(); i++ ) { + TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; + rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition); + } + + setVisited(current.getHopID(), currentType); + } + + ///////////////////////////////////////////////////////// + // Cost model fused operators w/ materialization points + ////////// + + private static double getPlanCost(CPlanMemoTable memo, PlanPartition part, + InterestingPoint[] matPoints,boolean[] plan, HashMap<Long, Double> computeCosts) + { + //high level heuristic: every hop or fused operator has the following cost: + //WRITE + max(COMPUTE, READ), where WRITE costs are given by the output size, + //READ costs by the input sizes, and COMPUTE by operation specific FLOP + //counts times number of cells of main input, disregarding sparsity for now. + + HashSet<VisitMarkCost> visited = new HashSet<>(); + double costs = 0; + for( Long hopID : part.getRoots() ) { + costs += rGetPlanCosts(memo, memo.getHopRefs().get(hopID), + visited, part, matPoints, plan, computeCosts, null, null); + } + return costs; + } + + private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<VisitMarkCost> visited, + PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts, + CostVector costsCurrent, TemplateType currentType) + { + //memoization per hop id and cost vector to account for redundant + //computation without double counting materialized results or compute + //costs of complex operation DAGs within a single fused operator + VisitMarkCost tag = new VisitMarkCost(current.getHopID(), + (costsCurrent==null || currentType==TemplateType.MAGG)?0:costsCurrent.ID); + if( visited.contains(tag) ) + return 0; + visited.add(tag); + + //open template if necessary, including memoization + //under awareness of current plan choice + MemoTableEntry best = null; + boolean opened = false; + if( memo.contains(current.getHopID()) ) { + //note: this is the inner loop of plan enumeration and hence, we do not + //use streams, lambda expressions, etc to avoid unnecessary overhead + long hopID = current.getHopID(); + if( currentType == null ) { + for( MemoTableEntry me : memo.get(hopID) ) + best = isValid(me, current) + && hasNoRefToMatPoint(hopID, me, matPoints, plan) + && BasicPlanComparator.icompare(me, best)<0 ? me : best; + opened = true; + } + else { + for( MemoTableEntry me : memo.get(hopID) ) + best = (me.type == currentType || me.type==TemplateType.CELL) + && hasNoRefToMatPoint(hopID, me, matPoints, plan) + && TypedPlanComparator.icompare(me, best, currentType)<0 ? me : best; + } + } + + //create new cost vector if opened, initialized with write costs + CostVector costVect = !opened ? costsCurrent : new CostVector(getSize(current)); + double costs = 0; + + //add other roots for multi-agg template to account for shared costs + if( opened && best != null && best.type == TemplateType.MAGG ) { + //account costs to first multi-agg root + if( best.input1 == current.getHopID() ) + for( int i=1; i<3; i++ ) { + if( !best.isPlanRef(i) ) continue; + costs += rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited, + part, matPoints, plan, computeCosts, costVect, TemplateType.MAGG); + } + //skip other multi-agg roots + else + return 0; + } + + //add compute costs of current operator to costs vector + if( part.getPartition().contains(current.getHopID()) ) + costVect.computeCosts += computeCosts.get(current.getHopID()); + + //process children recursively + for( int i=0; i< current.getInput().size(); i++ ) { + Hop c = current.getInput().get(i); + if( best!=null && best.isPlanRef(i) ) + costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, costVect, best.type); + else if( best!=null && isImplicitlyFused(current, i, best.type) ) + costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c)); + else { //include children and I/O costs + costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null); + if( costVect != null && c.getDataType().isMatrix() ) + costVect.addInputSize(c.getHopID(), getSize(c)); + } + } + + //add costs for opened fused operator + if( part.getPartition().contains(current.getHopID()) ) { + if( opened ) { + if( LOG.isTraceEnabled() ) { + String type = (best !=null) ? best.type.name() : "HOP"; + LOG.trace("Cost vector ("+type+" "+current.getHopID()+"): "+costVect); + } + double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH //time for output write + + Math.max(costVect.getSumInputSizes() * 8 / READ_BANDWIDTH, + costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH); + //sparsity correction for outer-product template (and sparse-safe cell) + if( best != null && best.type == TemplateType.OUTER ) { + Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); + tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST; + } + costs += tmpCosts; + } + //add costs for non-partition read in the middle of fused operator + else if( part.getExtConsumed().contains(current.getHopID()) ) { + costs += rGetPlanCosts(memo, current, visited, + part, matPoints, plan, computeCosts, null, null); + } + } + + //sanity check non-negative costs + if( costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs) ) + throw new RuntimeException("Wrong cost estimate: "+costs); + + return costs; + } + + private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) + { + if( computeCosts.containsKey(current.getHopID()) + || !partition.contains(current.getHopID()) ) + return; + + //recursively process children + for( Hop c : current.getInput() ) + rGetComputeCosts(c, partition, computeCosts); + + //get costs for given hop + double costs = 1; + if( current instanceof UnaryOp ) { + switch( ((UnaryOp)current).getOp() ) { + case ABS: + case ROUND: + case CEIL: + case FLOOR: + case SIGN: + case SELP: costs = 1; break; + case SPROP: + case SQRT: costs = 2; break; + case EXP: costs = 18; break; + case SIGMOID: costs = 21; break; + case LOG: + case LOG_NZ: costs = 32; break; + case NCOL: + case NROW: + case PRINT: + case CAST_AS_BOOLEAN: + case CAST_AS_DOUBLE: + case CAST_AS_INT: + case CAST_AS_MATRIX: + case CAST_AS_SCALAR: costs = 1; break; + case SIN: costs = 18; break; + case COS: costs = 22; break; + case TAN: costs = 42; break; + case ASIN: costs = 93; break; + case ACOS: costs = 103; break; + case ATAN: costs = 40; break; + case CUMSUM: + case CUMMIN: + case CUMMAX: + case CUMPROD: costs = 1; break; + default: + LOG.warn("Cost model not " + + "implemented yet for: "+((UnaryOp)current).getOp()); + } + } + else if( current instanceof BinaryOp ) { + switch( ((BinaryOp)current).getOp() ) { + case MULT: + case PLUS: + case MINUS: + case MIN: + case MAX: + case AND: + case OR: + case EQUAL: + case NOTEQUAL: + case LESS: + case LESSEQUAL: + case GREATER: + case GREATEREQUAL: + case CBIND: + case RBIND: costs = 1; break; + case INTDIV: costs = 6; break; + case MODULUS: costs = 8; break; + case DIV: costs = 22; break; + case LOG: + case LOG_NZ: costs = 32; break; + case POW: costs = (HopRewriteUtils.isLiteralOfValue( + current.getInput().get(1), 2) ? 1 : 16); break; + case MINUS_NZ: + case MINUS1_MULT: costs = 2; break; + case CENTRALMOMENT: + int type = (int) (current.getInput().get(1) instanceof LiteralOp ? + HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2); + switch( type ) { + case 0: costs = 1; break; //count + case 1: costs = 8; break; //mean + case 2: costs = 16; break; //cm2 + case 3: costs = 31; break; //cm3 + case 4: costs = 51; break; //cm4 + case 5: costs = 16; break; //variance + } + break; + case COVARIANCE: costs = 23; break; + default: + LOG.warn("Cost model not " + + "implemented yet for: "+((BinaryOp)current).getOp()); + } + } + else if( current instanceof TernaryOp ) { + switch( ((TernaryOp)current).getOp() ) { + case PLUS_MULT: + case MINUS_MULT: costs = 2; break; + case CTABLE: costs = 3; break; + case CENTRALMOMENT: + int type = (int) (current.getInput().get(1) instanceof LiteralOp ? + HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2); + switch( type ) { + case 0: costs = 2; break; //count + case 1: costs = 9; break; //mean + case 2: costs = 17; break; //cm2 + case 3: costs = 32; break; //cm3 + case 4: costs = 52; break; //cm4 + case 5: costs = 17; break; //variance + } + break; + case COVARIANCE: costs = 23; break; + default: + LOG.warn("Cost model not " + + "implemented yet for: "+((TernaryOp)current).getOp()); + } + } + else if( current instanceof ParameterizedBuiltinOp ) { + costs = 1; + } + else if( current instanceof IndexingOp ) { + costs = 1; + } + else if( current instanceof ReorgOp ) { + costs = 1; + } + else if( current instanceof AggBinaryOp ) { + //outer product template + if( HopRewriteUtils.isOuterProductLikeMM(current) ) + costs = 2 * current.getInput().get(0).getDim2(); + //row template w/ matrix-vector or matrix-matrix + else + costs = 2 * current .getDim2(); + } + else if( current instanceof AggUnaryOp) { + switch(((AggUnaryOp)current).getOp()) { + case SUM: costs = 4; break; + case SUM_SQ: costs = 5; break; + case MIN: + case MAX: costs = 1; break; + default: + LOG.warn("Cost model not " + + "implemented yet for: "+((AggUnaryOp)current).getOp()); + } + } + + computeCosts.put(current.getHopID(), costs); + } + + private static boolean hasNoRefToMatPoint(long hopID, + MemoTableEntry me, InterestingPoint[] M, boolean[] plan) { + return !InterestingPoint.isMatPoint(M, hopID, me, plan); + } + + private static boolean isImplicitlyFused(Hop hop, int index, TemplateType type) { + return type == TemplateType.ROW + && HopRewriteUtils.isMatrixMultiply(hop) && index==0 + && HopRewriteUtils.isTransposeOperation(hop.getInput().get(index)); + } + + private static class CostVector { + public final long ID; + public final double outSize; + public double computeCosts = 0; + public final HashMap<Long, Double> inSizes = new HashMap<Long, Double>(); + + public CostVector(double outputSize) { + ID = COST_ID.getNextID(); + outSize = outputSize; + } + public void addInputSize(long hopID, double inputSize) { + //ensures that input sizes are not double counted + inSizes.put(hopID, inputSize); + } + public double getSumInputSizes() { + return inSizes.values().stream() + .mapToDouble(d -> d.doubleValue()).sum(); + } + public double getMaxInputSize() { + return inSizes.values().stream() + .mapToDouble(d -> d.doubleValue()).max().orElse(0); + } + public long getMaxInputSizeHopID() { + long id = -1; double max = 0; + for( Entry<Long,Double> e : inSizes.entrySet() ) + if( max < e.getValue() ) { + id = e.getKey(); + max = e.getValue(); + } + return id; + } + @Override + public String toString() { + return "["+outSize+", "+computeCosts+", {" + +Arrays.toString(inSizes.keySet().toArray(new Long[0]))+", " + +Arrays.toString(inSizes.values().toArray(new Double[0]))+"}]"; + } + } + + private static class StaticCosts { + public final HashMap<Long, Double> _computeCosts; + public final double _compute; + public final double _read; + public final double _write; + + public StaticCosts(HashMap<Long,Double> allComputeCosts, double computeCost, double readCost, double writeCost) { + _computeCosts = allComputeCosts; + _compute = computeCost; + _read = readCost; + _write = writeCost; + } + } + + private static class AggregateInfo { + public final HashMap<Long,Hop> _aggregates; + public final HashSet<Long> _inputAggs = new HashSet<Long>(); + public final HashSet<Long> _fusedInputs = new HashSet<Long>(); + public AggregateInfo(Hop aggregate) { + _aggregates = new HashMap<Long, Hop>(); + _aggregates.put(aggregate.getHopID(), aggregate); + } + public void addInputAggregate(long hopID) { + _inputAggs.add(hopID); + } + public void addFusedInput(long hopID) { + _fusedInputs.add(hopID); + } + public boolean isMergable(AggregateInfo that) { + //check independence + boolean ret = _aggregates.size()<3 + && _aggregates.size()+that._aggregates.size()<=3; + for( Long hopID : that._aggregates.keySet() ) + ret &= !_inputAggs.contains(hopID); + for( Long hopID : _aggregates.keySet() ) + ret &= !that._inputAggs.contains(hopID); + //check partial shared reads + ret &= !CollectionUtils.intersection( + _fusedInputs, that._fusedInputs).isEmpty(); + //check consistent sizes (result correctness) + Hop in1 = _aggregates.values().iterator().next(); + Hop in2 = that._aggregates.values().iterator().next(); + return ret && HopRewriteUtils.isEqualSize( + in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0), + in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0)); + } + public AggregateInfo merge(AggregateInfo that) { + _aggregates.putAll(that._aggregates); + _inputAggs.addAll(that._inputAggs); + _fusedInputs.addAll(that._fusedInputs); + return this; + } + @Override + public String toString() { + return "["+Arrays.toString(_aggregates.keySet().toArray(new Long[0]))+": " + +"{"+Arrays.toString(_inputAggs.toArray(new Long[0]))+"}," + +"{"+Arrays.toString(_fusedInputs.toArray(new Long[0]))+"}]"; + } + } +}
http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java new file mode 100644 index 0000000..759a903 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.opt; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Map.Entry; +import java.util.HashSet; +import java.util.List; + +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; +import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; + +/** + * This plan selection heuristic aims for fusion without any redundant + * computation, which, however, potentially leads to more materialized + * intermediates than the fuse all heuristic. + * <p> + * NOTE: This heuristic is essentially the same as FuseAll, except that + * any plans that refer to a hop with multiple consumers are removed in + * a pre-processing step. + * + */ +public class PlanSelectionFuseNoRedundancy extends PlanSelection +{ + @Override + public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) { + //pruning and collection pass + for( Hop hop : roots ) + rSelectPlans(memo, hop, null); + + //take all distinct best plans + for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) + memo.setDistinct(e.getKey(), e.getValue()); + } + + private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) + { + if( isVisited(current.getHopID(), currentType) ) + return; + + //step 0: remove plans that refer to a common partial plan + if( memo.contains(current.getHopID()) ) { + HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); + List<MemoTableEntry> hopP = memo.get(current.getHopID()); + for( MemoTableEntry e1 : hopP ) + for( int i=0; i<3; i++ ) + if( e1.isPlanRef(i) && current.getInput().get(i).getParent().size()>1 ) + rmSet.add(e1); //remove references to hops w/ multiple consumers + memo.remove(current, rmSet); + } + + //step 1: prune subsumed plans of same type + if( memo.contains(current.getHopID()) ) { + HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); + List<MemoTableEntry> hopP = memo.get(current.getHopID()); + for( MemoTableEntry e1 : hopP ) + for( MemoTableEntry e2 : hopP ) + if( e1 != e2 && e1.subsumes(e2) ) + rmSet.add(e2); + memo.remove(current, rmSet); + } + + //step 2: select plan for current path + MemoTableEntry best = null; + if( memo.contains(current.getHopID()) ) { + if( currentType == null ) { + best = memo.get(current.getHopID()).stream() + .filter(p -> isValid(p, current)) + .min(new BasicPlanComparator()).orElse(null); + } + else { + best = memo.get(current.getHopID()).stream() + .filter(p -> p.type==currentType || p.type==TemplateType.CELL) + .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) + .orElse(null); + } + addBestPlan(current.getHopID(), best); + } + + //step 3: recursively process children + for( int i=0; i< current.getInput().size(); i++ ) { + TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; + rSelectPlans(memo, current.getInput().get(i), pref); + } + + setVisited(current.getHopID(), currentType); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java new file mode 100644 index 0000000..de1ed92 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.codegen.opt; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang.ArrayUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable; +import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; +import org.apache.sysml.hops.codegen.opt.PlanSelection.VisitMarkCost; + +/** + * + */ +public class ReachabilityGraph +{ + private HashMap<Pair<Long,Long>,NodeLink> _matPoints = null; + private NodeLink _root = null; + + private InterestingPoint[] _searchSpace; + private CutSet[] _cutSets; + + public ReachabilityGraph(PlanPartition part, CPlanMemoTable memo) { + //create repository of materialization points + _matPoints = new HashMap<>(); + for( InterestingPoint p : part.getMatPointsExt() ) + _matPoints.put(Pair.of(p._fromHopID, p._toHopID), new NodeLink(p)); + + //create reachability graph + _root = new NodeLink(null); + HashSet<VisitMarkCost> visited = new HashSet<>(); + for( Long hopID : part.getRoots() ) { + Hop rootHop = memo.getHopRefs().get(hopID); + addInputNodeLinks(rootHop, _root, part, memo, visited); + } + + //create candidate cutsets + List<NodeLink> tmpCS = _matPoints.values().stream() + .filter(p -> p._inputs.size() > 0 && p._p != null) + .sorted().collect(Collectors.toList()); + + //short-cut for partitions without cutsets + if( tmpCS.isEmpty() ) { + _cutSets = new CutSet[0]; + _searchSpace = part.getMatPointsExt(); + return; + } + + //create composite cutsets + ArrayList<ArrayList<NodeLink>> candCS = new ArrayList<>(); + ArrayList<NodeLink> current = new ArrayList<>(); + for( NodeLink node : tmpCS ) { + if( current.isEmpty() ) + current.add(node); + else if( current.get(0).equals(node) ) + current.add(node); + else { + candCS.add(current); + current = new ArrayList<>(); + current.add(node); + } + } + if( !current.isEmpty() ) + candCS.add(current); + + //evaluate cutsets (single, and duplicate pairs) + ArrayList<ArrayList<NodeLink>> remain = new ArrayList<>(); + ArrayList<Pair<CutSet,Double>> cutSets = evaluateCutSets(candCS, remain); + if( !remain.isEmpty() && remain.size() < 5 ) { + //second chance: for pairs for remaining candidates + ArrayList<ArrayList<NodeLink>> candCS2 = new ArrayList<>(); + for( int i=0; i<remain.size()-1; i++) + for( int j=i+1; j<remain.size(); j++) { + ArrayList<NodeLink> tmp = new ArrayList<>(); + tmp.addAll(remain.get(i)); + tmp.addAll(remain.get(j)); + candCS2.add(tmp); + } + ArrayList<Pair<CutSet,Double>> cutSets2 = evaluateCutSets(candCS2, remain); + //ensure constructed cutsets are disjoint + HashSet<InterestingPoint> testDisjoint = new HashSet<>(); + for( Pair<CutSet,Double> cs : cutSets2 ) { + if( !CollectionUtils.containsAny(testDisjoint, Arrays.asList(cs.getLeft().cut)) ) { + cutSets.add(cs); + CollectionUtils.addAll(testDisjoint, cs.getLeft().cut); + } + } + } + + //sort and linearize search space according to scores + _cutSets = cutSets.stream() + .sorted(Comparator.comparing(p -> p.getRight())) + .map(p -> p.getLeft()).toArray(CutSet[]::new); + + HashMap<InterestingPoint, Integer> probe = new HashMap<>(); + ArrayList<InterestingPoint> lsearchSpace = new ArrayList<>(); + for( CutSet cs : _cutSets ) { + cs.updatePos(lsearchSpace.size()); + cs.updatePartitions(probe); + CollectionUtils.addAll(lsearchSpace, cs.cut); + for( InterestingPoint p: cs.cut ) + probe.put(p, probe.size()-1); + } + for( InterestingPoint p : part.getMatPointsExt() ) + if( !probe.containsKey(p) ) { + lsearchSpace.add(p); + probe.put(p, probe.size()-1); + } + _searchSpace = lsearchSpace.toArray(new InterestingPoint[0]); + + //materialize partition indices + for( CutSet cs : _cutSets ) { + cs.updatePartitionIndexes(probe); + cs.finalizePartition(); + } + + //final sanity check of interesting points + if( _searchSpace.length != part.getMatPointsExt().length ) + throw new RuntimeException("Corrupt linearized search space: " + + _searchSpace.length+" vs "+part.getMatPointsExt().length); + } + + public InterestingPoint[] getSortedSearchSpace() { + return _searchSpace; + } + + public boolean isCutSet(boolean[] plan) { + for( CutSet cs : _cutSets ) + if( isCutSet(cs, plan) ) + return true; + return false; + } + + public boolean isCutSet(CutSet cs, boolean[] plan) { + boolean ret = true; + for(int i=0; i<cs.posCut.length && ret; i++) + ret &= plan[cs.posCut[i]]; + return ret; + } + + public CutSet getCutSet(boolean[] plan) { + for( CutSet cs : _cutSets ) + if( isCutSet(cs, plan) ) + return cs; + throw new RuntimeException("No valid cut set found."); + } + + public long getNumSkipPlans(boolean[] plan) { + for( CutSet cs : _cutSets ) + if( isCutSet(cs, plan) ) { + int pos = cs.posCut[cs.posCut.length-1]; + return (long) Math.pow(2, plan.length-pos-1); + } + throw new RuntimeException("Failed to compute " + + "number of skip plans for plan without cutset."); + } + + + public SubProblem[] getSubproblems(boolean[] plan) { + CutSet cs = getCutSet(plan); + return new SubProblem[] { + new SubProblem(cs.cut.length, cs.posLeft, cs.left), + new SubProblem(cs.cut.length, cs.posRight, cs.right)}; + } + + @Override + public String toString() { + return "ReachabilityGraph("+_matPoints.size()+"):\n" + + _root.explain(new HashSet<>()); + } + + private void addInputNodeLinks(Hop current, NodeLink parent, PlanPartition part, + CPlanMemoTable memo, HashSet<VisitMarkCost> visited) + { + if( visited.contains(new VisitMarkCost(current.getHopID(), parent._ID)) ) + return; + + //process children + for( Hop in : current.getInput() ) { + if( InterestingPoint.isMatPoint(part.getMatPointsExt(), current.getHopID(), in.getHopID()) ) { + NodeLink tmp = _matPoints.get(Pair.of(current.getHopID(), in.getHopID())); + parent.addInput(tmp); + addInputNodeLinks(in, tmp, part, memo, visited); + } + else + addInputNodeLinks(in, parent, part, memo, visited); + } + + visited.add(new VisitMarkCost(current.getHopID(), parent._ID)); + } + + private void rCollectInputs(NodeLink current, HashSet<NodeLink> probe, HashSet<NodeLink> inputs) { + for( NodeLink c : current._inputs ) + if( !probe.contains(c) ) { + rCollectInputs(c, probe, inputs); + inputs.add(c); + } + } + + private ArrayList<Pair<CutSet,Double>> evaluateCutSets(ArrayList<ArrayList<NodeLink>> candCS, ArrayList<ArrayList<NodeLink>> remain) { + ArrayList<Pair<CutSet,Double>> cutSets = new ArrayList<>(); + + for( ArrayList<NodeLink> cand : candCS ) { + HashSet<NodeLink> probe = new HashSet<>(cand); + + //determine subproblems for cutset candidates + HashSet<NodeLink> part1 = new HashSet<>(); + rCollectInputs(_root, probe, part1); + HashSet<NodeLink> part2 = new HashSet<>(); + for( NodeLink rNode : cand ) + rCollectInputs(rNode, probe, part2); + + //select, score and create cutsets + if( !CollectionUtils.containsAny(part1, part2) + && !part1.isEmpty() && !part2.isEmpty()) { + //score cutsets (smaller is better) + double base = Math.pow(2, _matPoints.size()); + double numComb = Math.pow(2, cand.size()); + double score = (numComb-1)/numComb * base + + 1/numComb * Math.pow(2, part1.size()) + + 1/numComb * Math.pow(2, part2.size()); + + //construct cutset + cutSets.add(Pair.of(new CutSet( + cand.stream().map(p->p._p).toArray(InterestingPoint[]::new), + part1.stream().map(p->p._p).toArray(InterestingPoint[]::new), + part2.stream().map(p->p._p).toArray(InterestingPoint[]::new)), score)); + } + else { + remain.add(cand); + } + } + + return cutSets; + } + + public static class SubProblem { + public int offset; + public int[] freePos; + public InterestingPoint[] freeMat; + + public SubProblem(int off, int[] pos, InterestingPoint[] mat) { + offset = off; + freePos = pos; + freeMat = mat; + } + } + + public static class CutSet { + public InterestingPoint[] cut; + public InterestingPoint[] left; + public InterestingPoint[] right; + public int[] posCut; + public int[] posLeft; + public int[] posRight; + + public CutSet(InterestingPoint[] cutPoints, + InterestingPoint[] l, InterestingPoint[] r) { + cut = cutPoints; + left = l; + right = r; + } + + public void updatePos(int index) { + posCut = new int[cut.length]; + for(int i=0; i<posCut.length; i++) + posCut[i] = index + i; + } + + public void updatePartitions(HashMap<InterestingPoint,Integer> blacklist) { + left = Arrays.stream(left).filter(p -> !blacklist.containsKey(p)) + .toArray(InterestingPoint[]::new); + right = Arrays.stream(right).filter(p -> !blacklist.containsKey(p)) + .toArray(InterestingPoint[]::new); + } + + public void updatePartitionIndexes(HashMap<InterestingPoint,Integer> probe) { + posLeft = new int[left.length]; + for(int i=0; i<left.length; i++) + posLeft[i] = probe.get(left[i]); + posRight = new int[right.length]; + for(int i=0; i<right.length; i++) + posRight[i] = probe.get(right[i]); + } + + public void finalizePartition() { + left = (InterestingPoint[]) ArrayUtils.addAll(cut, left); + right = (InterestingPoint[]) ArrayUtils.addAll(cut, right); + } + + @Override + public String toString() { + return "Cut : "+Arrays.toString(cut); + } + } + + private static class NodeLink implements Comparable<NodeLink> + { + private static final IDSequence _seqID = new IDSequence(); + + private ArrayList<NodeLink> _inputs = new ArrayList<>(); + private long _ID; + private InterestingPoint _p; + + public NodeLink(InterestingPoint p) { + _ID = _seqID.getNextID(); + _p = p; + } + + public void addInput(NodeLink in) { + _inputs.add(in); + } + + @Override + public boolean equals(Object o) { + if( !(o instanceof NodeLink) ) + return false; + NodeLink that = (NodeLink) o; + boolean ret = (_inputs.size() == that._inputs.size()); + for( int i=0; i<_inputs.size() && ret; i++ ) + ret &= (_inputs.get(i)._ID == that._inputs.get(i)._ID); + return ret; + } + + @Override + public int compareTo(NodeLink that) { + if( _inputs.size() > that._inputs.size() ) + return -1; + else if( _inputs.size() < that._inputs.size() ) + return 1; + for( int i=0; i<_inputs.size(); i++ ) { + int comp = Long.compare(_inputs.get(i)._ID, + that._inputs.get(i)._ID); + if( comp != 0 ) + return comp; + } + return 0; + } + + @Override + public String toString() { + StringBuilder inputs = new StringBuilder(); + for(NodeLink in : _inputs) { + if( inputs.length() > 0 ) + inputs.append(","); + inputs.append(in._ID); + } + return _ID+" ("+inputs.toString()+") "+((_p!=null)?_p:"null"); + } + + private String explain(HashSet<Long> visited) { + if( visited.contains(_ID) ) + return ""; + //add children + StringBuilder sb = new StringBuilder(); + StringBuilder inputs = new StringBuilder(); + for(NodeLink in : _inputs) { + String tmp = in.explain(visited); + if( !tmp.isEmpty() ) + sb.append(tmp + "\n"); + if( inputs.length() > 0 ) + inputs.append(","); + inputs.append(in._ID); + } + //add node itself + sb.append(_ID+" ("+inputs+") "+((_p!=null)?_p:"null")); + visited.add(_ID); + + return sb.toString(); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index edbcdf9..4078060 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -36,6 +37,8 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.codegen.SpoofCompiler; +import org.apache.sysml.hops.codegen.opt.InterestingPoint; +import org.apache.sysml.hops.codegen.opt.PlanSelection; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; import org.apache.sysml.runtime.util.UtilFunctions; @@ -53,6 +56,18 @@ public class CPlanMemoTable _plansBlacklist = new HashSet<Long>(); } + public HashMap<Long, List<MemoTableEntry>> getPlans() { + return _plans; + } + + public HashSet<Long> getPlansBlacklisted() { + return _plansBlacklist; + } + + public HashMap<Long, Hop> getHopRefs() { + return _hopRefs; + } + public void addHop(Hop hop) { _hopRefs.put(hop.getHopID(), hop); } @@ -78,6 +93,14 @@ public class CPlanMemoTable .anyMatch(p -> (!checkClose||!p.closed) && probe.contains(p.type)); } + public boolean containsNotIn(long hopID, Collection<TemplateType> types, + boolean checkChildRefs, boolean excludeCell) { + return contains(hopID) && get(hopID).stream() + .anyMatch(p -> (!checkChildRefs || p.hasPlanRef()) + && (!excludeCell || p.type!=TemplateType.CELL) + && !types.contains(p.type)); + } + public int countEntries(long hopID) { return get(hopID).size(); } @@ -85,7 +108,7 @@ public class CPlanMemoTable public int countEntries(long hopID, TemplateType type) { return (int) get(hopID).stream() .filter(p -> p.type==type).count(); - } + } public boolean containsTopLevel(long hopID) { return !_plansBlacklist.contains(hopID) @@ -133,7 +156,7 @@ public class CPlanMemoTable .distinct().collect(Collectors.toList())); } - public void pruneRedundant(long hopID) { + public void pruneRedundant(long hopID, boolean pruneDominated, InterestingPoint[] matPoints) { if( !contains(hopID) ) return; @@ -146,7 +169,7 @@ public class CPlanMemoTable //prune dominated plans (e.g., opened plan subsumed by fused plan //if single consumer of input; however this only applies to fusion //heuristic that only consider materialization points) - if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) { + if( pruneDominated ) { HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>(); List<MemoTableEntry> list = _plans.get(hopID); Hop hop = _hopRefs.get(hopID); @@ -155,9 +178,12 @@ public class CPlanMemoTable if( e1 != e2 && e1.subsumes(e2) ) { //check that childs don't have multiple consumers boolean rmSafe = true; - for( int i=0; i<=2; i++ ) + for( int i=0; i<=2; i++ ) { rmSafe &= (e1.isPlanRef(i) && !e2.isPlanRef(i)) ? - hop.getInput().get(i).getParent().size()==1 : true; + (matPoints!=null && !InterestingPoint.isMatPoint( + matPoints, hopID, e1.input(i))) + || hop.getInput().get(i).getParent().size()==1 : true; + } if( rmSafe ) rmList.add(e2); } @@ -194,12 +220,14 @@ public class CPlanMemoTable //prune dominated plans (e.g., plan referenced by other plan and this //other plan is single consumer) by marking it as blacklisted because //the chain of entries is still required for cplan construction - for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() ) - for( MemoTableEntry me : e.getValue() ) { - for( int i=0; i<=2; i++ ) - if( me.isPlanRef(i) && _hopRefs.get(me.input(i)).getParent().size()==1 ) - _plansBlacklist.add(me.input(i)); - } + if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) { + for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() ) + for( MemoTableEntry me : e.getValue() ) { + for( int i=0; i<=2; i++ ) + if( me.isPlanRef(i) && _hopRefs.get(me.input(i)).getParent().size()==1 ) + _plansBlacklist.add(me.input(i)); + } + } //core plan selection PlanSelection selector = SpoofCompiler.createPlanSelector(); @@ -232,6 +260,16 @@ public class CPlanMemoTable .distinct().collect(Collectors.toList()); } + public List<TemplateType> getDistinctTemplateTypes(long hopID, int refAt) { + if(!contains(hopID)) + return Collections.emptyList(); + //return distinct template types with reference at given position + return _plans.get(hopID).stream() + .filter(p -> p.isPlanRef(refAt)) + .map(p -> p.type) //extract type + .distinct().collect(Collectors.toList()); + } + public MemoTableEntry getBest(long hopID) { List<MemoTableEntry> tmp = get(hopID); if( tmp == null || tmp.isEmpty() ) http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java deleted file mode 100644 index f8a12fd..0000000 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.hops.codegen.template; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; - -import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; -import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; -import org.apache.sysml.hops.rewrite.HopRewriteUtils; -import org.apache.sysml.runtime.util.UtilFunctions; - -public abstract class PlanSelection -{ - private final HashMap<Long, List<MemoTableEntry>> _bestPlans = - new HashMap<Long, List<MemoTableEntry>>(); - private final HashSet<VisitMark> _visited = new HashSet<VisitMark>(); - - /** - * Given a HOP DAG G, and a set of partial fusions plans P, find the set of optimal, - * non-conflicting fusion plans P' that applied to G minimizes costs C with - * P' = \argmin_{p \subseteq P} C(G, p) s.t. Z \vDash p, where Z is a set of - * constraints such as memory budgets and block size restrictions per fused operator. - * - * @param memo partial fusion plans P - * @param roots entry points of HOP DAG G - */ - public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots); - - /** - * Determines if the given partial fusion plan is valid. - * - * @param me memo table entry - * @param hop current hop - * @return true if entry is valid as top-level plan - */ - protected static boolean isValid(MemoTableEntry me, Hop hop) { - return (me.type == TemplateType.OuterProdTpl - && (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop))) - || (me.type == TemplateType.RowTpl) - || (me.type == TemplateType.CellTpl) - || (me.type == TemplateType.MultiAggTpl); - } - - protected void addBestPlan(long hopID, MemoTableEntry me) { - if( me == null ) return; - if( !_bestPlans.containsKey(hopID) ) - _bestPlans.put(hopID, new ArrayList<MemoTableEntry>()); - _bestPlans.get(hopID).add(me); - } - - protected HashMap<Long, List<MemoTableEntry>> getBestPlans() { - return _bestPlans; - } - - protected boolean isVisited(long hopID, TemplateType type) { - return _visited.contains(new VisitMark(hopID, type)); - } - - protected void setVisited(long hopID, TemplateType type) { - _visited.add(new VisitMark(hopID, type)); - } - - /** - * Basic plan comparator to compare memo table entries with regard to - * a pre-defined template preference order and the number of references. - */ - protected static class BasicPlanComparator implements Comparator<MemoTableEntry> { - @Override - public int compare(MemoTableEntry o1, MemoTableEntry o2) { - //for different types, select preferred type - if( o1.type != o2.type ) - return Integer.compare(o1.type.getRank(), o2.type.getRank()); - - //for same type, prefer plan with more refs - return Integer.compare( - 3-o1.countPlanRefs(), 3-o2.countPlanRefs()); - } - } - - private static class VisitMark { - private final long _hopID; - private final TemplateType _type; - - public VisitMark(long hopID, TemplateType type) { - _hopID = hopID; - _type = type; - } - @Override - public int hashCode() { - return UtilFunctions.longHashCode( - _hopID, (_type!=null)?_type.hashCode():0); - } - @Override - public boolean equals(Object o) { - return (o instanceof VisitMark - && _hopID == ((VisitMark)o)._hopID - && _type == ((VisitMark)o)._type); - } - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java deleted file mode 100644 index a455302..0000000 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.hops.codegen.template; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.Map.Entry; -import java.util.HashSet; -import java.util.List; - -import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; -import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; - -/** - * This plan selection heuristic aims for maximal fusion, which - * potentially leads to overlapping fused operators and thus, - * redundant computation but with a minimal number of materialized - * intermediate results. - * - */ -public class PlanSelectionFuseAll extends PlanSelection -{ - @Override - public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) { - //pruning and collection pass - for( Hop hop : roots ) - rSelectPlans(memo, hop, null); - - //take all distinct best plans - for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) - memo.setDistinct(e.getKey(), e.getValue()); - } - - private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) - { - if( isVisited(current.getHopID(), currentType) ) - return; - - //step 1: prune subsumed plans of same type - if( memo.contains(current.getHopID()) ) { - HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); - List<MemoTableEntry> hopP = memo.get(current.getHopID()); - for( MemoTableEntry e1 : hopP ) - for( MemoTableEntry e2 : hopP ) - if( e1 != e2 && e1.subsumes(e2) ) - rmSet.add(e2); - memo.remove(current, rmSet); - } - - //step 2: select plan for current path - MemoTableEntry best = null; - if( memo.contains(current.getHopID()) ) { - if( currentType == null ) { - best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) - .min(new BasicPlanComparator()).orElse(null); - } - else { - best = memo.get(current.getHopID()).stream() - .filter(p -> p.type==currentType || p.type==TemplateType.CellTpl) - .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) - .orElse(null); - } - addBestPlan(current.getHopID(), best); - } - - //step 3: recursively process children - for( int i=0; i< current.getInput().size(); i++ ) { - TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; - rSelectPlans(memo, current.getInput().get(i), pref); - } - - setVisited(current.getHopID(), currentType); - } -}
