This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 0cc2c9f  [SYSTEMDS-2949] Fix function call hoisting out of expressions
0cc2c9f is described below

commit 0cc2c9f98ca2bd242d6d7d2e20c3802e52f83f9b
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed May 12 23:10:16 2021 +0200

    [SYSTEMDS-2949] Fix function call hoisting out of expressions
    
    This patch fixes parsing issues where partially incorrect hoisting of
    function calls out of complex expressions lead to null pointer
    exceptions (as a result of hoisting multiple functions). We now added
    more tests and use a conservative approach of cuts before and after
    hoisted functions, which are later merged with other blocks but ensure
    the validity of the simplifying assumption made in the parser.
---
 .../org/apache/sysds/parser/DMLTranslator.java     | 52 ++++++++++------------
 .../org/apache/sysds/parser/StatementBlock.java    | 34 +++++++-------
 .../functions/misc/FunctionInExpressionTest.java   | 48 +++++++++++---------
 .../scripts/functions/misc/FunInExpression8.dml    | 27 +++++++++++
 .../scripts/functions/misc/FunInExpression9.dml    | 32 +++++++++++++
 5 files changed, 128 insertions(+), 65 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index b050a3b..3b2146a 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1736,44 +1736,40 @@ public class DMLTranslator
                Hop left  = processExpression(source.getLeft(),  null, hops);
                Hop right = processExpression(source.getRight(), null, hops);
 
-               if (left == null || right == null){
-                       left  = processExpression(source.getLeft(),  null, 
hops);
-                       right = processExpression(source.getRight(), null, 
hops);
+               if (left == null || right == null) {
+                       throw new ParseException("Missing input in binary 
expressions (" + source.toString()+"): "
+                               + 
((left==null)?source.getLeft():source.getRight())+", 
line="+source.getBeginLine());
                }
-       
-               Hop currBop = null;
-
+               
                //prepare target identifier and ensure that output type is of 
inferred type 
-        //(type should not be determined by target (e.g., string for print)
+               //(type should not be determined by target (e.g., string for 
print)
                if (target == null) {
-                   target = createTarget(source);
+                       target = createTarget(source);
                }
                target.setValueType(source.getOutput().getValueType());
                
-               if (source.getOpCode() == Expression.BinaryOp.PLUS) {
-                       currBop = new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.PLUS, left, right);
-               } else if (source.getOpCode() == Expression.BinaryOp.MINUS) {
-                       currBop = new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.MINUS, left, right);
-               } else if (source.getOpCode() == Expression.BinaryOp.MULT) {
-                       currBop = new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.MULT, left, right);
-               } else if (source.getOpCode() == Expression.BinaryOp.DIV) {
-                       currBop = new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.DIV, left, right);
-               } else if (source.getOpCode() == Expression.BinaryOp.MODULUS) {
-                       currBop = new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.MODULUS, left, right);
-               } else if (source.getOpCode() == Expression.BinaryOp.INTDIV) {
-                       currBop = new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.INTDIV, left, right);
-               } else if (source.getOpCode() == Expression.BinaryOp.MATMULT) {
-                       currBop = new AggBinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.MULT, 
org.apache.sysds.common.Types.AggOp.SUM, left, right);
-               } else if (source.getOpCode() == Expression.BinaryOp.POW) {
-                       currBop = new BinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.POW, left, right);
-               }
-               else {
-                       throw new ParseException("Unsupported parsing of binary 
expression: "+source.getOpCode());
+               Hop currBop = null;
+               switch( source.getOpCode() ) {
+                       case PLUS:
+                       case MINUS:
+                       case MULT:
+                       case DIV:
+                       case MODULUS:
+                       case POW:
+                       case INTDIV:
+                               currBop = new BinaryOp(target.getName(), 
target.getDataType(),
+                                       target.getValueType(), 
OpOp2.valueOf(source.getOpCode().name()), left, right);
+                               break;
+                       case MATMULT:
+                               currBop = new AggBinaryOp(target.getName(), 
target.getDataType(), target.getValueType(), OpOp2.MULT, 
org.apache.sysds.common.Types.AggOp.SUM, left, right);
+                               break;
+                       default:
+                               throw new ParseException("Unsupported parsing 
of binary expression: "+source.getOpCode());
                }
+               
                setIdentifierParams(currBop, source.getOutput());
                currBop.setParseInfo(source);
                return currBop;
-               
        }
 
        private Hop processRelationalExpression(RelationalExpression source, 
DataIdentifier target, HashMap<String, Hop> hops) {
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index c4876d4..570af39 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -183,10 +183,10 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
        public boolean isMergeableFunctionCallBlock(DMLProgram dmlProg) {
                // check whether targetIndex stmt block is for a mergable 
function call
                Statement stmt = this.getStatement(0);
-
+               
                // Check whether targetIndex block is: control stmt block or 
stmt block for un-mergable function call
                if (   stmt instanceof WhileStatement || stmt instanceof 
IfStatement || stmt instanceof ForStatement
-                       || stmt instanceof FunctionStatement || 
isMergeablePrintStatement(stmt) /*|| stmt instanceof ELStatement*/ )
+                       || stmt instanceof FunctionStatement || 
isMergeablePrintStatement(stmt) )
                {
                        return false;
                }
@@ -232,7 +232,7 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                        }
                }
 
-               // regular function block
+               // regular statement block
                return true;
        }
 
