This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new b5b6f37e20 [SYSTEMDS-3765] Fix time displacement through function 
hoisting
b5b6f37e20 is described below

commit b5b6f37e2064f8eda40e820e5821cd70f379e717
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Dec 7 12:11:22 2024 +0100

    [SYSTEMDS-3765] Fix time displacement through function hoisting
    
    This patch fixes issues with time() functions which are used to
    measure execution time of parts of a program. When these functions
    were used in expressions (e.g., print string concatenation) the normal
    DAG compilation might move them before the operation that was actually
    measured. Similar to DML function calls, we now hoist these time
    functions out of expressions.
---
 .../java/org/apache/sysds/parser/StatementBlock.java |  9 ++++++++-
 .../functions/rewrite/RewriteHoistingTimeTest.java   | 20 +++++++++++++++-----
 .../functions/rewrite/RewriteTimeHoisting.dml        |  2 +-
 3 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index cc37b252c6..82501f63c5 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -595,6 +595,13 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                        Expression[] clexpr = lexpr.getAllExpr();
                        for( int i=0; i<clexpr.length; i++ )
                                clexpr[i] = 
rHoistFunctionCallsFromExpressions(clexpr[i], false, tmp, prog);
+                       if( !root && lexpr.getOpCode()==Builtins.TIME ) { 
//core time hoisting
+                               String varname = 
StatementBlockRewriteRule.createCutVarName(true);
+                               DataIdentifier di = new DataIdentifier(varname);
+                               di.setDataType(lexpr.getDataType());
+                               di.setValueType(lexpr.getValueType());
+                               tmp.add(new AssignmentStatement(di, lexpr, di));
+                       }
                }
                else if( expr instanceof ParameterizedBuiltinFunctionExpression 
) {
                        ParameterizedBuiltinFunctionExpression lexpr = 
(ParameterizedBuiltinFunctionExpression) expr;
@@ -612,7 +619,7 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                        FunctionCallIdentifier fexpr = (FunctionCallIdentifier) 
expr;
                        for( ParameterExpression pexpr : fexpr.getParamExprs() )
                                
pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp, 
prog));
-                       if( !root ) { //core hoisting
+                       if( !root ) { //core fcall hoisting
                                String varname = 
StatementBlockRewriteRule.createCutVarName(true);
                                DataIdentifier di = new DataIdentifier(varname);
                                di.setDataType(fexpr.getDataType());
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java
index fd7f3bdda4..40cdb5c738 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.test.functions.rewrite;
 
+import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.ExecType;
@@ -43,9 +44,14 @@ public class RewriteHoistingTimeTest extends 
AutomatedTestBase
        }
 
        @Test
-       public void testTimeHoisting() {
+       public void testTimeHoistingCP() {
                test(TEST_NAME1, ExecType.CP);
        }
+       
+       @Test
+       public void testTimeHoistingSpark() {
+               test(TEST_NAME1, ExecType.SPARK);
+       }
 
        private void test(String testname, ExecType et)
        {
@@ -58,11 +64,15 @@ public class RewriteHoistingTimeTest extends 
AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[] { "-explain", "-args",
+                       programArgs = new String[] {"-args",
                                String.valueOf(rows), String.valueOf(cols) };
-                       
-                       //FIXME need to hoist time() out of expression similar 
to function calls
-                       runTest(true, false, null, -1); 
+
+                       //test that time is not executed before 1k-by-1k rand
+                       setOutputBuffering(true);
+                       String out = runTest(true, false, null, -1).toString();
+                       double time = Double.parseDouble(out.split(";")[1]);
+                       System.out.println("Time = "+time+"s");
+                       Assert.assertTrue(time>0.001);
                }
                finally {
                        resetExecMode(platformOld);
diff --git a/src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml 
b/src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml
index b4a4457212..46c3623ccd 100644
--- a/src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml
+++ b/src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml
@@ -22,5 +22,5 @@
 t1 = time();
 X = rand(rows=$1, cols=$2);
 
-print("time = "+(time()-t1)/1e9+"s"+" "+sum(X));
+print(";"+(time()-t1)/1e9+";"+" "+sum(X));
 

Reply via email to