This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new b5b6f37e20 [SYSTEMDS-3765] Fix time displacement through function
hoisting
b5b6f37e20 is described below
commit b5b6f37e2064f8eda40e820e5821cd70f379e717
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Dec 7 12:11:22 2024 +0100
[SYSTEMDS-3765] Fix time displacement through function hoisting
This patch fixes issues with time() functions which are used to
measure execution time of parts of a program. When these functions
were used in expressions (e.g., print string concatenation) the normal
DAG compilation might move them before the operation that was actually
measured. Similar to DML function calls, we now hoist these time
functions out of expressions.
---
.../java/org/apache/sysds/parser/StatementBlock.java | 9 ++++++++-
.../functions/rewrite/RewriteHoistingTimeTest.java | 20 +++++++++++++++-----
.../functions/rewrite/RewriteTimeHoisting.dml | 2 +-
3 files changed, 24 insertions(+), 7 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index cc37b252c6..82501f63c5 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -595,6 +595,13 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
Expression[] clexpr = lexpr.getAllExpr();
for( int i=0; i<clexpr.length; i++ )
clexpr[i] =
rHoistFunctionCallsFromExpressions(clexpr[i], false, tmp, prog);
+ if( !root && lexpr.getOpCode()==Builtins.TIME ) {
//core time hoisting
+ String varname =
StatementBlockRewriteRule.createCutVarName(true);
+ DataIdentifier di = new DataIdentifier(varname);
+ di.setDataType(lexpr.getDataType());
+ di.setValueType(lexpr.getValueType());
+ tmp.add(new AssignmentStatement(di, lexpr, di));
+ }
}
else if( expr instanceof ParameterizedBuiltinFunctionExpression
) {
ParameterizedBuiltinFunctionExpression lexpr =
(ParameterizedBuiltinFunctionExpression) expr;
@@ -612,7 +619,7 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
FunctionCallIdentifier fexpr = (FunctionCallIdentifier)
expr;
for( ParameterExpression pexpr : fexpr.getParamExprs() )
pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp,
prog));
- if( !root ) { //core hoisting
+ if( !root ) { //core fcall hoisting
String varname =
StatementBlockRewriteRule.createCutVarName(true);
DataIdentifier di = new DataIdentifier(varname);
di.setDataType(fexpr.getDataType());
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java
index fd7f3bdda4..40cdb5c738 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingTimeTest.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.functions.rewrite;
+import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
@@ -43,9 +44,14 @@ public class RewriteHoistingTimeTest extends
AutomatedTestBase
}
@Test
- public void testTimeHoisting() {
+ public void testTimeHoistingCP() {
test(TEST_NAME1, ExecType.CP);
}
+
+ @Test
+ public void testTimeHoistingSpark() {
+ test(TEST_NAME1, ExecType.SPARK);
+ }
private void test(String testname, ExecType et)
{
@@ -58,11 +64,15 @@ public class RewriteHoistingTimeTest extends
AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] { "-explain", "-args",
+ programArgs = new String[] {"-args",
String.valueOf(rows), String.valueOf(cols) };
-
- //FIXME need to hoist time() out of expression similar
to function calls
- runTest(true, false, null, -1);
+
+ //test that time is not executed before 1k-by-1k rand
+ setOutputBuffering(true);
+ String out = runTest(true, false, null, -1).toString();
+ double time = Double.parseDouble(out.split(";")[1]);
+ System.out.println("Time = "+time+"s");
+ Assert.assertTrue(time>0.001);
}
finally {
resetExecMode(platformOld);
diff --git a/src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml
b/src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml
index b4a4457212..46c3623ccd 100644
--- a/src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml
+++ b/src/test/scripts/functions/rewrite/RewriteTimeHoisting.dml
@@ -22,5 +22,5 @@
t1 = time();
X = rand(rows=$1, cols=$2);
-print("time = "+(time()-t1)/1e9+"s"+" "+sum(X));
+print(";"+(time()-t1)/1e9+";"+" "+sum(X));