@@ -360,18 +360,17 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                ArrayList<StatementBlock> result = new ArrayList<>();
                StatementBlock currentBlock = null;
 
-               for (int i = 0; i < body.size(); i++){
+               for (int i = 0; i < body.size(); i++) {
                        StatementBlock current = body.get(i);
                        if (current.isMergeableFunctionCallBlock(dmlProg)){
-                               if (currentBlock != null) {
+                               if (currentBlock != null)
                                        currentBlock.addStatementBlock(current);
-                               } else {
+                               else
                                        currentBlock = current;
-                               }
-                       } else {
-                               if (currentBlock != null) {
+                       }
+                       else {
+                               if (currentBlock != null)
                                        result.add(currentBlock);
-                               }
                                result.add(current);
                                currentBlock = null;
                        }
@@ -465,7 +464,6 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                }
 
                return result;
-
        }
        
        public static List<StatementBlock> 
rHoistFunctionCallsFromExpressions(StatementBlock current, DMLProgram prog) {
@@ -634,11 +632,17 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                List<StatementBlock> ret = new ArrayList<>();
                StatementBlock current = new StatementBlock(sb);
                for(Statement stmt : stmts) {
+                       //cut the statement block before and after the current 
function
+                       //(cut before is precondition for subsequent merge 
steps which 
+                       //assume function statements as the first statement in 
the block)
+                       boolean cut = stmt instanceof AssignmentStatement
+                               && ((AssignmentStatement)stmt).getSource() 
instanceof FunctionCallIdentifier;
+                       if( cut && current.getNumStatements() > 0 ) { //before
+                               ret.add(current);
+                               current = new StatementBlock(sb);
+                       }
                        current.addStatement(stmt);
-                       //cut the statement block after the current function
-                       if( stmt instanceof AssignmentStatement
-                               && ((AssignmentStatement)stmt).getSource()
-                               instanceof FunctionCallIdentifier ) {
+                       if( cut ) { //after
                                ret.add(current);
                                current = new StatementBlock(sb);
                        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java
 
b/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java
index 78ad721..da54458 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/misc/FunctionInExpressionTest.java
@@ -29,13 +29,12 @@ import org.apache.sysds.test.TestUtils;
 
 public class FunctionInExpressionTest extends AutomatedTestBase 
 {
-       private final static String TEST_NAME1 = "FunInExpression1";
-       private final static String TEST_NAME2 = "FunInExpression2";
-       private final static String TEST_NAME3 = "FunInExpression3";
-       private final static String TEST_NAME4 = "FunInExpression4";
-       private final static String TEST_NAME5 = "FunInExpression5";
-       private final static String TEST_NAME6 = "FunInExpression6";
-       private final static String TEST_NAME7 = "FunInExpression7"; 
//dml-bodied builtin
+       private final static String[] TEST_NAMES = new String[] {
+               "FunInExpression1", "FunInExpression2", "FunInExpression3",
+               "FunInExpression4", "FunInExpression5", "FunInExpression6",
+                //dml-bodied functions (w/ and w/o CSEs)
+               "FunInExpression7", "FunInExpression8", "FunInExpression9"
+       };
        
        private final static String TEST_DIR = "functions/misc/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FunctionInExpressionTest.class.getSimpleName() + "/";
@@ -43,48 +42,53 @@ public class FunctionInExpressionTest extends 
AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
-               addTestConfiguration( TEST_NAME7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] { "R" }) );
+               for(int i=0; i<TEST_NAMES.length; i++)
+                       addTestConfiguration(TEST_NAMES[i], new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAMES[i], new String[] {"R"}));
        }
 
        @Test
        public void testFunInExpression1() {
-               runFunInExpressionTest( TEST_NAME1 );
+               runFunInExpressionTest( TEST_NAMES[0] );
        }
        
        @Test
        public void testFunInExpression2() {
-               runFunInExpressionTest( TEST_NAME2 );
+               runFunInExpressionTest( TEST_NAMES[1] );
        }
        
        @Test
        public void testFunInExpression3() {
-               runFunInExpressionTest( TEST_NAME3 );
+               runFunInExpressionTest( TEST_NAMES[2] );
        }
        
        @Test
        public void testFunInExpression4() {
-               runFunInExpressionTest( TEST_NAME4 );
+               runFunInExpressionTest( TEST_NAMES[3] );
        }
 
        @Test
        public void testFunInExpression5() {
-               runFunInExpressionTest( TEST_NAME5 );
+               runFunInExpressionTest( TEST_NAMES[4] );
        }
 
        @Test
        public void testFunInExpression6() {
-               runFunInExpressionTest( TEST_NAME6 );
+               runFunInExpressionTest( TEST_NAMES[5] );
        }
        
        @Test
        public void testFunInExpression7() {
-               runFunInExpressionTest( TEST_NAME7 );
+               runFunInExpressionTest( TEST_NAMES[6] );
+       }
+       
+       @Test
+       public void testFunInExpression8() {
+               runFunInExpressionTest( TEST_NAMES[7] );
+       }
+       
+       @Test
+       public void testFunInExpression9() {
+               runFunInExpressionTest( TEST_NAMES[8] );
        }
        
        private void runFunInExpressionTest( String testName )
