This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new f1bffeb [SYSTEMDS-344] New IPA pass for marking deterministic
functions/SBs
f1bffeb is described below
commit f1bffeb299eec6a57d5290fd12c81ba92c9f03e2
Author: arnabp <[email protected]>
AuthorDate: Sun May 17 22:24:22 2020 +0200
[SYSTEMDS-344] New IPA pass for marking deterministic functions/SBs
This patch moves the fragile and less efficient non-determinism check
in runtime to compile time. This adds a new IPA rewrite to unmark the
functions and StatementBlocks containing direct or transitive
nondeterministic calls (e.g. rand with UNSPECIFIED_SEED) for lineage
caching.
AMLS project SS 2020.
Closes #911.
---
docs/Tasks.txt | 2 +-
src/main/java/org/apache/sysds/hops/DataGenOp.java | 8 +
.../sysds/hops/ipa/IPAPassFlagNonDeterminism.java | 201 +++++++++++++++++++++
.../sysds/hops/ipa/InterProceduralAnalysis.java | 2 +
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 7 +
.../org/apache/sysds/parser/DMLTranslator.java | 1 +
.../sysds/parser/FunctionStatementBlock.java | 9 +
.../org/apache/sysds/parser/StatementBlock.java | 10 +
.../runtime/controlprogram/BasicProgramBlock.java | 4 +-
.../controlprogram/FunctionProgramBlock.java | 9 +
.../instructions/cp/FunctionCallCPInstruction.java | 6 +-
.../apache/sysds/runtime/lineage/LineageCache.java | 8 +-
.../sysds/runtime/util/ProgramConverter.java | 4 +
.../functions/lineage/FunctionFullReuseTest.java | 7 +-
.../functions/lineage/FunctionFullReuse8.dml | 57 ++++++
.../scripts/functions/lineage/LineageReuseAlg2.dml | 4 +-
16 files changed, 322 insertions(+), 17 deletions(-)
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 8163e9f..6d5ff80 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -277,7 +277,7 @@ SYSTEMDS-340 Compiler Assisted Lineage Caching and Reuse
* 341 Finalize unmarking of loop dependent operations
* 342 Mark functions as last-use to enable early eviction
* 343 Identify equal last level HOPs to ensure SB-level reuse
- * 344 Unmark functions/SBs containing non-determinism for caching
+ * 344 Unmark functions/SBs containing non-determinism for caching OK
* 345 Compiler assisted cache configuration
SYSTEMDS-350 Data Cleaning Framework
diff --git a/src/main/java/org/apache/sysds/hops/DataGenOp.java
b/src/main/java/org/apache/sysds/hops/DataGenOp.java
index edcb448..8fdf98d 100644
--- a/src/main/java/org/apache/sysds/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataGenOp.java
@@ -468,6 +468,14 @@ public class DataGenOp extends MultiThreadedHop
return ret;
}
+ public boolean hasUnspecifiedSeed() {
+ if (_op == OpOpDG.RAND || _op == OpOpDG.SINIT) {
+ Hop seed =
getInput().get(_paramIndexMap.get(DataExpression.RAND_SEED));
+ return
seed.getName().equals(String.valueOf(DataGenOp.UNSPECIFIED_SEED));
+ }
+ return false;
+ }
+
public Hop getConstantValue() {
return
getInput().get(_paramIndexMap.get(DataExpression.RAND_MIN));
}
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
new file mode 100644
index 0000000..a000096
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.ipa;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+
+public class IPAPassFlagNonDeterminism extends IPAPass {
+ @Override
+ public boolean isApplicable(FunctionCallGraph fgraph) {
+ return InterProceduralAnalysis.REMOVE_UNUSED_FUNCTIONS
+ && !fgraph.containsSecondOrderCall();
+ }
+
+ @Override
+ public void rewriteProgram (DMLProgram prog, FunctionCallGraph fgraph,
FunctionCallSizeInfo fcallSizes)
+ {
+ if (!LineageCacheConfig.isMultiLevelReuse())
+ return;
+
+ try {
+ // Find the individual functions and statementblocks
with non-determinism.
+ HashSet<String> ndfncs = new HashSet<>();
+ for (String fkey : fgraph.getReachableFunctions()) {
+ FunctionStatementBlock fsblock =
prog.getFunctionStatementBlock(fkey);
+ FunctionStatement fnstmt =
(FunctionStatement)fsblock.getStatement(0);
+ String fname =
DMLProgram.splitFunctionKey(fkey)[1];
+ if (rIsNonDeterministicFnc(fname,
fnstmt.getBody()))
+ ndfncs.add(fkey);
+ }
+
+ // Find the callers of the nondeterministic functions.
+ propagate2Callers(fgraph, ndfncs, new
HashSet<String>(), null);
+
+ // Mark the corresponding FunctionStatementBlocks
+ ndfncs.forEach(fkey -> {
+ FunctionStatementBlock fsblock =
prog.getFunctionStatementBlock(fkey);
+ fsblock.setNondeterministic(true);
+ });
+
+ // Find and mark the StatementBlocks having calls to
nondeterministic functions.
+ rMarkNondeterministicSBs(prog.getStatementBlocks(),
ndfncs);
+ for (String fkey : fgraph.getReachableFunctions()) {
+ FunctionStatementBlock fsblock =
prog.getFunctionStatementBlock(fkey);
+ FunctionStatement fnstmt =
(FunctionStatement)fsblock.getStatement(0);
+ rMarkNondeterministicSBs(fnstmt.getBody(),
ndfncs);
+ }
+ }
+ catch( LanguageException ex ) {
+ throw new HopsException(ex);
+ }
+ }
+
+ private boolean rIsNonDeterministicFnc (String fname,
ArrayList<StatementBlock> sbs)
+ {
+ boolean isND = false;
+ for (StatementBlock sb : sbs)
+ {
+ if (isND)
+ break;
+
+ if (sb instanceof ForStatementBlock) {
+ ForStatement fstmt =
(ForStatement)sb.getStatement(0);
+ isND = rIsNonDeterministicFnc(fname,
fstmt.getBody());
+ }
+ else if (sb instanceof WhileStatementBlock) {
+ WhileStatement wstmt =
(WhileStatement)sb.getStatement(0);
+ isND = rIsNonDeterministicFnc(fname,
wstmt.getBody());
+ }
+ else if (sb instanceof IfStatementBlock) {
+ IfStatement ifstmt =
(IfStatement)sb.getStatement(0);
+ isND = rIsNonDeterministicFnc(fname,
ifstmt.getIfBody());
+ if (ifstmt.getElseBody() != null)
+ isND = rIsNonDeterministicFnc(fname,
ifstmt.getElseBody());
+ }
+ else {
+ if (sb.getHops() != null) {
+ Hop.resetVisitStatus(sb.getHops());
+ for (Hop hop : sb.getHops())
+ isND |=
rIsNonDeterministicHop(hop);
+ Hop.resetVisitStatus(sb.getHops());
+ // Mark the statementblock
+ sb.setNondeterministic(isND);
+ }
+ }
+ }
+ return isND;
+ }
+
+ private void rMarkNondeterministicSBs (ArrayList<StatementBlock> sbs,
HashSet<String> ndfncs)
+ {
+ for (StatementBlock sb : sbs)
+ {
+ if (sb instanceof ForStatementBlock) {
+ ForStatement fstmt =
(ForStatement)sb.getStatement(0);
+ rMarkNondeterministicSBs(fstmt.getBody(),
ndfncs);
+ }
+ else if (sb instanceof WhileStatementBlock) {
+ WhileStatement wstmt =
(WhileStatement)sb.getStatement(0);
+ rMarkNondeterministicSBs(wstmt.getBody(),
ndfncs);
+ }
+ else if (sb instanceof IfStatementBlock) {
+ IfStatement ifstmt =
(IfStatement)sb.getStatement(0);
+ rMarkNondeterministicSBs(ifstmt.getIfBody(),
ndfncs);
+ if (ifstmt.getElseBody() != null)
+
rMarkNondeterministicSBs(ifstmt.getElseBody(), ndfncs);
+ }
+ else {
+ if (sb.getHops() != null) {
+ boolean callsND = false;
+ Hop.resetVisitStatus(sb.getHops());
+ for (Hop hop : sb.getHops())
+ callsND |=
rMarkNondeterministicHop(hop, ndfncs);
+ Hop.resetVisitStatus(sb.getHops());
+ if (callsND)
+ sb.setNondeterministic(callsND);
+ }
+ }
+ }
+ }
+
+ private boolean rMarkNondeterministicHop(Hop hop, HashSet<String>
ndfncs) {
+ if (hop.isVisited())
+ return false;
+
+ boolean callsND = hop instanceof FunctionOp &&
ndfncs.contains(hop.getName());
+
+ if (!callsND)
+ for (Hop hi : hop.getInput())
+ callsND |= rMarkNondeterministicHop(hi, ndfncs);
+ hop.setVisited();
+ return callsND;
+ }
+
+ private boolean rIsNonDeterministicHop(Hop hop) {
+ if (hop.isVisited())
+ return false;
+
+ boolean isND =
HopRewriteUtils.isDataGenOpWithNonDeterminism(hop);
+
+ if (!isND)
+ for (Hop hi : hop.getInput())
+ isND |= rIsNonDeterministicHop(hi);
+ hop.setVisited();
+ return isND;
+ }
+
+ private void propagate2Callers (FunctionCallGraph fgraph,
HashSet<String> ndfncs, HashSet<String> fstack, String fkey) {
+ Collection<String> cfkeys = fgraph.getCalledFunctions(fkey);
+ if (cfkeys != null) {
+ for (String cfkey : cfkeys) {
+ if (fstack.contains(cfkey) &&
fgraph.isRecursiveFunction(cfkey)) {
+ if (ndfncs.contains(cfkey) && fkey
!=null)
+ ndfncs.add(fkey);
+ }
+ else {
+ fstack.add(cfkey);
+ propagate2Callers(fgraph, ndfncs,
fstack, cfkey);
+ fstack.remove(cfkey);
+ if (ndfncs.contains(cfkey) && fkey
!=null)
+ ndfncs.add(fkey);
+ }
+ }
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 23d9a76..0710cb3 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -94,6 +94,7 @@ public class InterProceduralAnalysis
protected static final int INLINING_MAX_NUM_OPS = 10;
//inline single-statement functions w/ #ops <= threshold, other than dataops
and literals
protected static final boolean ELIMINATE_DEAD_CODE = true;
//remove dead code (e.g., assigments) not used later on
protected static final boolean FORWARD_SIMPLE_FUN_CALLS = true;
//replace a call to a simple forwarding function with the function itself
+ protected static final boolean FLAG_NONDETERMINISM = true;
//flag functions which directly or transitively contain non-deterministic calls
static {
// for internal debugging only
@@ -136,6 +137,7 @@ public class InterProceduralAnalysis
_passes.add(new IPAPassPropagateReplaceLiterals());
_passes.add(new IPAPassInlineFunctions());
_passes.add(new IPAPassEliminateDeadCode());
+ _passes.add(new IPAPassFlagNonDeterminism());
//note: apply rewrites last because statement block rewrites
//might merge relevant statement blocks in special cases, which
//would require an update of the function call graph
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 91118bf..9d86b4d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -524,6 +524,13 @@ public class HopRewriteUtils
&& ((DataGenOp)hop).hasConstantValue(value);
}
+ public static boolean isDataGenOpWithNonDeterminism(Hop hop) {
+ if (!isDataGenOp(hop, OpOpDG.RAND, OpOpDG.SAMPLE))
+ return false;
+ return isDataGenOp(hop, OpOpDG.SAMPLE) || (isDataGenOp(hop,
OpOpDG.RAND)
+ && !((DataGenOp)hop).hasConstantValue() &&
((DataGenOp)hop).hasUnspecifiedSeed());
+ }
+
public static Hop getDataGenOpConstantValue(Hop hop) {
return ((DataGenOp) hop).getConstantValue();
}
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index a15f923..de1e3ce 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -444,6 +444,7 @@ public class DMLTranslator
FunctionProgramBlock rtpb =
(FunctionProgramBlock)createRuntimeProgramBlock(rtprog, fsb, config);
rtprog.addFunctionProgramBlock(namespace,
fname, rtpb);
rtpb.setRecompileOnce( fsb.isRecompileOnce() );
+
rtpb.setNondeterministic(fsb.isNondeterministic());
}
}
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index a8f3c75..b056b7e 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -31,6 +31,7 @@ public class FunctionStatementBlock extends StatementBlock
{
private boolean _recompileOnce = false;
+ private boolean _nondeterministic = false;
/**
* TODO: DRB: This needs to be changed to reflect:
@@ -241,4 +242,12 @@ public class FunctionStatementBlock extends StatementBlock
public boolean isRecompileOnce() {
return _recompileOnce;
}
+
+ public void setNondeterministic(boolean flag) {
+ _nondeterministic = flag;
+ }
+
+ public boolean isNondeterministic() {
+ return _nondeterministic;
+ }
}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 17a0847..5a1f967 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -62,6 +62,7 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
private ArrayList<String> _updateInPlaceVars = null;
private boolean _requiresRecompile = false;
private boolean _splitDag = false;
+ private boolean _nondeterministic = false;
public StatementBlock() {
_ID = getNextSBID();
@@ -83,6 +84,7 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
this();
setParseInfo(sb);
_dmlProg = sb._dmlProg;
+ _nondeterministic = sb.isNondeterministic();
}
public void setDMLProg(DMLProgram dmlProg){
@@ -1333,4 +1335,12 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
public void setUpdateInPlaceVars( ArrayList<String> vars ) {
_updateInPlaceVars = vars;
}
+
+ public void setNondeterministic(boolean flag) {
+ _nondeterministic = flag;
+ }
+
+ public boolean isNondeterministic() {
+ return _nondeterministic;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
index 3b18b8b..fd95f38 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
@@ -109,7 +109,7 @@ public class BasicProgramBlock extends ProgramBlock
//statement-block-level, lineage-based reuse
LineageItem[] liInputs = null;
long t0 = 0;
- if (_sb != null && LineageCacheConfig.isMultiLevelReuse()) {
+ if (_sb != null && LineageCacheConfig.isMultiLevelReuse() &&
!_sb.isNondeterministic()) {
liInputs =
LineageItemUtils.getLineageItemInputstoSB(_sb.getInputstoSB(), ec);
List<String> outNames = _sb.getOutputNamesofSB();
if(liInputs != null && LineageCache.reuse(outNames,
_sb.getOutputsofSB(),
@@ -125,7 +125,7 @@ public class BasicProgramBlock extends ProgramBlock
executeInstructions(tmp, ec);
//statement-block-level, lineage-based caching
- if (_sb != null && liInputs != null)
+ if (_sb != null && liInputs != null &&
!_sb.isNondeterministic())
LineageCache.putValue(_sb.getOutputsofSB(),
liInputs, _sb.getName(), ec,
System.nanoTime()-t0);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index f250930..54f071c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -44,6 +44,7 @@ public class FunctionProgramBlock extends ProgramBlock
protected ArrayList<DataIdentifier> _outputParams;
private boolean _recompileOnce = false;
+ private boolean _nondeterministic = false;
public FunctionProgramBlock( Program prog, ArrayList<DataIdentifier>
inputParams, ArrayList<DataIdentifier> outputParams) {
super(prog);
@@ -160,6 +161,14 @@ public class FunctionProgramBlock extends ProgramBlock
public boolean isRecompileOnce() {
return _recompileOnce;
}
+
+ public void setNondeterministic(boolean flag) {
+ _nondeterministic = flag;
+ }
+
+ public boolean isNondeterministic() {
+ return _nondeterministic;
+ }
@Override
public String printBlockErrorLocation(){
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 3c4e1a9..0f0951e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -119,7 +119,7 @@ public class FunctionCallCPInstruction extends
CPInstruction {
// check if function outputs can be reused from cache
LineageItem[] liInputs = DMLScript.LINEAGE &&
LineageCacheConfig.isMultiLevelReuse() ?
LineageItemUtils.getLineage(ec, _boundInputs) : null;
- if( reuseFunctionOutputs(liInputs, fpb, ec) )
+ if (!fpb.isNondeterministic() && reuseFunctionOutputs(liInputs,
fpb, ec))
return; //only if all the outputs are found in cache
// create bindings to formal parameters for given function call
@@ -228,9 +228,9 @@ public class FunctionCallCPInstruction extends
CPInstruction {
}
//update lineage cache with the functions outputs
- if( DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse()
) {
+ if (DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse()
&& !fpb.isNondeterministic()) {
LineageCache.putValue(fpb.getOutputParams(), liInputs,
- getCacheFunctionName(_functionName,
fpb), ec, t1-t0);
+ getCacheFunctionName(_functionName,
fpb), fn_ec, t1-t0);
//FIXME: send _boundOutputNames instead of
fpb.getOutputParams as
//those are already replaced by boundoutput names in
the lineage map.
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 5de0ee1..ac54b70 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -43,9 +43,7 @@ import
org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.MetaDataFormat;
-import java.util.Arrays;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -257,11 +255,7 @@ public class LineageCache
LineageItem boundLI = ec.getLineage().get(boundVarName);
if (boundLI != null)
boundLI.resetVisitStatus();
- if (boundLI == null
- || !LineageCache.probe(li)
- //TODO remove this brittle constraint (if the
placeholder is removed
- //it might crash threads that are already
waiting for its results)
- || LineageItemUtils.containsRandDataGen(new
HashSet<>(Arrays.asList(liInputs)), boundLI)) {
+ if (boundLI == null || !LineageCache.probe(li)) {
AllOutputsCacheable = false;
}
FuncLIMap.put(li, boundLI);
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index e5e83be..2cc9bb3 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -503,6 +503,7 @@ public class ProgramConverter
Recompiler.updateFunctionNames( hops,
pid );
ret.setHops( hops );
ret.updateRecompilationFlag();
+
ret.setNondeterministic(sb.isNondeterministic());
}
else {
ret = sb;
@@ -541,6 +542,7 @@ public class ProgramConverter
Hop hops = Recompiler.deepCopyHopsDag(
sb.getPredicateHops() );
ret.setPredicateHops( hops );
ret.updatePredicateRecompilationFlag();
+
ret.setNondeterministic(sb.isNondeterministic());
}
else {
ret = sb;
@@ -580,6 +582,7 @@ public class ProgramConverter
Hop hops = Recompiler.deepCopyHopsDag(
sb.getPredicateHops() );
ret.setPredicateHops( hops );
ret.updatePredicateRecompilationFlag();
+
ret.setNondeterministic(sb.isNondeterministic());
}
else {
ret = sb;
@@ -633,6 +636,7 @@ public class ProgramConverter
ret.setIncrementHops( hops );
}
ret.updatePredicateRecompilationFlags();
+
ret.setNondeterministic(sb.isNondeterministic());
}
else {
ret = sb;
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
index b0ff854..46a69b4 100644
---
a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
@@ -42,7 +42,7 @@ public class FunctionFullReuseTest extends AutomatedTestBase
{
protected static final String TEST_DIR = "functions/lineage/";
protected static final String TEST_NAME = "FunctionFullReuse";
- protected static final int TEST_VARIANTS = 7;
+ protected static final int TEST_VARIANTS = 8;
protected String TEST_CLASS_DIR = TEST_DIR +
FunctionFullReuseTest.class.getSimpleName() + "/";
@@ -82,6 +82,11 @@ public class FunctionFullReuseTest extends AutomatedTestBase
public void testParforIssue2() {
testLineageTrace(TEST_NAME+"7");
}
+
+ @Test
+ public void testCompilerAssistedNondeterminism() {
+ testLineageTrace(TEST_NAME+"8");
+ }
public void testLineageTrace(String testname) {
boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
diff --git a/src/test/scripts/functions/lineage/FunctionFullReuse8.dml
b/src/test/scripts/functions/lineage/FunctionFullReuse8.dml
new file mode 100644
index 0000000..4e33613
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FunctionFullReuse8.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Increase rows and cols for better performance gains
+
+foo = function(Matrix[Double] X) return (Matrix[Double] S)
+{
+ if (ncol(X) == 15)
+ S = X;
+ else {
+ S = bar(X);
+ X = cbind(X, matrix(1, nrow(X), 1));
+ }
+}
+bar = function(Matrix[Double] X) return (Matrix[Double] S)
+{
+ if (ncol(X) == 15)
+ S = X;
+ else {
+ X = cbind(X, rand(rows=nrow(X), cols=1, seed=42));
+ S = foo(X);
+ #S = X;
+ }
+}
+
+r = 100
+c = 10
+
+X = rand(rows=r, cols=c, seed=42);
+y = rand(rows=r, cols=1, seed=43);
+R = matrix(0, 1, 2);
+
+S = foo(X);
+R[,1] = sum(S);
+
+S = foo(X);
+R[,2] = sum(S);
+
+write(R, $1, format="text");
diff --git a/src/test/scripts/functions/lineage/LineageReuseAlg2.dml
b/src/test/scripts/functions/lineage/LineageReuseAlg2.dml
index 8639f20..dcc69cd 100644
--- a/src/test/scripts/functions/lineage/LineageReuseAlg2.dml
+++ b/src/test/scripts/functions/lineage/LineageReuseAlg2.dml
@@ -25,9 +25,7 @@ l2norm = function(Matrix[Double] X, Matrix[Double] y,
Matrix[Double] B) return (
randColSet = function(Matrix[Double] X, Integer seed, Double sample) return
(Matrix[Double] Xi) {
temp = rand(rows=ncol(X), cols=1, min = 0, max = 1, sparsity=1, seed=seed)
<= sample
- sel = diag(temp)
- sel = removeEmpty(target = sel, margin = "cols")
- Xi = X %*% sel
+ Xi = removeEmpty(target = X, margin = "cols", select = temp);
}
X = rand(rows=100, cols=100, sparsity=1.0, seed=1);