[SYSTEMML-1444] Fix function call hoisting in nested control flow

The recent simplification of the nn-lstm layer with UDFs in expressions
revealed a hidden issue of function call hoisting in nested control
flow. Specifically, newly added statement blocks where not reassigned
and thus lost, which led to validation errors due to non-existing input
variables in statements that have been altered during hoisting.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1f63b09c
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1f63b09c
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1f63b09c

Branch: refs/heads/master
Commit: 1f63b09cdd0498a94f6bf4e8dfefbfa3cc27d948
Parents: 63be18a
Author: Matthias Boehm <[email protected]>
Authored: Mon Mar 5 15:12:32 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Mon Mar 5 15:12:49 2018 -0800

----------------------------------------------------------------------
 .../org/apache/sysml/parser/StatementBlock.java | 24 +++++++++++++++-----
 1 file changed, 18 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/1f63b09c/src/main/java/org/apache/sysml/parser/StatementBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java 
b/src/main/java/org/apache/sysml/parser/StatementBlock.java
index f7901c1..04e4a34 100644
--- a/src/main/java/org/apache/sysml/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java
@@ -412,31 +412,43 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                if (current instanceof FunctionStatementBlock) {
                        FunctionStatementBlock fsb = 
(FunctionStatementBlock)current;
                        FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
+                       ArrayList<StatementBlock> tmp = new ArrayList<>();
                        for (StatementBlock sb : fstmt.getBody())
-                               rHoistFunctionCallsFromExpressions(sb);
+                               
tmp.addAll(rHoistFunctionCallsFromExpressions(sb));
+                       fstmt.setBody(tmp);
                }
                else if (current instanceof WhileStatementBlock) {
                        WhileStatementBlock wsb = (WhileStatementBlock) current;
                        WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
                        //TODO handle predicates
+                       ArrayList<StatementBlock> tmp = new ArrayList<>();
                        for (StatementBlock sb : wstmt.getBody())
-                               rHoistFunctionCallsFromExpressions(sb);
+                               
tmp.addAll(rHoistFunctionCallsFromExpressions(sb));
+                       wstmt.setBody(tmp);
                }
                else if (current instanceof IfStatementBlock) {
                        IfStatementBlock isb = (IfStatementBlock) current;
                        IfStatement istmt = (IfStatement)isb.getStatement(0);
                        //TODO handle predicates
+                       ArrayList<StatementBlock> tmp = new ArrayList<>();
                        for (StatementBlock sb : istmt.getIfBody())
-                               rHoistFunctionCallsFromExpressions(sb);
-                       for (StatementBlock sb : istmt.getElseBody())
-                               rHoistFunctionCallsFromExpressions(sb);
+                               
tmp.addAll(rHoistFunctionCallsFromExpressions(sb));
+                       istmt.setIfBody(tmp);
+                       if( istmt.getElseBody() != null && 
!istmt.getElseBody().isEmpty() ) {
+                               ArrayList<StatementBlock> tmp2 = new 
ArrayList<>();
+                               for (StatementBlock sb : istmt.getElseBody())
+                                       
tmp2.addAll(rHoistFunctionCallsFromExpressions(sb));
+                               istmt.setElseBody(tmp2);
+                       }
                }
                else if (current instanceof ForStatementBlock) { //incl parfor
                        ForStatementBlock fsb = (ForStatementBlock) current;
                        ForStatement fstmt = (ForStatement)fsb.getStatement(0);
                        //TODO handle predicates
+                       ArrayList<StatementBlock> tmp = new ArrayList<>();
                        for (StatementBlock sb : fstmt.getBody())
-                               rHoistFunctionCallsFromExpressions(sb);
+                               
tmp.addAll(rHoistFunctionCallsFromExpressions(sb));
+                       fstmt.setBody(tmp);
                }
                else { //generic (last-level)
                        ArrayList<Statement> tmp = new ArrayList<>();

Reply via email to