Repository: systemml
Updated Branches:
  refs/heads/master bdf42c068 -> 622d36c4a


[SYSTEMML-2166] Fix scrambled print order after function inlining

With SystemML 1.1 a number of new rewrites were introduced that merge
sequences of statement blocks, which for example appear after branch
removal and function inlining. Although this greatly improves
performance for certain workloads, it might lead to a reordering of
operations without dependencies (such as prints) that end up in the
consolidated DAG. SYSTEMML-2050 tried to address this by ordering of
prints by their line numbers. Unfortunately, this works only in certain
situations because after function lining the print line numbers do not
correspond to the order of the original function calls. 

This patch fixes this issue by merging the non-overlapping roots of
statement blocks s1 and s2 in order of (s1,s2) instead of (s2,s1), where
the latter originated from the fact that we merge s1 into s2.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/622d36c4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/622d36c4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/622d36c4

Branch: refs/heads/master
Commit: 622d36c4a8d0cea4cec4de4ede98c1bcccd7ca8a
Parents: bdf42c0
Author: Matthias Boehm <[email protected]>
Authored: Thu Mar 1 18:51:34 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Thu Mar 1 19:10:10 2018 -0800

----------------------------------------------------------------------
 scripts/nn/test/test.dml                            |  4 ++--
 .../hops/rewrite/RewriteMergeBlockSequence.java     | 16 ++++++++++++----
 .../java/org/apache/sysml/lops/compile/Dag.java     |  5 +++--
 3 files changed, 17 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/622d36c4/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index 2a04f97..391c7f0 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -853,9 +853,9 @@ compare_tanh_builtin_forward_with_old = function() {
 
 compare_tanh_builtin_backward_with_old = function() {
   /*
-   * Test for the `tanh` forward function.
+   * Test for the `tanh` backward function.
    */
-  print("Testing the tanh forward function.")
+  print("Testing the tanh backward function.")
 
   # Generate data
   N = 2  # num examples

http://git-wip-us.apache.org/repos/asf/systemml/blob/622d36c4/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
index 775006d..cf7701c 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
@@ -73,8 +73,12 @@ public class RewriteMergeBlockSequence extends 
StatementBlockRewriteRule
                                        && (!hasFunctionOpRoot(sb1) || 
!hasFunctionIOConflict(sb1,sb2))
                                        && (!hasFunctionOpRoot(sb2) || 
!hasFunctionIOConflict(sb2,sb1)) )
                                {
+                                       //note: we intend to merge sb1 into sb2 
to connect data dependencies
+                                       //however, we work with a temporary 
list of root nodes to preserve
+                                       //the original order of roots, which 
affects prints w/o dependencies
                                        ArrayList<Hop> sb1Hops = sb1.getHops();
                                        ArrayList<Hop> sb2Hops = sb2.getHops();
+                                       ArrayList<Hop> newHops = new 
ArrayList<>();
                                        
                                        //determine transient read inputs s2 
                                        Hop.resetVisitStatus(sb2Hops);
@@ -99,23 +103,27 @@ public class RewriteMergeBlockSequence extends 
StatementBlockRewriteRule
                                                        //add transient write 
if necessary
                                                        if( 
!twrites.containsKey(root.getName()) 
                                                                && 
sb2.liveOut().containsVariable(root.getName()) ) {
-                                                               
sb2Hops.add(HopRewriteUtils.createDataOp(
+                                                               
newHops.add(HopRewriteUtils.createDataOp(
                                                                        
root.getName(), in, DataOpTypes.TRANSIENTWRITE));
                                                        }
                                                }
                                                //add remaining roots from s1 
to s2
                                                else if( 
!(HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE)
                                                        && 
(twrites.containsKey(root.getName()) || 
!sb2.liveOut().containsVariable(root.getName()))) ) {
-                                                       sb2Hops.add(root);
+                                                       newHops.add(root);
                                                }
                                        }
                                        //clear partial hops from the merged 
statement block to avoid problems with 
                                        //other statement block rewrites that 
iterate over the original program
                                        sb1Hops.clear();
                                        
+                                       //append all root nodes of s2 after 
root nodes of s1
+                                       newHops.addAll(sb2Hops);
+                                       sb2.setHops(newHops);
+                                       
                                        //run common-subexpression elimination
-                                       Hop.resetVisitStatus(sb2Hops);
-                                       rewriter.rewriteHopDAG(sb2Hops, new 
ProgramRewriteStatus());
+                                       Hop.resetVisitStatus(sb2.getHops());
+                                       rewriter.rewriteHopDAG(sb2.getHops(), 
new ProgramRewriteStatus());
                                        
                                        //modify live variable sets of s2
                                        sb2.setLiveIn(sb1.liveIn()); //liveOut 
remains unchanged

http://git-wip-us.apache.org/repos/asf/systemml/blob/622d36c4/src/main/java/org/apache/sysml/lops/compile/Dag.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java 
b/src/main/java/org/apache/sysml/lops/compile/Dag.java
index d662124..b777623 100644
--- a/src/main/java/org/apache/sysml/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java
@@ -3627,10 +3627,11 @@ public class Dag<N extends Lop>
        private ArrayList<Lop> doTopologicalSortTwoLevelOrder(ArrayList<Lop> v) 
{
                //partition nodes into leaf/inner nodes and dag root nodes,
                //+ sort leaf/inner nodes by ID to force depth-first scheduling
-               //+ sort root nodes by line numbers to force ordering of prints 
+               //+ append root nodes in order of their original definition 
+               //  (which also preserves the original order of prints)
                Lop[] nodearray = Stream.concat(
                        v.stream().filter(l -> 
!l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getID())),
-                       v.stream().filter(l -> 
l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getBeginLine())))
+                       v.stream().filter(l -> l.getOutputs().isEmpty()))
                        .toArray(Lop[]::new);
                
                return createIDMapping(nodearray);

Reply via email to