Repository: incubator-systemml Updated Branches: refs/heads/master 30f72e83f -> 2893e1aed
[SYSTEMML-1374] Improved candidate exploration of code generation plans This patch improves the codegen candidate exploration algorithm by (1) better memoization (which now also includes unsupported operators) and (2) a simplified creation of merge plans (which now also applies to ternary operators). Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9cbaf85a Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9cbaf85a Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9cbaf85a Branch: refs/heads/master Commit: 9cbaf85ab1389cde9fb79f58e29c0adb6044c493 Parents: 30f72e8 Author: Matthias Boehm <mboe...@gmail.com> Authored: Sat Mar 18 21:39:28 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sat Mar 18 21:39:28 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 60 +++++++++----------- .../hops/codegen/template/CPlanMemoTable.java | 34 +++++++++++ 2 files changed, 60 insertions(+), 34 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9cbaf85a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 54e67b6..6479917 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -44,6 +44,7 @@ import org.apache.sysml.hops.codegen.template.BaseTpl.CloseType; import org.apache.sysml.hops.codegen.template.BaseTpl.TemplateType; import org.apache.sysml.hops.codegen.template.CPlanMemoTable; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; +import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntrySet; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.OpOp1; @@ -318,68 +319,59 @@ public class SpoofCompiler throws DMLException { //top-down memoization of processed dag nodes - if( memo.contains(hop.getHopID()) ) + if( memo.contains(hop.getHopID()) || memo.containsHop(hop) ) return; - //recursively process child nodes + //recursive candidate exploration for( Hop c : hop.getInput() ) rExploreCPlans(c, memo, compileLiterals); - //generate new node plans + //open initial operator plans, if possible for( BaseTpl tpl : TemplateUtils.TEMPLATES ) if( tpl.open(hop) ) memo.add(hop, tpl.getType()); + //fuse and merge operator plans for( Hop c : hop.getInput() ) { if( memo.contains(c.getHopID()) ) for( MemoTableEntry me : memo.get(c.getHopID()) ) { BaseTpl tpl = TemplateUtils.createTemplate(me.type, me.closed); - if( tpl.fuse(hop, c) ) - genExplorePlans(tpl, hop, memo, hop.getInput(), c); + if( tpl.fuse(hop, c) ) { + int pos = hop.getInput().indexOf(c); + MemoTableEntrySet P = new MemoTableEntrySet(tpl.getType(), pos, c.getHopID(), tpl.isClosed()); + for(int k=0; k<hop.getInput().size(); k++) + if( k != pos ) { + Hop input2 = hop.getInput().get(k); + if( memo.contains(input2.getHopID()) && !memo.get(input2.getHopID()).get(0).closed + && memo.get(input2.getHopID()).get(0).type == TemplateType.CellTpl && tpl.merge(hop, input2) ) + P.crossProduct(k, -1L, input2.getHopID()); + else + P.crossProduct(k, -1L); + } + memo.addAll(hop, P); + } } } //prune subsumed / redundant plans memo.pruneRedundant(hop.getHopID()); - //check if templates require close + //close operator plans, if required if( memo.contains(hop.getHopID()) ) { Iterator<MemoTableEntry> iter = memo.get(hop.getHopID()).iterator(); while( iter.hasNext() ) { MemoTableEntry me = iter.next(); BaseTpl tpl = TemplateUtils.createTemplate(me.type); CloseType ccode = tpl.close(hop); - if( ccode != CloseType.OPEN ) { + if( ccode == CloseType.CLOSED_INVALID ) + iter.remove(); + else if( ccode == CloseType.CLOSED_VALID ) me.closed = true; - if( ccode == CloseType.CLOSED_INVALID ) - iter.remove(); - } } } - } - - private static void genExplorePlans(BaseTpl tpl, Hop hop, CPlanMemoTable memo, ArrayList<Hop> inputs, Hop exclude) - { - //handle unary operators - if( hop.getInput().size() == 1 ) { - memo.add(hop, tpl.getType(), exclude.getHopID()); - } - //handle binary operators - //TODO rework plan exploration step - else if( hop.getInput().size() == 2 ) { - int input2ix = (inputs.get(0)==exclude ? 1:0); - Hop input2 = inputs.get(input2ix); - long[] refs = (input2ix==1) ? new long[]{exclude.getHopID(), -1} : new long[]{-1, exclude.getHopID()}; - memo.add(hop, tpl.getType(), refs[0], refs[1]); - if( memo.contains(input2.getHopID()) && !memo.get(input2.getHopID()).get(0).closed - && memo.get(input2.getHopID()).get(0).type == TemplateType.CellTpl && tpl.merge(hop, input2) ) { - refs[input2ix] = input2.getHopID(); - memo.add(hop, tpl.getType(), refs[0], refs[1]); - } - } - else { - LOG.warn("genExplorePlans currently only supports unary and binary operators."); - } + + //mark visited even if no plans found (e.g., unsupported ops) + memo.addHop(hop); } private static void rConstructCPlans(Hop hop, CPlanMemoTable memo, HashMap<Long, Pair<Hop[],CNodeTpl>> cplans, boolean compileLiterals) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9cbaf85a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index ee41521..b0bf75b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -52,6 +52,14 @@ public class CPlanMemoTable _plansBlacklist = new HashSet<Long>(); } + public void addHop(Hop hop) { + _hopRefs.put(hop.getHopID(), hop); + } + + public boolean containsHop(Hop hop) { + return _hopRefs.containsKey(hop.getHopID()); + } + public boolean contains(long hopID) { return _plans.containsKey(hopID); } @@ -79,6 +87,13 @@ public class CPlanMemoTable _plans.put(hop.getHopID(), new ArrayList<MemoTableEntry>()); _plans.get(hop.getHopID()).add(new MemoTableEntry(type, in1, in2, in3)); } + + public void addAll(Hop hop, MemoTableEntrySet P) { + _hopRefs.put(hop.getHopID(), hop); + if( !_plans.containsKey(hop.getHopID()) ) + _plans.put(hop.getHopID(), new ArrayList<MemoTableEntry>()); + _plans.get(hop.getHopID()).addAll(P.plans); + } @SuppressWarnings("unchecked") public void pruneRedundant(long hopID) { @@ -277,4 +292,23 @@ public class CPlanMemoTable return type.name()+"("+input1+","+input2+","+input3+")"; } } + + public static class MemoTableEntrySet + { + public ArrayList<MemoTableEntry> plans = new ArrayList<MemoTableEntry>(); + + public MemoTableEntrySet(TemplateType type, int pos, long hopID, boolean close) { + plans.add(new MemoTableEntry(type, (pos==0)?hopID:-1, + (pos==1)?hopID:-1, (pos==2)?hopID:-1)); + } + + public void crossProduct(int pos, Long... refs) { + ArrayList<MemoTableEntry> tmp = new ArrayList<MemoTableEntry>(); + for( MemoTableEntry me : plans ) + for( Long ref : refs ) + tmp.add(new MemoTableEntry(me.type, (pos==0)?ref:me.input1, + (pos==1)?ref:me.input2, (pos==2)?ref:me.input3)); + plans = tmp; + } + } }