Repository: incubator-systemml
Updated Branches:
  refs/heads/master 30f72e83f -> 2893e1aed


[SYSTEMML-1374] Improved candidate exploration of code generation plans

This patch improves the codegen candidate exploration algorithm by (1)
better memoization (which now also includes unsupported operators) and
(2) a simplified creation of merge plans (which now also applies to
ternary operators).

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9cbaf85a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9cbaf85a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9cbaf85a

Branch: refs/heads/master
Commit: 9cbaf85ab1389cde9fb79f58e29c0adb6044c493
Parents: 30f72e8
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sat Mar 18 21:39:28 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sat Mar 18 21:39:28 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       | 60 +++++++++-----------
 .../hops/codegen/template/CPlanMemoTable.java   | 34 +++++++++++
 2 files changed, 60 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9cbaf85a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 54e67b6..6479917 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -44,6 +44,7 @@ import 
org.apache.sysml.hops.codegen.template.BaseTpl.CloseType;
 import org.apache.sysml.hops.codegen.template.BaseTpl.TemplateType;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntrySet;
 import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.OpOp1;
@@ -318,68 +319,59 @@ public class SpoofCompiler
                throws DMLException
        {               
                //top-down memoization of processed dag nodes
-               if( memo.contains(hop.getHopID()) )
+               if( memo.contains(hop.getHopID()) || memo.containsHop(hop) )
                        return;
                
-               //recursively process child nodes
+               //recursive candidate exploration
                for( Hop c : hop.getInput() )
                        rExploreCPlans(c, memo, compileLiterals);
                
-               //generate new node plans
+               //open initial operator plans, if possible
                for( BaseTpl tpl : TemplateUtils.TEMPLATES )
                        if( tpl.open(hop) )
                                memo.add(hop, tpl.getType());
                
+               //fuse and merge operator plans
                for( Hop c : hop.getInput() ) {
                        if( memo.contains(c.getHopID()) )
                                for( MemoTableEntry me : memo.get(c.getHopID()) 
) {
                                        BaseTpl tpl = 
TemplateUtils.createTemplate(me.type, me.closed);
-                                       if( tpl.fuse(hop, c) )
-                                               genExplorePlans(tpl, hop, memo, 
hop.getInput(), c);     
+                                       if( tpl.fuse(hop, c) ) {
+                                               int pos = 
hop.getInput().indexOf(c);
+                                               MemoTableEntrySet P = new 
MemoTableEntrySet(tpl.getType(), pos, c.getHopID(), tpl.isClosed());
+                                               for(int k=0; 
k<hop.getInput().size(); k++)
+                                                       if( k != pos ) {
+                                                               Hop input2 = 
hop.getInput().get(k);
+                                                               if( 
memo.contains(input2.getHopID()) && !memo.get(input2.getHopID()).get(0).closed
+                                                                       && 
memo.get(input2.getHopID()).get(0).type == TemplateType.CellTpl && 
tpl.merge(hop, input2) ) 
+                                                                       
P.crossProduct(k, -1L, input2.getHopID());
+                                                               else
+                                                                       
P.crossProduct(k, -1L);
+                                                       }
+                                               memo.addAll(hop, P);
+                                       }
                                }       
                }
                
                //prune subsumed / redundant plans
                memo.pruneRedundant(hop.getHopID());
                
-               //check if templates require close
+               //close operator plans, if required
                if( memo.contains(hop.getHopID()) ) {
                        Iterator<MemoTableEntry> iter = 
memo.get(hop.getHopID()).iterator();
                        while( iter.hasNext() ) {
                                MemoTableEntry me = iter.next();
                                BaseTpl tpl = 
TemplateUtils.createTemplate(me.type);
                                CloseType ccode = tpl.close(hop);
-                               if( ccode != CloseType.OPEN ) {
+                               if( ccode == CloseType.CLOSED_INVALID )
+                                       iter.remove();
+                               else if( ccode == CloseType.CLOSED_VALID )
                                        me.closed = true;
-                                       if( ccode == CloseType.CLOSED_INVALID )
-                                               iter.remove();
-                               }
                        }
                }
-       }
-       
-       private static void genExplorePlans(BaseTpl tpl, Hop hop, 
CPlanMemoTable memo, ArrayList<Hop> inputs, Hop exclude)
-       {
-               //handle unary operators
-               if( hop.getInput().size() == 1 ) {
-                       memo.add(hop, tpl.getType(), exclude.getHopID());
-               }
-               //handle binary operators
-               //TODO rework plan exploration step
-               else if( hop.getInput().size() == 2 ) {
-                       int input2ix = (inputs.get(0)==exclude ? 1:0);
-                       Hop input2 = inputs.get(input2ix); 
-                       long[] refs = (input2ix==1) ? new 
long[]{exclude.getHopID(), -1} : new long[]{-1, exclude.getHopID()};
-                       memo.add(hop, tpl.getType(), refs[0], refs[1]);         
-                       if( memo.contains(input2.getHopID()) && 
!memo.get(input2.getHopID()).get(0).closed
-                               && memo.get(input2.getHopID()).get(0).type == 
TemplateType.CellTpl && tpl.merge(hop, input2) ) {
-                               refs[input2ix] = input2.getHopID();
-                               memo.add(hop, tpl.getType(), refs[0], refs[1]); 
        
-                       }
-               }
-               else {
-                       LOG.warn("genExplorePlans currently only supports unary 
and binary operators.");
-               }
+               
+               //mark visited even if no plans found (e.g., unsupported ops)
+               memo.addHop(hop);
        }
        
        private static void rConstructCPlans(Hop hop, CPlanMemoTable memo, 
HashMap<Long, Pair<Hop[],CNodeTpl>> cplans, boolean compileLiterals) 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9cbaf85a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index ee41521..b0bf75b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -52,6 +52,14 @@ public class CPlanMemoTable
                _plansBlacklist = new HashSet<Long>();
        }
        
