Repository: systemml Updated Branches: refs/heads/master 4ec6f0865 -> 7b4a3418a
http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java deleted file mode 100644 index 82fedff..0000000 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ /dev/null @@ -1,1009 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.hops.codegen.template; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Map.Entry; -import java.util.stream.Collectors; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; - -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.sysml.hops.AggBinaryOp; -import org.apache.sysml.hops.AggUnaryOp; -import org.apache.sysml.hops.BinaryOp; -import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.Hop.AggOp; -import org.apache.sysml.hops.Hop.Direction; -import org.apache.sysml.hops.IndexingOp; -import org.apache.sysml.hops.LiteralOp; -import org.apache.sysml.hops.ParameterizedBuiltinOp; -import org.apache.sysml.hops.ReorgOp; -import org.apache.sysml.hops.TernaryOp; -import org.apache.sysml.hops.UnaryOp; -import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; -import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; -import org.apache.sysml.hops.rewrite.HopRewriteUtils; -import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; -import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; - -/** - * This cost-based plan selection algorithm chooses fused operators - * based on the DAG structure and resulting overall costs. This includes - * decisions on materialization points, template types, and composed - * multi output templates. - * - */ -public class PlanSelectionFuseCostBased extends PlanSelection -{ - private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBased.class.getName()); - - //common bandwidth characteristics, with a conservative write bandwidth in order - //to cover result allocation, write into main memory, and potential evictions - private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024; //2GB/s - private static final double READ_BANDWIDTH = 32d*1024*1024*1024; //32GB/s - private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core - * InfrastructureAnalyzer.getLocalParallelism(); - - private static final IDSequence COST_ID = new IDSequence(); - private static final TemplateRow ROW_TPL = new TemplateRow(); - - @Override - public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) - { - //step 1: determine connected sub graphs of plans - Collection<HashSet<Long>> parts = getConnectedSubGraphs(memo, roots); - if( LOG.isTraceEnabled() ) - LOG.trace("Connected sub graphs: "+parts.size()); - - for( HashSet<Long> partition : parts ) { - //step 2: determine materialization points - HashSet<Long> R = getPartitionRootNodes(memo, partition); - if( LOG.isTraceEnabled() ) - LOG.trace("Partition root points: "+Arrays.toString(R.toArray(new Long[0]))); - ArrayList<Long> M = getMaterializationPoints(R, partition, memo); - if( LOG.isTraceEnabled() ) - LOG.trace("Partition materialization points: "+Arrays.toString(M.toArray(new Long[0]))); - - //step 3: create composite templates (within the partition) - createAndAddMultiAggPlans(memo, partition, R); - - //step 4: plan enumeration and plan selection - selectPlans(memo, partition, R, M); - } - - //step 5: add composite templates (across partitions) - createAndAddMultiAggPlans(memo, roots); - - //take all distinct best plans - for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) - memo.setDistinct(e.getKey(), e.getValue()); - } - - private static Collection<HashSet<Long>> getConnectedSubGraphs(CPlanMemoTable memo, ArrayList<Hop> roots) - { - //build inverted index for 'referenced by' relationship - HashMap<Long, HashSet<Long>> refBy = new HashMap<Long, HashSet<Long>>(); - for( Entry<Long, List<MemoTableEntry>> e : memo._plans.entrySet() ) - for( MemoTableEntry me : e.getValue() ) - for( int i=0; i<3; i++ ) - if( me.isPlanRef(i) ) { - if( !refBy.containsKey(me.input(i)) ) - refBy.put(me.input(i), new HashSet<Long>()); - refBy.get(me.input(i)).add(e.getKey()); - } - - //create a single partition per root node, if reachable over refBy of - //other root node the resulting partition is empty and can be discarded - ArrayList<HashSet<Long>> parts = new ArrayList<HashSet<Long>>(); - HashSet<Long> visited = new HashSet<Long>(); - for( Entry<Long, List<MemoTableEntry>> e : memo._plans.entrySet() ) - if( !refBy.containsKey(e.getKey()) ) { //root node - HashSet<Long> part = rGetConnectedSubGraphs(e.getKey(), - memo, refBy, visited, new HashSet<Long>()); - if( !part.isEmpty() ) - parts.add(part); - } - - return parts; - } - - private static HashSet<Long> rGetConnectedSubGraphs(long hopID, CPlanMemoTable memo, - HashMap<Long, HashSet<Long>> refBy, HashSet<Long> visited, HashSet<Long> partition) - { - if( visited.contains(hopID) ) - return partition; - - //process node itself w/ memoization - if( memo.contains(hopID) ) { - partition.add(hopID); - visited.add(hopID); - } - - //recursively process parents - if( refBy.containsKey(hopID) ) - for( Long ref : refBy.get(hopID) ) - rGetConnectedSubGraphs(ref, memo, refBy, visited, partition); - - //recursively process children - if( memo.contains(hopID) ) { - long[] refs = memo.getAllRefs(hopID); - for( int i=0; i<3; i++ ) - if( refs[i] != -1 ) - rGetConnectedSubGraphs(refs[i], memo, refBy, visited, partition); - } - - return partition; - } - - private static HashSet<Long> getPartitionRootNodes(CPlanMemoTable memo, HashSet<Long> partition) - { - //build inverted index of references entries - HashSet<Long> ix = new HashSet<Long>(); - for( Long hopID : partition ) - if( memo.contains(hopID) ) - for( MemoTableEntry me : memo.get(hopID) ) { - ix.add(me.input1); - ix.add(me.input2); - ix.add(me.input3); - } - - HashSet<Long> roots = new HashSet<Long>(); - for( Long hopID : partition ) - if( !ix.contains(hopID) ) - roots.add(hopID); - return roots; - } - - private static ArrayList<Long> getMaterializationPoints(HashSet<Long> roots, - HashSet<Long> partition, CPlanMemoTable memo) - { - //collect materialization points bottom-up - ArrayList<Long> ret = new ArrayList<Long>(); - HashSet<Long> visited = new HashSet<Long>(); - for( Long hopID : roots ) - rCollectMaterializationPoints(memo._hopRefs.get(hopID), - visited, partition, ret); - - //remove special-case materialization points - //(root nodes w/ multiple consumers, tsmm input if consumed in partition) - ret.removeIf(hopID -> roots.contains(hopID) - || HopRewriteUtils.isTsmmInput(memo._hopRefs.get(hopID))); - - return ret; - } - - private static void rCollectMaterializationPoints(Hop current, HashSet<Long> visited, - HashSet<Long> partition, ArrayList<Long> M) - { - //memoization (not via hops because in middle of dag) - if( visited.contains(current.getHopID()) ) - return; - - //process children recursively - for( Hop c : current.getInput() ) - rCollectMaterializationPoints(c, visited, partition, M); - - //collect materialization point - if( isMaterializationPointCandidate(current, partition) ) - M.add(current.getHopID()); - - visited.add(current.getHopID()); - } - - private static boolean isMaterializationPointCandidate(Hop hop, HashSet<Long> partition) { - return hop.getParent().size()>=2 - && partition.contains(hop.getHopID()); - } - - //within-partition multi-agg templates - private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) - { - //create index of plans that reference full aggregates to avoid circular dependencies - HashSet<Long> refHops = new HashSet<Long>(); - for( Entry<Long, List<MemoTableEntry>> e : memo._plans.entrySet() ) - if( !e.getValue().isEmpty() ) { - Hop hop = memo._hopRefs.get(e.getKey()); - for( Hop c : hop.getInput() ) - refHops.add(c.getHopID()); - } - - //find all full aggregations (the fact that they are in the same partition guarantees - //that they also have common subexpressions, also full aggregations are by def root nodes) - ArrayList<Long> fullAggs = new ArrayList<Long>(); - for( Long hopID : R ) { - Hop root = memo._hopRefs.get(hopID); - if( !refHops.contains(hopID) && isMultiAggregateRoot(root) ) - fullAggs.add(hopID); - } - if( LOG.isTraceEnabled() ) { - LOG.trace("Found within-partition ua(RC) aggregations: " + - Arrays.toString(fullAggs.toArray(new Long[0]))); - } - - //construct and add multiagg template plans (w/ max 3 aggregations) - for( int i=0; i<fullAggs.size(); i+=3 ) { - int ito = Math.min(i+3, fullAggs.size()); - if( ito-i >= 2 ) { - MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl, - fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1, ito-i); - if( isValidMultiAggregate(memo, me) ) { - for( int j=i; j<ito; j++ ) { - memo.add(memo._hopRefs.get(fullAggs.get(j)), me); - if( LOG.isTraceEnabled() ) - LOG.trace("Added multiagg plan: "+fullAggs.get(j)+" "+me); - } - } - else if( LOG.isTraceEnabled() ) { - LOG.trace("Removed invalid multiagg plan: "+me); - } - } - } - } - - //across-partition multi-agg templates with shared reads - private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) - { - //collect full aggregations as initial set of candidates - HashSet<Long> fullAggs = new HashSet<Long>(); - Hop.resetVisitStatus(roots); - for( Hop hop : roots ) - rCollectFullAggregates(hop, fullAggs); - Hop.resetVisitStatus(roots); - - //remove operators with assigned multi-agg plans - fullAggs.removeIf(p -> memo.contains(p, TemplateType.MultiAggTpl)); - - //check applicability for further analysis - if( fullAggs.size() <= 1 ) - return; - - if( LOG.isTraceEnabled() ) { - LOG.trace("Found across-partition ua(RC) aggregations: " + - Arrays.toString(fullAggs.toArray(new Long[0]))); - } - - //collect information for all candidates - //(subsumed aggregations, and inputs to fused operators) - List<AggregateInfo> aggInfos = new ArrayList<AggregateInfo>(); - for( Long hopID : fullAggs ) { - Hop aggHop = memo._hopRefs.get(hopID); - AggregateInfo tmp = new AggregateInfo(aggHop); - for( int i=0; i<aggHop.getInput().size(); i++ ) { - Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? - aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i); - rExtractAggregateInfo(memo, c, tmp, TemplateType.CellTpl); - } - if( tmp._fusedInputs.isEmpty() ) { - if( HopRewriteUtils.isMatrixMultiply(aggHop) ) { - tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID()); - tmp.addFusedInput(aggHop.getInput().get(1).getHopID()); - } - else - tmp.addFusedInput(aggHop.getInput().get(0).getHopID()); - } - aggInfos.add(tmp); - } - - if( LOG.isTraceEnabled() ) { - LOG.trace("Extracted across-partition ua(RC) aggregation info: "); - for( AggregateInfo info : aggInfos ) - LOG.trace(info); - } - - //sort aggregations by num dependencies to simplify merging - //clusters of aggregations with parallel dependencies - aggInfos = aggInfos.stream() - .sorted(Comparator.comparing(a -> a._inputAggs.size())) - .collect(Collectors.toList()); - - //greedy grouping of multi-agg candidates - boolean converged = false; - while( !converged ) { - AggregateInfo merged = null; - for( int i=0; i<aggInfos.size(); i++ ) { - AggregateInfo current = aggInfos.get(i); - for( int j=i+1; j<aggInfos.size(); j++ ) { - AggregateInfo that = aggInfos.get(j); - if( current.isMergable(that) ) { - merged = current.merge(that); - aggInfos.remove(j); j--; - } - } - } - converged = (merged == null); - } - - if( LOG.isTraceEnabled() ) { - LOG.trace("Merged across-partition ua(RC) aggregation info: "); - for( AggregateInfo info : aggInfos ) - LOG.trace(info); - } - - //construct and add multiagg template plans (w/ max 3 aggregations) - for( AggregateInfo info : aggInfos ) { - if( info._aggregates.size()<=1 ) - continue; - Long[] aggs = info._aggregates.keySet().toArray(new Long[0]); - MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl, - aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, aggs.length); - for( int i=0; i<aggs.length; i++ ) { - memo.add(memo._hopRefs.get(aggs[i]), me); - addBestPlan(aggs[i], me); - if( LOG.isTraceEnabled() ) - LOG.trace("Added multiagg* plan: "+aggs[i]+" "+me); - - } - } - } - - private static boolean isMultiAggregateRoot(Hop root) { - return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) - && ((AggUnaryOp)root).getDirection()==Direction.RowCol) - || (root instanceof AggBinaryOp && root.getDim1()==1 && root.getDim2()==1 - && HopRewriteUtils.isTransposeOperation(root.getInput().get(0))); - } - - private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) { - //ensure input consistent sizes (otherwise potential for incorrect results) - boolean ret = true; - Hop refSize = memo._hopRefs.get(me.input1).getInput().get(0); - for( int i=1; ret && i<3; i++ ) { - if( me.isPlanRef(i) ) - ret &= HopRewriteUtils.isEqualSize(refSize, - memo._hopRefs.get(me.input(i)).getInput().get(0)); - } - - //ensure that aggregates are independent of each other, i.e., - //they to not have potentially transitive parent child references - for( int i=0; ret && i<3; i++ ) - if( me.isPlanRef(i) ) { - HashSet<Long> probe = new HashSet<Long>(); - for( int j=0; j<3; j++ ) - if( i != j ) - probe.add(me.input(j)); - ret &= rCheckMultiAggregate(memo._hopRefs.get(me.input(i)), probe); - } - return ret; - } - - private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) { - boolean ret = true; - for( Hop c : current.getInput() ) - ret &= rCheckMultiAggregate(c, probe); - ret &= !probe.contains(current.getHopID()); - return ret; - } - - private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) { - if( current.isVisited() ) - return; - - //collect all applicable full aggregations per read - if( isMultiAggregateRoot(current) ) - aggs.add(current.getHopID()); - - //recursively process children - for( Hop c : current.getInput() ) - rCollectFullAggregates(c, aggs); - - current.setVisited(); - } - - private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateType type) { - //collect input aggregates (dependents) - if( isMultiAggregateRoot(current) ) - aggInfo.addInputAggregate(current.getHopID()); - - //recursively process children - MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null; - for( int i=0; i<current.getInput().size(); i++ ) { - Hop c = current.getInput().get(i); - if( me != null && me.isPlanRef(i) ) - rExtractAggregateInfo(memo, c, aggInfo, type); - else { - if( type != null && c.getDataType().isMatrix() ) //add fused input - aggInfo.addFusedInput(c.getHopID()); - rExtractAggregateInfo(memo, c, aggInfo, null); - } - } - } - - private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) - { - //prune row aggregates with pure cellwise operations - for( Long hopID : R ) { - MemoTableEntry me = memo.getBest(hopID, TemplateType.RowTpl); - if( me.type == TemplateType.RowTpl && memo.contains(hopID, TemplateType.CellTpl) - && isRowTemplateWithoutAgg(memo, memo._hopRefs.get(hopID), new HashSet<Long>())) { - List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.RowTpl); - memo.remove(memo._hopRefs.get(hopID), new HashSet<MemoTableEntry>(blacklist)); - if( LOG.isTraceEnabled() ) { - LOG.trace("Removed row memo table entries w/o aggregation: " - + Arrays.toString(blacklist.toArray(new MemoTableEntry[0]))); - } - } - } - - //prune suboptimal outer product plans that are dominated by outer product plans w/ same number of - //references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), - //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern. - for( Long hopID : partition ) { - if( memo.countEntries(hopID, TemplateType.OuterProdTpl) == 2 ) { - List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OuterProdTpl); - MemoTableEntry me1 = entries.get(0); - MemoTableEntry me2 = entries.get(1); - MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2); - if( rmEntry != null ) { - memo.remove(memo._hopRefs.get(hopID), Collections.singleton(rmEntry)); - memo._plansBlacklist.remove(rmEntry.input(rmEntry.getPlanRefIndex())); - if( LOG.isTraceEnabled() ) - LOG.trace("Removed dominated outer product memo table entry: " + rmEntry); - } - } - } - - //if no materialization points, use basic fuse-all w/ partition awareness - if( M == null || M.isEmpty() ) { - for( Long hopID : R ) - rSelectPlansFuseAll(memo, - memo._hopRefs.get(hopID), null, partition); - } - else { - //TODO branch and bound pruning, right now we use exhaustive enum for early experiments - //via skip ahead in below enumeration algorithm - - //obtain hop compute costs per cell once - HashMap<Long, Double> computeCosts = new HashMap<Long, Double>(); - for( Long hopID : R ) - rGetComputeCosts(memo._hopRefs.get(hopID), partition, computeCosts); - - //scan linearized search space, w/ skips for branch and bound pruning - int len = (int)Math.pow(2, M.size()); - boolean[] bestPlan = null; - double bestC = Double.MAX_VALUE; - - for( int i=0; i<len; i++ ) { - //construct assignment - boolean[] plan = createAssignment(M.size(), i); - - //cost assignment on hops - double C = getPlanCost(memo, partition, R, M, plan, computeCosts); - if( LOG.isTraceEnabled() ) - LOG.trace("Enum: "+Arrays.toString(plan)+" -> "+C); - - //cost comparisons - if( bestPlan == null || C < bestC ) { - bestC = C; - bestPlan = plan; - if( LOG.isTraceEnabled() ) - LOG.trace("Enum: Found new best plan."); - } - } - - //prune memo table wrt best plan and select plans - HashSet<Long> visited = new HashSet<Long>(); - for( Long hopID : R ) - rPruneSuboptimalPlans(memo, memo._hopRefs.get(hopID), - visited, partition, M, bestPlan); - HashSet<Long> visited2 = new HashSet<Long>(); - for( Long hopID : R ) - rPruneInvalidPlans(memo, memo._hopRefs.get(hopID), - visited2, partition, M, bestPlan); - - for( Long hopID : R ) - rSelectPlansFuseAll(memo, - memo._hopRefs.get(hopID), null, partition); - } - } - - private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { - //consider all aggregations other than root operation - MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl); - boolean ret = true; - for(int i=0; i<3; i++) - if( me.isPlanRef(i) ) - ret &= rIsRowTemplateWithoutAgg(memo, - current.getInput().get(i), visited); - return ret; - } - - private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { - if( visited.contains(current.getHopID()) ) - return true; - - boolean ret = true; - MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl); - for(int i=0; i<3; i++) - if( me.isPlanRef(i) ) - ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited); - ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp); - - visited.add(current.getHopID()); - return ret; - } - - private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) { - //memoization (not via hops because in middle of dag) - if( visited.contains(current.getHopID()) ) - return; - - //remove memo table entries if necessary - long hopID = current.getHopID(); - if( partition.contains(hopID) && memo.contains(hopID) ) { - Iterator<MemoTableEntry> iter = memo.get(hopID).iterator(); - while( iter.hasNext() ) { - MemoTableEntry me = iter.next(); - if( !hasNoRefToMaterialization(me, M, plan) && me.type!=TemplateType.OuterProdTpl ){ - iter.remove(); - if( LOG.isTraceEnabled() ) - LOG.trace("Removed memo table entry: "+me); - } - } - } - - //process children recursively - for( Hop c : current.getInput() ) - rPruneSuboptimalPlans(memo, c, visited, partition, M, plan); - - visited.add(current.getHopID()); - } - - private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) { - //memoization (not via hops because in middle of dag) - if( visited.contains(current.getHopID()) ) - return; - - //process children recursively - for( Hop c : current.getInput() ) - rPruneInvalidPlans(memo, c, visited, partition, M, plan); - - //find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs, - //i.e., plans that become invalid after the previous pruning step - long hopID = current.getHopID(); - if( partition.contains(hopID) && memo.contains(hopID, TemplateType.RowTpl) ) { - for( MemoTableEntry me : memo.get(hopID) ) { - if( me.type==TemplateType.RowTpl ) { - //convert leaf node with pure vector inputs - if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) { - me.type = TemplateType.CellTpl; - if( LOG.isTraceEnabled() ) - LOG.trace("Converted leaf memo table entry from row to cell: "+me); - } - - //convert inner node without row template input - if( me.hasPlanRef() && !ROW_TPL.open(current) ) { - boolean hasRowInput = false; - for( int i=0; i<3; i++ ) - if( me.isPlanRef(i) ) - hasRowInput |= memo.contains(me.input(i), TemplateType.RowTpl); - if( !hasRowInput ) { - me.type = TemplateType.CellTpl; - if( LOG.isTraceEnabled() ) - LOG.trace("Converted inner memo table entry from row to cell: "+me); - } - } - - } - } - } - - visited.add(current.getHopID()); - } - - private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) - { - if( isVisited(current.getHopID(), currentType) - || !partition.contains(current.getHopID()) ) - return; - - //step 1: prune subsumed plans of same type - if( memo.contains(current.getHopID()) ) { - HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); - List<MemoTableEntry> hopP = memo.get(current.getHopID()); - for( MemoTableEntry e1 : hopP ) - for( MemoTableEntry e2 : hopP ) - if( e1 != e2 && e1.subsumes(e2) ) - rmSet.add(e2); - memo.remove(current, rmSet); - } - - //step 2: select plan for current path - MemoTableEntry best = null; - if( memo.contains(current.getHopID()) ) { - if( currentType == null ) { - best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) - .min(new BasicPlanComparator()).orElse(null); - } - else { - best = memo.get(current.getHopID()).stream() - .filter(p -> p.type==currentType || p.type==TemplateType.CellTpl) - .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) - .orElse(null); - } - addBestPlan(current.getHopID(), best); - } - - //step 3: recursively process children - for( int i=0; i< current.getInput().size(); i++ ) { - TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; - rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition); - } - - setVisited(current.getHopID(), currentType); - } - - private static boolean[] createAssignment(int len, int pos) { - boolean[] ret = new boolean[len]; - int tmp = pos; - for( int i=0; i<len; i++ ) { - ret[i] = (tmp < (int)Math.pow(2, len-i-1)); - tmp %= Math.pow(2, len-i-1); - } - return ret; - } - - ///////////////////////////////////////////////////////// - // Cost model fused operators w/ materialization points - ////////// - - private static double getPlanCost(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, - ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts) - { - //high level heuristic: every hop or fused operator has the following cost: - //WRITE + max(COMPUTE, READ), where WRITE costs are given by the output size, - //READ costs by the input sizes, and COMPUTE by operation specific FLOP - //counts times number of cells of main input, disregarding sparsity for now. - - HashSet<Pair<Long,Long>> visited = new HashSet<Pair<Long,Long>>(); - double costs = 0; - for( Long hopID : R ) - costs += rGetPlanCosts(memo, memo._hopRefs.get(hopID), - visited, partition, M, plan, computeCosts, null, null); - return costs; - } - - private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<Pair<Long,Long>> visited, HashSet<Long> partition, - ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateType currentType) - { - //memoization per hop id and cost vector to account for redundant - //computation without double counting materialized results or compute - //costs of complex operation DAGs within a single fused operator - Pair<Long,Long> tag = Pair.of(current.getHopID(), - (costsCurrent==null)?0:costsCurrent.ID); - if( visited.contains(tag) ) - return 0; - visited.add(tag); - - //open template if necessary, including memoization - //under awareness of current plan choice - MemoTableEntry best = null; - boolean opened = false; - if( memo.contains(current.getHopID()) ) { - if( currentType == null ) { - best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) - .filter(p -> hasNoRefToMaterialization(p, M, plan)) - .min(new BasicPlanComparator()).orElse(null); - opened = true; - } - else { - best = memo.get(current.getHopID()).stream() - .filter(p -> p.type==currentType || p.type==TemplateType.CellTpl) - .filter(p -> hasNoRefToMaterialization(p, M, plan)) - .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) - .orElse(null); - } - } - - //create new cost vector if opened, initialized with write costs - CostVector costVect = !opened ? costsCurrent : - new CostVector(Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1)); - - //add compute costs of current operator to costs vector - if( partition.contains(current.getHopID()) ) - costVect.computeCosts += computeCosts.get(current.getHopID()); - - //process children recursively - double costs = 0; - for( int i=0; i< current.getInput().size(); i++ ) { - Hop c = current.getInput().get(i); - if( best!=null && best.isPlanRef(i) ) - costs += rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, costVect, best.type); - else if( best!=null && isImplicitlyFused(current, i, best.type) ) - costVect.addInputSize(c.getInput().get(0).getHopID(), Math.max(c.getDim1(),1)*Math.max(c.getDim2(),1)); - else { //include children and I/O costs - costs += rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, null, null); - if( costVect != null && c.getDataType().isMatrix() ) - costVect.addInputSize(c.getHopID(), Math.max(c.getDim1(),1)*Math.max(c.getDim2(),1)); - } - } - - //add costs for opened fused operator - if( partition.contains(current.getHopID()) ) { - if( opened ) { - if( LOG.isTraceEnabled() ) - LOG.trace("Cost vector for fused operator (hop "+current.getHopID()+"): "+costVect); - costs += costVect.outSize * 8 / WRITE_BANDWIDTH; //time for output write - costs += Math.max( - costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH, - costVect.getSumInputSizes() * 8 / READ_BANDWIDTH); - } - //add costs for non-partition read in the middle of fused operator - else if( hasNonPartitionConsumer(current, partition) ) { - costs += rGetPlanCosts(memo, current, visited, partition, M, plan, computeCosts, null, null); - } - } - - //sanity check non-negative costs - if( costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs) ) - throw new RuntimeException("Wrong cost estimate: "+costs); - - return costs; - } - - private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) - { - if( computeCosts.containsKey(current.getHopID()) ) - return; - - //recursively process children - for( Hop c : current.getInput() ) - rGetComputeCosts(c, partition, computeCosts); - - //get costs for given hop - double costs = 1; - if( current instanceof UnaryOp ) { - switch( ((UnaryOp)current).getOp() ) { - case ABS: - case ROUND: - case CEIL: - case FLOOR: - case SIGN: - case SELP: costs = 1; break; - case SPROP: - case SQRT: costs = 2; break; - case EXP: costs = 18; break; - case SIGMOID: costs = 21; break; - case LOG: - case LOG_NZ: costs = 32; break; - case NCOL: - case NROW: - case PRINT: - case CAST_AS_BOOLEAN: - case CAST_AS_DOUBLE: - case CAST_AS_INT: - case CAST_AS_MATRIX: - case CAST_AS_SCALAR: costs = 1; break; - case SIN: costs = 18; break; - case COS: costs = 22; break; - case TAN: costs = 42; break; - case ASIN: costs = 93; break; - case ACOS: costs = 103; break; - case ATAN: costs = 40; break; - case CUMSUM: - case CUMMIN: - case CUMMAX: - case CUMPROD: costs = 1; break; - default: - LOG.warn("Cost model not " - + "implemented yet for: "+((UnaryOp)current).getOp()); - } - } - else if( current instanceof BinaryOp ) { - switch( ((BinaryOp)current).getOp() ) { - case MULT: - case PLUS: - case MINUS: - case MIN: - case MAX: - case AND: - case OR: - case EQUAL: - case NOTEQUAL: - case LESS: - case LESSEQUAL: - case GREATER: - case GREATEREQUAL: - case CBIND: - case RBIND: costs = 1; break; - case INTDIV: costs = 6; break; - case MODULUS: costs = 8; break; - case DIV: costs = 22; break; - case LOG: - case LOG_NZ: costs = 32; break; - case POW: costs = (HopRewriteUtils.isLiteralOfValue( - current.getInput().get(1), 2) ? 1 : 16); break; - case MINUS_NZ: - case MINUS1_MULT: costs = 2; break; - case CENTRALMOMENT: - int type = (int) (current.getInput().get(1) instanceof LiteralOp ? - HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2); - switch( type ) { - case 0: costs = 1; break; //count - case 1: costs = 8; break; //mean - case 2: costs = 16; break; //cm2 - case 3: costs = 31; break; //cm3 - case 4: costs = 51; break; //cm4 - case 5: costs = 16; break; //variance - } - break; - case COVARIANCE: costs = 23; break; - default: - LOG.warn("Cost model not " - + "implemented yet for: "+((BinaryOp)current).getOp()); - } - } - else if( current instanceof TernaryOp ) { - switch( ((TernaryOp)current).getOp() ) { - case PLUS_MULT: - case MINUS_MULT: costs = 2; break; - case CTABLE: costs = 3; break; - case CENTRALMOMENT: - int type = (int) (current.getInput().get(1) instanceof LiteralOp ? - HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2); - switch( type ) { - case 0: costs = 2; break; //count - case 1: costs = 9; break; //mean - case 2: costs = 17; break; //cm2 - case 3: costs = 32; break; //cm3 - case 4: costs = 52; break; //cm4 - case 5: costs = 17; break; //variance - } - break; - case COVARIANCE: costs = 23; break; - default: - LOG.warn("Cost model not " - + "implemented yet for: "+((TernaryOp)current).getOp()); - } - } - else if( current instanceof ParameterizedBuiltinOp ) { - costs = 1; - } - else if( current instanceof IndexingOp ) { - costs = 1; - } - else if( current instanceof ReorgOp ) { - costs = 1; - } - else if( current instanceof AggBinaryOp ) { - costs = 2; //matrix vector - } - else if( current instanceof AggUnaryOp) { - switch(((AggUnaryOp)current).getOp()) { - case SUM: costs = 4; break; - case SUM_SQ: costs = 5; break; - case MIN: - case MAX: costs = 1; break; - default: - LOG.warn("Cost model not " - + "implemented yet for: "+((AggUnaryOp)current).getOp()); - } - } - - computeCosts.put(current.getHopID(), costs); - } - - private static boolean hasNoRefToMaterialization(MemoTableEntry me, ArrayList<Long> M, boolean[] plan) { - boolean ret = true; - for( int i=0; ret && i<3; i++ ) - ret &= (!M.contains(me.input(i)) || !plan[M.indexOf(me.input(i))]); - return ret; - } - - private static boolean hasNonPartitionConsumer(Hop hop, HashSet<Long> partition) { - boolean ret = false; - for( Hop p : hop.getParent() ) - ret |= !partition.contains(p.getHopID()); - return ret; - } - - private static boolean isImplicitlyFused(Hop hop, int index, TemplateType type) { - return type == TemplateType.RowTpl - && HopRewriteUtils.isMatrixMultiply(hop) && index==0 - && HopRewriteUtils.isTransposeOperation(hop.getInput().get(index)); - } - - private static class CostVector { - public final long ID; - public final double outSize; - public double computeCosts = 0; - public final HashMap<Long, Double> inSizes = new HashMap<Long, Double>(); - - public CostVector(double outputSize) { - ID = COST_ID.getNextID(); - outSize = outputSize; - } - public void addInputSize(long hopID, double inputSize) { - //ensures that input sizes are not double counted - inSizes.put(hopID, inputSize); - } - public double getSumInputSizes() { - return inSizes.values().stream() - .mapToDouble(d -> d.doubleValue()).sum(); - } - public double getMaxInputSize() { - return inSizes.values().stream() - .mapToDouble(d -> d.doubleValue()).max().orElse(0); - } - @Override - public String toString() { - return "["+outSize+", "+computeCosts+", {" - +Arrays.toString(inSizes.keySet().toArray(new Long[0]))+", " - +Arrays.toString(inSizes.values().toArray(new Double[0]))+"}]"; - } - } - - private static class AggregateInfo { - public final HashMap<Long,Hop> _aggregates; - public final HashSet<Long> _inputAggs = new HashSet<Long>(); - public final HashSet<Long> _fusedInputs = new HashSet<Long>(); - public AggregateInfo(Hop aggregate) { - _aggregates = new HashMap<Long, Hop>(); - _aggregates.put(aggregate.getHopID(), aggregate); - } - public void addInputAggregate(long hopID) { - _inputAggs.add(hopID); - } - public void addFusedInput(long hopID) { - _fusedInputs.add(hopID); - } - public boolean isMergable(AggregateInfo that) { - //check independence - boolean ret = _aggregates.size()<3 - && _aggregates.size()+that._aggregates.size()<=3; - for( Long hopID : that._aggregates.keySet() ) - ret &= !_inputAggs.contains(hopID); - for( Long hopID : _aggregates.keySet() ) - ret &= !that._inputAggs.contains(hopID); - //check partial shared reads - ret &= !CollectionUtils.intersection( - _fusedInputs, that._fusedInputs).isEmpty(); - //check consistent sizes (result correctness) - Hop in1 = _aggregates.values().iterator().next(); - Hop in2 = that._aggregates.values().iterator().next(); - return ret && HopRewriteUtils.isEqualSize( - in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0), - in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0)); - } - public AggregateInfo merge(AggregateInfo that) { - _aggregates.putAll(that._aggregates); - _inputAggs.addAll(that._inputAggs); - _fusedInputs.addAll(that._fusedInputs); - return this; - } - @Override - public String toString() { - return "["+Arrays.toString(_aggregates.keySet().toArray(new Long[0]))+": " - +"{"+Arrays.toString(_inputAggs.toArray(new Long[0]))+"}," - +"{"+Arrays.toString(_fusedInputs.toArray(new Long[0]))+"}]"; - } - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java deleted file mode 100644 index fa8b25a..0000000 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseNoRedundancy.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.hops.codegen.template; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.Map.Entry; -import java.util.HashSet; -import java.util.List; - -import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; -import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; - -/** - * This plan selection heuristic aims for fusion without any redundant - * computation, which, however, potentially leads to more materialized - * intermediates than the fuse all heuristic. - * <p> - * NOTE: This heuristic is essentially the same as FuseAll, except that - * any plans that refer to a hop with multiple consumers are removed in - * a pre-processing step. - * - */ -public class PlanSelectionFuseNoRedundancy extends PlanSelection -{ - @Override - public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) { - //pruning and collection pass - for( Hop hop : roots ) - rSelectPlans(memo, hop, null); - - //take all distinct best plans - for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) - memo.setDistinct(e.getKey(), e.getValue()); - } - - private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) - { - if( isVisited(current.getHopID(), currentType) ) - return; - - //step 0: remove plans that refer to a common partial plan - if( memo.contains(current.getHopID()) ) { - HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); - List<MemoTableEntry> hopP = memo.get(current.getHopID()); - for( MemoTableEntry e1 : hopP ) - for( int i=0; i<3; i++ ) - if( e1.isPlanRef(i) && current.getInput().get(i).getParent().size()>1 ) - rmSet.add(e1); //remove references to hops w/ multiple consumers - memo.remove(current, rmSet); - } - - //step 1: prune subsumed plans of same type - if( memo.contains(current.getHopID()) ) { - HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); - List<MemoTableEntry> hopP = memo.get(current.getHopID()); - for( MemoTableEntry e1 : hopP ) - for( MemoTableEntry e2 : hopP ) - if( e1 != e2 && e1.subsumes(e2) ) - rmSet.add(e2); - memo.remove(current, rmSet); - } - - //step 2: select plan for current path - MemoTableEntry best = null; - if( memo.contains(current.getHopID()) ) { - if( currentType == null ) { - best = memo.get(current.getHopID()).stream() - .filter(p -> isValid(p, current)) - .min(new BasicPlanComparator()).orElse(null); - } - else { - best = memo.get(current.getHopID()).stream() - .filter(p -> p.type==currentType || p.type==TemplateType.CellTpl) - .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) - .orElse(null); - } - addBestPlan(current.getHopID(), best); - } - - //step 3: recursively process children - for( int i=0; i< current.getInput().size(); i++ ) { - TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; - rSelectPlans(memo, current.getInput().get(i), pref); - } - - setVisited(current.getHopID(), currentType); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java index a4f2b91..b42eecf 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java @@ -19,6 +19,7 @@ package org.apache.sysml.hops.codegen.template; +import org.apache.commons.lang3.ArrayUtils; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.runtime.matrix.data.Pair; @@ -28,13 +29,16 @@ public abstract class TemplateBase { public enum TemplateType { //ordering specifies type preferences - MultiAggTpl, - OuterProdTpl, - RowTpl, - CellTpl; + MAGG, + OUTER, + ROW, + CELL; public int getRank() { return this.ordinal(); } + public boolean isIn(TemplateType... types) { + return ArrayUtils.contains(types, this); + } } public enum CloseType { http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 68f7412..b781fd8 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -58,11 +58,11 @@ public class TemplateCell extends TemplateBase new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX}; public TemplateCell() { - super(TemplateType.CellTpl); + super(TemplateType.CELL); } public TemplateCell(boolean closed) { - super(TemplateType.CellTpl, closed); + super(TemplateType.CELL, closed); } public TemplateCell(TemplateType type, boolean closed) { @@ -149,10 +149,10 @@ public class TemplateCell extends TemplateBase if( tmp.containsKey(hop.getHopID()) ) return; - MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CellTpl); + MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL); //recursively process required childs - if( me!=null && (me.type == TemplateType.RowTpl || me.type == TemplateType.OuterProdTpl) ) { + if( me!=null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER) ) { CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals); tmp.put(hop.getHopID(), cdata); inHops.add(hop); @@ -161,9 +161,9 @@ public class TemplateCell extends TemplateBase for( int i=0; i<hop.getInput().size(); i++ ) { Hop c = hop.getInput().get(i); if( me!=null && me.isPlanRef(i) && !(c instanceof DataOp) - && (me.type!=TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateType.CellTpl))) + && (me.type!=TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL))) rConstructCplan(c, memo, tmp, inHops, compileLiterals); - else if( me!=null && (me.type==TemplateType.MultiAggTpl || me.type==TemplateType.CellTpl) + else if( me!=null && (me.type==TemplateType.MAGG || me.type==TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && i==0 ) //skip transpose rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals); else { http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java index 2d53e7c..bc51cf0 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java @@ -37,11 +37,11 @@ import org.apache.sysml.runtime.matrix.data.Pair; public class TemplateMultiAgg extends TemplateCell { public TemplateMultiAgg() { - super(TemplateType.MultiAggTpl, false); + super(TemplateType.MAGG, false); } public TemplateMultiAgg(boolean closed) { - super(TemplateType.MultiAggTpl, closed); + super(TemplateType.MAGG, closed); } @Override @@ -69,7 +69,7 @@ public class TemplateMultiAgg extends TemplateCell public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) { //get all root nodes for multi aggregation - MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MultiAggTpl); + MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MAGG); ArrayList<Hop> roots = new ArrayList<Hop>(); for( int i=0; i<3; i++ ) if( multiAgg.isPlanRef(i) ) http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java index 9f5b191..f001a81 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java @@ -48,11 +48,11 @@ import org.apache.sysml.runtime.matrix.data.Pair; public class TemplateOuterProduct extends TemplateBase { public TemplateOuterProduct() { - super(TemplateType.OuterProdTpl); + super(TemplateType.OUTER); } public TemplateOuterProduct(boolean closed) { - super(TemplateType.OuterProdTpl, closed); + super(TemplateType.OUTER, closed); } @Override @@ -141,7 +141,7 @@ public class TemplateOuterProduct extends TemplateBase { return; //recursively process required childs - MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OuterProdTpl); + MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OUTER); for( int i=0; i<hop.getInput().size(); i++ ) { Hop c = hop.getInput().get(i); if( me.isPlanRef(i) ) @@ -220,7 +220,7 @@ public class TemplateOuterProduct extends TemplateBase { tmp.put(hop.getHopID(), out); } - protected static MemoTableEntry dropAlternativePlan(CPlanMemoTable memo, MemoTableEntry me1, MemoTableEntry me2) { + public static MemoTableEntry dropAlternativePlan(CPlanMemoTable memo, MemoTableEntry me1, MemoTableEntry me2) { //if there are two alternative sub plans with references to disjoint outer product plans //drop the one that would render the other invalid if( me1.countPlanRefs()==1 && me2.countPlanRefs()==1 @@ -229,8 +229,8 @@ public class TemplateOuterProduct extends TemplateBase { Hop c1 = memo._hopRefs.get(me1.input(me1.getPlanRefIndex())); Hop c2 = memo._hopRefs.get(me2.input(me2.getPlanRefIndex())); - if( memo.contains(c1.getHopID(), TemplateType.OuterProdTpl) - && memo.contains(c2.getHopID(), TemplateType.OuterProdTpl) ) + if( memo.contains(c1.getHopID(), TemplateType.OUTER) + && memo.contains(c2.getHopID(), TemplateType.OUTER) ) { if( HopRewriteUtils.isBinaryMatrixMatrixOperation(c1) && HopRewriteUtils.isBinary(c1, OpOp2.MULT, OpOp2.DIV) ) http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 659b528..a3037ec 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -64,11 +64,11 @@ public class TemplateRow extends TemplateBase OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL}; public TemplateRow() { - super(TemplateType.RowTpl); + super(TemplateType.ROW); } public TemplateRow(boolean closed) { - super(TemplateType.RowTpl, closed); + super(TemplateType.ROW, closed); } @Override @@ -208,7 +208,7 @@ public class TemplateRow extends TemplateBase return; //recursively process required childs - MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.RowTpl); + MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW); for( int i=0; i<hop.getInput().size(); i++ ) { Hop c = hop.getInput().get(i); if( me!=null && me.isPlanRef(i) ) http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index b461c5e..21061f2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -149,10 +149,10 @@ public class TemplateUtils public static TemplateBase createTemplate(TemplateType type, boolean closed) { TemplateBase tpl = null; switch( type ) { - case CellTpl: tpl = new TemplateCell(closed); break; - case RowTpl: tpl = new TemplateRow(closed); break; - case MultiAggTpl: tpl = new TemplateMultiAgg(closed); break; - case OuterProdTpl: tpl = new TemplateOuterProduct(closed); break; + case CELL: tpl = new TemplateCell(closed); break; + case ROW: tpl = new TemplateRow(closed); break; + case MAGG: tpl = new TemplateMultiAgg(closed); break; + case OUTER: tpl = new TemplateOuterProduct(closed); break; } return tpl; } @@ -160,10 +160,10 @@ public class TemplateUtils public static TemplateBase[] createCompatibleTemplates(TemplateType type, boolean closed) { TemplateBase[] tpl = null; switch( type ) { - case CellTpl: tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRow(closed)}; break; - case RowTpl: tpl = new TemplateBase[]{new TemplateRow(closed)}; break; - case MultiAggTpl: tpl = new TemplateBase[]{new TemplateMultiAgg(closed)}; break; - case OuterProdTpl: tpl = new TemplateBase[]{new TemplateOuterProduct(closed)}; break; + case CELL: tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRow(closed)}; break; + case ROW: tpl = new TemplateBase[]{new TemplateRow(closed)}; break; + case MAGG: tpl = new TemplateBase[]{new TemplateMultiAgg(closed)}; break; + case OUTER: tpl = new TemplateBase[]{new TemplateOuterProduct(closed)}; break; } return tpl; } @@ -183,9 +183,9 @@ public class TemplateUtils public static RowType getRowType(Hop output, Hop... inputs) { Hop X = inputs[0]; Hop B1 = (inputs.length>1) ? inputs[1] : null; - if( X!=null && HopRewriteUtils.isEqualSize(output, X) ) + if( (X!=null && HopRewriteUtils.isEqualSize(output, X)) || X==null ) return RowType.NO_AGG; - else if( (B1 != null && output.getDim1()==X.getDim1() && output.getDim2()==B1.getDim2()) + else if( (B1!=null && output.getDim1()==X.getDim1() && output.getDim2()==B1.getDim2()) || (output instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)output))) return RowType.NO_AGG_B1; else if( output.getDim1()==X.getDim1() && (output.getDim2()==1 @@ -372,7 +372,7 @@ public class TemplateUtils public static boolean hasCommonRowTemplateMatrixInput(Hop input1, Hop input2, CPlanMemoTable memo) { //if second input has no row template, it's always true - if( !memo.contains(input2.getHopID(), TemplateType.RowTpl) ) + if( !memo.contains(input2.getHopID(), TemplateType.ROW) ) return true; //check for common row template input long tmp1 = getRowTemplateMatrixInput(input1, memo); @@ -381,11 +381,11 @@ public class TemplateUtils } public static long getRowTemplateMatrixInput(Hop current, CPlanMemoTable memo) { - MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl); + MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); long ret = -1; for( int i=0; ret<0 && i<current.getInput().size(); i++ ) { Hop input = current.getInput().get(i); - if( me.isPlanRef(i) && memo.contains(input.getHopID(), TemplateType.RowTpl) ) + if( me.isPlanRef(i) && memo.contains(input.getHopID(), TemplateType.ROW) ) ret = getRowTemplateMatrixInput(input, memo); else if( !me.isPlanRef(i) && isMatrix(input) ) ret = input.getHopID(); http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java index 5b3b193..83ee425 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java @@ -171,6 +171,12 @@ public class LibSpoofPrimitives if( a == null ) return; System.arraycopy(a, 0, c, ci, len); } + + public static void vectWrite(boolean[] a, boolean[] c, int[] aix) { + if( a == null ) return; + for( int i=0; i<aix.length; i++ ) + c[aix[i]] = a[i]; + } // custom vector sums, mins, maxs http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java index 08032af..25cafe7 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java @@ -1060,7 +1060,10 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl @Override public Long call() throws DMLRuntimeException { - _c = (_type==CellType.COL_AGG)? new MatrixBlock(1,_clen, false) : _c; + if( _type==CellType.COL_AGG ) { + _c = new MatrixBlock(1,_clen, false); + _c.allocateDenseBlock(); + } if( _a instanceof CompressedMatrixBlock ) return executeCompressed((CompressedMatrixBlock)_a, _b, _scalars, _c, _rlen, _clen, _safe, _rl, _ru); else if( !_a.isInSparseFormat() ) http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java index dc6baff..abe2e78 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java @@ -26,6 +26,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.stream.IntStream; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.compress.CompressedMatrixBlock; @@ -121,7 +122,8 @@ public abstract class SpoofRowwise extends SpoofOperator //result allocation and preparations final int m = inputs.get(0).getNumRows(); final int n = inputs.get(0).getNumColumns(); - final int n2 = _type.isRowTypeB1() ? inputs.get(1).getNumColumns() : -1; + final int n2 = _type.isRowTypeB1() ? + getMinColsMatrixSideInputs(inputs) : -1; if( !aggIncr || !out.isAllocated() ) allocateOutputMatrix(m, n, n2, out); double[] c = out.getDenseBlock(); @@ -168,7 +170,8 @@ public abstract class SpoofRowwise extends SpoofOperator //result allocation and preparations final int m = inputs.get(0).getNumRows(); final int n = inputs.get(0).getNumColumns(); - final int n2 = _type.isRowTypeB1() ? inputs.get(1).getNumColumns() : -1; + final int n2 = _type.isRowTypeB1() ? + getMinColsMatrixSideInputs(inputs) : -1; allocateOutputMatrix(m, n, n2, out); //input preparation @@ -214,6 +217,14 @@ public abstract class SpoofRowwise extends SpoofOperator } } + private static int getMinColsMatrixSideInputs(ArrayList<MatrixBlock> inputs) { + //For B1 types, get the output number of columns as the minimum + //number of columns of side input matrices other than vectors. + return IntStream.range(1, inputs.size()) + .map(i -> inputs.get(i).getNumColumns()) + .filter(ncol -> ncol > 1).min().orElse(1); + } + private void allocateOutputMatrix(int m, int n, int n2, MatrixBlock out) { switch( _type ) { case NO_AGG: out.reset(m, n, false); break; http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index e7e8ecc..8891a6c 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -76,6 +76,7 @@ public class Statistics private static final LongAdder codegenCompileTime = new LongAdder(); //in nano private static final LongAdder codegenClassCompileTime = new LongAdder(); //in nano private static final LongAdder codegenHopCompile = new LongAdder(); //count + private static final LongAdder codegenFPlanCompile = new LongAdder(); //count private static final LongAdder codegenCPlanCompile = new LongAdder(); //count private static final LongAdder codegenClassCompile = new LongAdder(); //count private static final LongAdder codegenPlanCacheHits = new LongAdder(); //count @@ -256,6 +257,10 @@ public class Statistics codegenCPlanCompile.add(delta); } + public static void incrementCodegenFPlanCompile(long delta) { + codegenFPlanCompile.add(delta); + } + public static void incrementCodegenClassCompile() { codegenClassCompile.increment(); } @@ -284,6 +289,10 @@ public class Statistics return codegenCPlanCompile.longValue(); } + public static long getCodegenFPlanCompile() { + return codegenFPlanCompile.longValue(); + } + public static long getCodegenClassCompile() { return codegenClassCompile.longValue(); } @@ -376,6 +385,13 @@ public class Statistics funRecompiles.reset(); funRecompileTime.reset(); + codegenHopCompile.reset(); + codegenFPlanCompile.reset(); + codegenCPlanCompile.reset(); + codegenClassCompile.reset(); + codegenCompileTime.reset(); + codegenClassCompileTime.reset(); + parforOptCount = 0; parforOptTime = 0; parforInitTime = 0; @@ -757,7 +773,8 @@ public class Statistics sb.append("Functions recompile time:\t" + String.format("%.3f", ((double)getFunRecompileTime())/1000000000) + " sec.\n"); } if( ConfigurationManager.isCodegenEnabled() ) { - sb.append("Codegen compile (DAG, CP, JC):\t" + getCodegenDAGCompile() + "/" + getCodegenCPlanCompile() + "/" + getCodegenClassCompile() + ".\n"); + sb.append("Codegen compile (DAG,FP,CP,JC):\t" + getCodegenDAGCompile() + "/" + getCodegenFPlanCompile() + + "/" + getCodegenCPlanCompile() + "/" + getCodegenClassCompile() + ".\n"); sb.append("Codegen compile times (DAG,JC):\t" + String.format("%.3f", (double)getCodegenCompileTime()/1000000000) + "/" + String.format("%.3f", (double)getCodegenClassCompileTime()/1000000000) + " sec.\n"); sb.append("Codegen plan cache hits:\t" + getCodegenPlanCacheHits() + "/" + getCodegenPlanCacheTotal() + ".\n"); http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java new file mode 100644 index 0000000..75c28eb --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MiscPatternTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.codegen; + +import java.io.File; +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class MiscPatternTest extends AutomatedTestBase +{ + private static final String TEST_NAME = "miscPattern"; + private static final String TEST_NAME1 = TEST_NAME+"1"; //Y + (X * U%*%t(V)) overlapping cell-outer + private static final String TEST_NAME2 = TEST_NAME+"2"; //multi-agg w/ large common subexpression + + private static final String TEST_DIR = "functions/codegen/"; + private static final String TEST_CLASS_DIR = TEST_DIR + MiscPatternTest.class.getSimpleName() + "/"; + private final static String TEST_CONF = "SystemML-config-codegen.xml"; + private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF); + + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + for(int i=1; i<=2; i++) + addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); + } + + @Test + public void testCodegenMiscRewrite1CP() { + testCodegenIntegration( TEST_NAME1, true, ExecType.CP ); + } + + @Test + public void testCodegenMisc1CP() { + testCodegenIntegration( TEST_NAME1, false, ExecType.CP ); + } + + @Test + public void testCodegenMisc1SP() { + testCodegenIntegration( TEST_NAME1, false, ExecType.SPARK ); + } + + @Test + public void testCodegenMiscRewrite2CP() { + testCodegenIntegration( TEST_NAME2, true, ExecType.CP ); + } + + @Test + public void testCodegenMisc2CP() { + testCodegenIntegration( TEST_NAME2, false, ExecType.CP ); + } + + @Test + public void testCodegenMisc2SP() { + testCodegenIntegration( TEST_NAME2, false, ExecType.SPARK ); + } + + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + RUNTIME_PLATFORM platformOld = rtplatform; + switch( instType ) { + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-explain", "recompile_runtime", "-stats", "-args", output("S") }; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("S"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + Assert.assertTrue(heavyHittersContainsSubString("spoof") + || heavyHittersContainsSubString("sp_spoof")); + + //ensure correct optimizer decisions + if( testname.equals(TEST_NAME1) ) + Assert.assertTrue(!heavyHittersContainsSubString("spoofCell") + && !heavyHittersContainsSubString("sp_spoofCell")); + else if( testname.equals(TEST_NAME2) ) + Assert.assertTrue(!heavyHittersContainsSubString("spoof", 2) + && !heavyHittersContainsSubString("sp_spoof", 2)); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true; + OptimizerUtils.ALLOW_OPERATOR_FUSION = true; + } + } + + /** + * Override default configuration with custom test configuration to ensure + * scratch space and local temporary directory locations are also updated. + */ + @Override + protected File getConfigTemplateFile() { + // Instrumentation in this test's output log to show custom configuration file used for template. + System.out.println("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); + return TEST_CONF_FILE; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/test/scripts/functions/codegen/miscPattern1.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/miscPattern1.R b/src/test/scripts/functions/codegen/miscPattern1.R new file mode 100644 index 0000000..d393a2c --- /dev/null +++ b/src/test/scripts/functions/codegen/miscPattern1.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(1, 1100, 2200); +Y = matrix(2, 1100, 2200); +U = matrix(3, 1100, 10); +V = matrix(4, 2200, 10) +X[4:900,3:1000] = matrix(0, 897, 998); + +R1 = Y + (X * U%*%t(V)); +R2 = as.matrix(sum(R1)); + +writeMM(as(R2, "CsparseMatrix"), paste(args[1], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/test/scripts/functions/codegen/miscPattern1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/miscPattern1.dml b/src/test/scripts/functions/codegen/miscPattern1.dml new file mode 100644 index 0000000..6c8dbd9 --- /dev/null +++ b/src/test/scripts/functions/codegen/miscPattern1.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, 1100, 2200); +Y = matrix(2, 1100, 2200); +U = matrix(3, 1100, 10); +V = matrix(4, 2200, 10) +X[4:900,3:1000] = matrix(0, 897, 998); + +if(1==1){} + +R1 = Y + (X * U%*%t(V)); + +if(1==1){} + +R2 = as.matrix(sum(R1)); +write(R2, $1) http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/test/scripts/functions/codegen/miscPattern2.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/miscPattern2.R b/src/test/scripts/functions/codegen/miscPattern2.R new file mode 100644 index 0000000..e8caf2a --- /dev/null +++ b/src/test/scripts/functions/codegen/miscPattern2.R @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(1, 2340, 7); +Y = matrix(2, 2340, 7); + +Z = abs(exp((log(exp(X + Y))+7)/7)); + +R1 = sum((Z+1)^2/10000); +R2 = sum((Z+2)^3/10000); + +R = as.matrix(R1+R2); + +writeMM(as(R, "CsparseMatrix"), paste(args[1], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/test/scripts/functions/codegen/miscPattern2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/miscPattern2.dml b/src/test/scripts/functions/codegen/miscPattern2.dml new file mode 100644 index 0000000..8c2a35b --- /dev/null +++ b/src/test/scripts/functions/codegen/miscPattern2.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, 2340, 7); +Y = matrix(2, 2340, 7); + +if(1==1){} + +Z = abs(exp((log(exp(X + Y))+7)/7)); + +R1 = sum((Z+1)^2/10000); +R2 = sum((Z+2)^3/10000); + +R = as.matrix(R1+R2); +write(R, $1) http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java index 63be419..a9310b3 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/codegen/ZPackageSuite.java @@ -42,6 +42,7 @@ import org.junit.runners.Suite; CPlanComparisonTest.class, CPlanVectorPrimitivesTest.class, DAGCellwiseTmplTest.class, + MiscPatternTest.class, MultiAggTmplTest.class, OuterProdTmplTest.class, RowAggTmplTest.class,