[SYSTEMML-1444] Fix function call hoisting in nested control flow The recent simplification of the nn-lstm layer with UDFs in expressions revealed a hidden issue of function call hoisting in nested control flow. Specifically, newly added statement blocks where not reassigned and thus lost, which led to validation errors due to non-existing input variables in statements that have been altered during hoisting.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1f63b09c Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1f63b09c Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1f63b09c Branch: refs/heads/master Commit: 1f63b09cdd0498a94f6bf4e8dfefbfa3cc27d948 Parents: 63be18a Author: Matthias Boehm <[email protected]> Authored: Mon Mar 5 15:12:32 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Mon Mar 5 15:12:49 2018 -0800 ---------------------------------------------------------------------- .../org/apache/sysml/parser/StatementBlock.java | 24 +++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1f63b09c/src/main/java/org/apache/sysml/parser/StatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java index f7901c1..04e4a34 100644 --- a/src/main/java/org/apache/sysml/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java @@ -412,31 +412,43 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo if (current instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)current; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : fstmt.getBody()) - rHoistFunctionCallsFromExpressions(sb); + tmp.addAll(rHoistFunctionCallsFromExpressions(sb)); + fstmt.setBody(tmp); } else if (current instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) current; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); //TODO handle predicates + ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : wstmt.getBody()) - rHoistFunctionCallsFromExpressions(sb); + tmp.addAll(rHoistFunctionCallsFromExpressions(sb)); + wstmt.setBody(tmp); } else if (current instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) current; IfStatement istmt = (IfStatement)isb.getStatement(0); //TODO handle predicates + ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : istmt.getIfBody()) - rHoistFunctionCallsFromExpressions(sb); - for (StatementBlock sb : istmt.getElseBody()) - rHoistFunctionCallsFromExpressions(sb); + tmp.addAll(rHoistFunctionCallsFromExpressions(sb)); + istmt.setIfBody(tmp); + if( istmt.getElseBody() != null && !istmt.getElseBody().isEmpty() ) { + ArrayList<StatementBlock> tmp2 = new ArrayList<>(); + for (StatementBlock sb : istmt.getElseBody()) + tmp2.addAll(rHoistFunctionCallsFromExpressions(sb)); + istmt.setElseBody(tmp2); + } } else if (current instanceof ForStatementBlock) { //incl parfor ForStatementBlock fsb = (ForStatementBlock) current; ForStatement fstmt = (ForStatement)fsb.getStatement(0); //TODO handle predicates + ArrayList<StatementBlock> tmp = new ArrayList<>(); for (StatementBlock sb : fstmt.getBody()) - rHoistFunctionCallsFromExpressions(sb); + tmp.addAll(rHoistFunctionCallsFromExpressions(sb)); + fstmt.setBody(tmp); } else { //generic (last-level) ArrayList<Statement> tmp = new ArrayList<>();