+       public void addHop(Hop hop) {
+               _hopRefs.put(hop.getHopID(), hop);
+       }
+       
+       public boolean containsHop(Hop hop) {
+               return _hopRefs.containsKey(hop.getHopID());
+       }
+       
        public boolean contains(long hopID) {
                return _plans.containsKey(hopID);
        }
@@ -79,6 +87,13 @@ public class CPlanMemoTable
                        _plans.put(hop.getHopID(), new 
ArrayList<MemoTableEntry>());
                _plans.get(hop.getHopID()).add(new MemoTableEntry(type, in1, 
in2, in3));
        }
+       
+       public void addAll(Hop hop, MemoTableEntrySet P) {
+               _hopRefs.put(hop.getHopID(), hop);
+               if( !_plans.containsKey(hop.getHopID()) )
+                       _plans.put(hop.getHopID(), new 
ArrayList<MemoTableEntry>());
+               _plans.get(hop.getHopID()).addAll(P.plans);
+       }
 
        @SuppressWarnings("unchecked")
        public void pruneRedundant(long hopID) {
@@ -277,4 +292,23 @@ public class CPlanMemoTable
                        return type.name()+"("+input1+","+input2+","+input3+")";
                }
        }
+       
+       public static class MemoTableEntrySet 
+       {
+               public ArrayList<MemoTableEntry> plans = new 
ArrayList<MemoTableEntry>();
+               
+               public MemoTableEntrySet(TemplateType type, int pos, long 
hopID, boolean close) {
+                       plans.add(new MemoTableEntry(type, (pos==0)?hopID:-1, 
+                                       (pos==1)?hopID:-1, (pos==2)?hopID:-1));
+               }
+               
+               public void crossProduct(int pos, Long... refs) {
+                       ArrayList<MemoTableEntry> tmp = new 
ArrayList<MemoTableEntry>();
+                       for( MemoTableEntry me : plans )
+                               for( Long ref : refs )
+                                       tmp.add(new MemoTableEntry(me.type, 
(pos==0)?ref:me.input1, 
+                                               (pos==1)?ref:me.input2, 
(pos==2)?ref:me.input3));
+                       plans = tmp;
+               }
+       }
 }

Reply via email to