@@ -94,7 +98,7 @@ public class FunctionInExpressionTest extends 
AutomatedTestBase
                
                String HOME = SCRIPT_DIR + TEST_DIR;
                fullDMLScriptName = HOME + testName + ".dml";
-               programArgs = new String[]{"-args", output("R") };
+               programArgs = new String[]{"-explain","-args", output("R") };
                
                fullRScriptName = HOME + testName + ".R";
                rCmd = getRCmd(expectedDir());
diff --git a/src/test/scripts/functions/misc/FunInExpression8.dml 
b/src/test/scripts/functions/misc/FunInExpression8.dml
new file mode 100644
index 0000000..70c8080
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunInExpression8.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = matrix(4, rows=10, cols=10);
+R1 = sigmoid(A) + 7;
+R2 = sigmoid(A) - 7;
+R = as.matrix(sum(abs(R2-R1+14)<1e-10)*7/100)
+print(toString(R))
+write(R, $1);
diff --git a/src/test/scripts/functions/misc/FunInExpression9.dml 
b/src/test/scripts/functions/misc/FunInExpression9.dml
new file mode 100644
index 0000000..254b870
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunInExpression9.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] A, Double val)
+  return(Matrix[Double] R)
+{
+  R1 = sigmoid(A) + val;
+  R2 = sigmoid(A) - val;
+  R = as.matrix(sum(abs(R2-R1+14)<1e-10)*7/100)
+}
+
+A = matrix(4, rows=10, cols=10);
+R = foo(A, 7)
+write(R, $1);

Reply via email to