This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new f1bffeb  [SYSTEMDS-344] New IPA pass for marking deterministic 
functions/SBs
f1bffeb is described below

commit f1bffeb299eec6a57d5290fd12c81ba92c9f03e2
Author: arnabp <[email protected]>
AuthorDate: Sun May 17 22:24:22 2020 +0200

    [SYSTEMDS-344] New IPA pass for marking deterministic functions/SBs
    
    This patch moves the fragile and less efficient non-determinism check
    in runtime to compile time. This adds a new IPA rewrite to unmark the
    functions and StatementBlocks containing direct or transitive
    nondeterministic calls (e.g. rand with UNSPECIFIED_SEED) for lineage
    caching.
    
    AMLS project SS 2020.
    Closes #911.
---
 docs/Tasks.txt                                     |   2 +-
 src/main/java/org/apache/sysds/hops/DataGenOp.java |   8 +
 .../sysds/hops/ipa/IPAPassFlagNonDeterminism.java  | 201 +++++++++++++++++++++
 .../sysds/hops/ipa/InterProceduralAnalysis.java    |   2 +
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |   7 +
 .../org/apache/sysds/parser/DMLTranslator.java     |   1 +
 .../sysds/parser/FunctionStatementBlock.java       |   9 +
 .../org/apache/sysds/parser/StatementBlock.java    |  10 +
 .../runtime/controlprogram/BasicProgramBlock.java  |   4 +-
 .../controlprogram/FunctionProgramBlock.java       |   9 +
 .../instructions/cp/FunctionCallCPInstruction.java |   6 +-
 .../apache/sysds/runtime/lineage/LineageCache.java |   8 +-
 .../sysds/runtime/util/ProgramConverter.java       |   4 +
 .../functions/lineage/FunctionFullReuseTest.java   |   7 +-
 .../functions/lineage/FunctionFullReuse8.dml       |  57 ++++++
 .../scripts/functions/lineage/LineageReuseAlg2.dml |   4 +-
 16 files changed, 322 insertions(+), 17 deletions(-)

diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 8163e9f..6d5ff80 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -277,7 +277,7 @@ SYSTEMDS-340 Compiler Assisted Lineage Caching and Reuse
  * 341 Finalize unmarking of loop dependent operations
  * 342 Mark functions as last-use to enable early eviction
  * 343 Identify equal last level HOPs to ensure SB-level reuse
- * 344 Unmark functions/SBs containing non-determinism for caching
+ * 344 Unmark functions/SBs containing non-determinism for caching    OK
  * 345 Compiler assisted cache configuration
 
 SYSTEMDS-350 Data Cleaning Framework
diff --git a/src/main/java/org/apache/sysds/hops/DataGenOp.java 
b/src/main/java/org/apache/sysds/hops/DataGenOp.java
index edcb448..8fdf98d 100644
--- a/src/main/java/org/apache/sysds/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataGenOp.java
@@ -468,6 +468,14 @@ public class DataGenOp extends MultiThreadedHop
                return ret;
        }
        
+       public boolean hasUnspecifiedSeed() {
+               if (_op == OpOpDG.RAND || _op == OpOpDG.SINIT) {
+                       Hop seed = 
getInput().get(_paramIndexMap.get(DataExpression.RAND_SEED));
+                       return 
seed.getName().equals(String.valueOf(DataGenOp.UNSPECIFIED_SEED));
+               }
+               return false;
+       }
+       
        public Hop getConstantValue() {
                return 
getInput().get(_paramIndexMap.get(DataExpression.RAND_MIN));
        }
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
new file mode 100644
index 0000000..a000096
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.ipa;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+
+public class IPAPassFlagNonDeterminism extends IPAPass {
+       @Override
+       public boolean isApplicable(FunctionCallGraph fgraph) {
+               return InterProceduralAnalysis.REMOVE_UNUSED_FUNCTIONS
+                       && !fgraph.containsSecondOrderCall();
+       }
+
+       @Override
+       public void rewriteProgram (DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes) 
+       {
+               if (!LineageCacheConfig.isMultiLevelReuse())
+                       return;
+               
+               try {
+                       // Find the individual functions and statementblocks 
with non-determinism.
+                       HashSet<String> ndfncs = new HashSet<>();
+                       for (String fkey : fgraph.getReachableFunctions()) {
+                               FunctionStatementBlock fsblock = 
prog.getFunctionStatementBlock(fkey);
+                               FunctionStatement fnstmt = 
(FunctionStatement)fsblock.getStatement(0);
+                               String fname = 
DMLProgram.splitFunctionKey(fkey)[1];
+                               if (rIsNonDeterministicFnc(fname, 
fnstmt.getBody()))
+                                       ndfncs.add(fkey);
+                       }
+
+                       // Find the callers of the nondeterministic functions.
+                       propagate2Callers(fgraph, ndfncs, new 
HashSet<String>(), null);
+                       
+                       // Mark the corresponding FunctionStatementBlocks
+                       ndfncs.forEach(fkey -> {
+                               FunctionStatementBlock fsblock = 
prog.getFunctionStatementBlock(fkey); 
+                               fsblock.setNondeterministic(true);
+                       });
+                       
+                       // Find and mark the StatementBlocks having calls to 
nondeterministic functions.
+                       rMarkNondeterministicSBs(prog.getStatementBlocks(), 
ndfncs);
+                       for (String fkey : fgraph.getReachableFunctions()) {
+                               FunctionStatementBlock fsblock = 
prog.getFunctionStatementBlock(fkey);
+                               FunctionStatement fnstmt = 
(FunctionStatement)fsblock.getStatement(0);
+                               rMarkNondeterministicSBs(fnstmt.getBody(), 
ndfncs);
+                       }
+               }
+               catch( LanguageException ex ) {
+                       throw new HopsException(ex);
+               }
+       }
+
+       private boolean rIsNonDeterministicFnc (String fname, 
ArrayList<StatementBlock> sbs) 
+       {
+               boolean isND = false;
+               for (StatementBlock sb : sbs)
+               {
+                       if (isND)
+                               break;
+
+                       if (sb instanceof ForStatementBlock) {
+                               ForStatement fstmt = 
(ForStatement)sb.getStatement(0);
+                               isND = rIsNonDeterministicFnc(fname, 
fstmt.getBody());
+                       }
+                       else if (sb instanceof WhileStatementBlock) {
+                               WhileStatement wstmt = 
(WhileStatement)sb.getStatement(0);
+                               isND = rIsNonDeterministicFnc(fname, 
wstmt.getBody());
+                       }
+                       else if (sb instanceof IfStatementBlock) {
+                               IfStatement ifstmt = 
(IfStatement)sb.getStatement(0);
+                               isND = rIsNonDeterministicFnc(fname, 
ifstmt.getIfBody());
+                               if (ifstmt.getElseBody() != null)
+                                       isND = rIsNonDeterministicFnc(fname, 
ifstmt.getElseBody());
+                       }
+                       else {
+                               if (sb.getHops() != null) {
+                                       Hop.resetVisitStatus(sb.getHops());
+                                       for (Hop hop : sb.getHops()) 
+                                               isND |= 
rIsNonDeterministicHop(hop);
+                                       Hop.resetVisitStatus(sb.getHops());
+                                       // Mark the statementblock
+                                       sb.setNondeterministic(isND);
+                               }
+                       }
+               }
+               return isND;
+       }
+       
+       private void rMarkNondeterministicSBs (ArrayList<StatementBlock> sbs, 
HashSet<String> ndfncs)
+       {
+               for (StatementBlock sb : sbs)
+               {
+                       if (sb instanceof ForStatementBlock) {
+                               ForStatement fstmt = 
(ForStatement)sb.getStatement(0);
+                               rMarkNondeterministicSBs(fstmt.getBody(), 
ndfncs);
+                       }
+                       else if (sb instanceof WhileStatementBlock) {
+                               WhileStatement wstmt = 
(WhileStatement)sb.getStatement(0);
+                               rMarkNondeterministicSBs(wstmt.getBody(), 
ndfncs);
+                       }
+                       else if (sb instanceof IfStatementBlock) {
+                               IfStatement ifstmt = 
(IfStatement)sb.getStatement(0);
+                               rMarkNondeterministicSBs(ifstmt.getIfBody(), 
ndfncs);
+                               if (ifstmt.getElseBody() != null)
+                                       
rMarkNondeterministicSBs(ifstmt.getElseBody(), ndfncs);
+                       }
+                       else {
+                               if (sb.getHops() != null) {
+                                       boolean callsND = false;
+                                       Hop.resetVisitStatus(sb.getHops());
+                                       for (Hop hop : sb.getHops())
+                                               callsND |= 
rMarkNondeterministicHop(hop, ndfncs);
+                                       Hop.resetVisitStatus(sb.getHops());
+                                       if (callsND)
+                                               sb.setNondeterministic(callsND);
+                               }
+                       }
+               }
+       }
+       
+       private boolean rMarkNondeterministicHop(Hop hop, HashSet<String> 
ndfncs) {
+               if (hop.isVisited())
+                       return false;
+
+               boolean callsND = hop instanceof FunctionOp && 
ndfncs.contains(hop.getName());
+                       
+               if (!callsND)
+                       for (Hop hi : hop.getInput())
+                               callsND |= rMarkNondeterministicHop(hi, ndfncs);
+               hop.setVisited();
+               return callsND;
+       }
+       
+       private boolean rIsNonDeterministicHop(Hop hop) {
+               if (hop.isVisited())
+                       return false;
+
+               boolean isND = 
HopRewriteUtils.isDataGenOpWithNonDeterminism(hop);
+               
+               if (!isND)
+                       for (Hop hi : hop.getInput())
+                               isND |= rIsNonDeterministicHop(hi);
+               hop.setVisited();
+               return isND;
+       }
+       
+       private void propagate2Callers (FunctionCallGraph fgraph, 
HashSet<String> ndfncs, HashSet<String> fstack, String fkey) {
+               Collection<String> cfkeys = fgraph.getCalledFunctions(fkey);
+               if (cfkeys != null) {
+                       for (String cfkey : cfkeys) {
+                               if (fstack.contains(cfkey) && 
fgraph.isRecursiveFunction(cfkey)) {
+                                       if (ndfncs.contains(cfkey) && fkey 
!=null)
+                                               ndfncs.add(fkey);
+                               }
+                               else {
+                                       fstack.add(cfkey);
+                                       propagate2Callers(fgraph, ndfncs, 
fstack, cfkey);
+                                       fstack.remove(cfkey);
+                                       if (ndfncs.contains(cfkey) && fkey 
!=null)
+                                               ndfncs.add(fkey);
+                               }
+                       }
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 23d9a76..0710cb3 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -94,6 +94,7 @@ public class InterProceduralAnalysis
        protected static final int     INLINING_MAX_NUM_OPS           = 10;   
//inline single-statement functions w/ #ops <= threshold, other than dataops 
and literals
        protected static final boolean ELIMINATE_DEAD_CODE            = true; 
//remove dead code (e.g., assigments) not used later on
        protected static final boolean FORWARD_SIMPLE_FUN_CALLS       = true; 
//replace a call to a simple forwarding function with the function itself
+       protected static final boolean FLAG_NONDETERMINISM            = true; 
//flag functions which directly or transitively contain non-deterministic calls
        
        static {
                // for internal debugging only
@@ -136,6 +137,7 @@ public class InterProceduralAnalysis
                _passes.add(new IPAPassPropagateReplaceLiterals());
                _passes.add(new IPAPassInlineFunctions());
                _passes.add(new IPAPassEliminateDeadCode());
+               _passes.add(new IPAPassFlagNonDeterminism());
                //note: apply rewrites last because statement block rewrites
                //might merge relevant statement blocks in special cases, which 
                //would require an update of the function call graph
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 91118bf..9d86b4d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -524,6 +524,13 @@ public class HopRewriteUtils
                        && ((DataGenOp)hop).hasConstantValue(value);
        }
        
+       public static boolean isDataGenOpWithNonDeterminism(Hop hop) {
+               if (!isDataGenOp(hop, OpOpDG.RAND, OpOpDG.SAMPLE))
+                       return false;
+               return isDataGenOp(hop, OpOpDG.SAMPLE) || (isDataGenOp(hop, 
OpOpDG.RAND) 
+                       && !((DataGenOp)hop).hasConstantValue() && 
((DataGenOp)hop).hasUnspecifiedSeed());
+       }
+       
        public static Hop getDataGenOpConstantValue(Hop hop) {
                return ((DataGenOp) hop).getConstantValue();
        }
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index a15f923..de1e3ce 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -444,6 +444,7 @@ public class DMLTranslator
                                FunctionProgramBlock rtpb = 
(FunctionProgramBlock)createRuntimeProgramBlock(rtprog, fsb, config);
                                rtprog.addFunctionProgramBlock(namespace, 
fname, rtpb);
                                rtpb.setRecompileOnce( fsb.isRecompileOnce() );
+                               
rtpb.setNondeterministic(fsb.isNondeterministic());
                        }
                }
                
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index a8f3c75..b056b7e 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -31,6 +31,7 @@ public class FunctionStatementBlock extends StatementBlock
 {
                
        private boolean _recompileOnce = false;
+       private boolean _nondeterministic = false;
        
        /**
         *  TODO: DRB:  This needs to be changed to reflect:
@@ -241,4 +242,12 @@ public class FunctionStatementBlock extends StatementBlock
        public boolean isRecompileOnce() {
                return _recompileOnce;
        }
+       
+       public void setNondeterministic(boolean flag) {
+               _nondeterministic = flag;
+       }
+       
+       public boolean isNondeterministic() {
+               return _nondeterministic;
+       }
 }
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 17a0847..5a1f967 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -62,6 +62,7 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
        private ArrayList<String> _updateInPlaceVars = null;
        private boolean _requiresRecompile = false;
        private boolean _splitDag = false;
+       private boolean _nondeterministic = false;
 
        public StatementBlock() {
                _ID = getNextSBID();
@@ -83,6 +84,7 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                this();
                setParseInfo(sb);
                _dmlProg = sb._dmlProg;
+               _nondeterministic = sb.isNondeterministic();
        }
 
        public void setDMLProg(DMLProgram dmlProg){
@@ -1333,4 +1335,12 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
        public void setUpdateInPlaceVars( ArrayList<String> vars ) {
                _updateInPlaceVars = vars;
        }
+       
+       public void setNondeterministic(boolean flag) {
+               _nondeterministic = flag;
+       }
+       
+       public boolean isNondeterministic() {
+               return _nondeterministic;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
index 3b18b8b..fd95f38 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
@@ -109,7 +109,7 @@ public class BasicProgramBlock extends ProgramBlock
                //statement-block-level, lineage-based reuse
                LineageItem[] liInputs = null;
                long t0 = 0;
-               if (_sb != null && LineageCacheConfig.isMultiLevelReuse()) {
+               if (_sb != null && LineageCacheConfig.isMultiLevelReuse() && 
!_sb.isNondeterministic()) {
                        liInputs = 
LineageItemUtils.getLineageItemInputstoSB(_sb.getInputstoSB(), ec);
                        List<String> outNames = _sb.getOutputNamesofSB();
                        if(liInputs != null && LineageCache.reuse(outNames, 
_sb.getOutputsofSB(), 
@@ -125,7 +125,7 @@ public class BasicProgramBlock extends ProgramBlock
                executeInstructions(tmp, ec);
                
                //statement-block-level, lineage-based caching
-               if (_sb != null && liInputs != null)
+               if (_sb != null && liInputs != null && 
!_sb.isNondeterministic())
                        LineageCache.putValue(_sb.getOutputsofSB(),
                                liInputs, _sb.getName(), ec, 
System.nanoTime()-t0);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index f250930..54f071c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -44,6 +44,7 @@ public class FunctionProgramBlock extends ProgramBlock
        protected ArrayList<DataIdentifier> _outputParams;
        
        private boolean _recompileOnce = false;
+       private boolean _nondeterministic = false;
        
        public FunctionProgramBlock( Program prog, ArrayList<DataIdentifier> 
inputParams, ArrayList<DataIdentifier> outputParams) {
                super(prog);
@@ -160,6 +161,14 @@ public class FunctionProgramBlock extends ProgramBlock
        public boolean isRecompileOnce() {
                return _recompileOnce;
        }
+
+       public void setNondeterministic(boolean flag) {
+               _nondeterministic = flag;
+       }
+       
+       public boolean isNondeterministic() {
+               return _nondeterministic;
+       }
        
        @Override
        public String printBlockErrorLocation(){
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 3c4e1a9..0f0951e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -119,7 +119,7 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                // check if function outputs can be reused from cache
                LineageItem[] liInputs = DMLScript.LINEAGE && 
LineageCacheConfig.isMultiLevelReuse() ?
                        LineageItemUtils.getLineage(ec, _boundInputs) : null;
-               if( reuseFunctionOutputs(liInputs, fpb, ec) )
+               if (!fpb.isNondeterministic() && reuseFunctionOutputs(liInputs, 
fpb, ec))
                        return; //only if all the outputs are found in cache
                
                // create bindings to formal parameters for given function call
@@ -228,9 +228,9 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                }
 
                //update lineage cache with the functions outputs
-               if( DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() 
) {
+               if (DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() 
&& !fpb.isNondeterministic()) {
                        LineageCache.putValue(fpb.getOutputParams(), liInputs, 
-                                       getCacheFunctionName(_functionName, 
fpb), ec, t1-t0);
+                                       getCacheFunctionName(_functionName, 
fpb), fn_ec, t1-t0);
                        //FIXME: send _boundOutputNames instead of 
fpb.getOutputParams as 
                        //those are already replaced by boundoutput names in 
the lineage map.
                }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 5de0ee1..ac54b70 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -43,9 +43,7 @@ import 
org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
 
-import java.util.Arrays;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 
@@ -257,11 +255,7 @@ public class LineageCache
                        LineageItem boundLI = ec.getLineage().get(boundVarName);
                        if (boundLI != null)
                                boundLI.resetVisitStatus();
-                       if (boundLI == null 
-                               || !LineageCache.probe(li)
-                               //TODO remove this brittle constraint (if the 
placeholder is removed
-                               //it might crash threads that are already 
waiting for its results)
-                               || LineageItemUtils.containsRandDataGen(new 
HashSet<>(Arrays.asList(liInputs)), boundLI)) {
+                       if (boundLI == null || !LineageCache.probe(li)) {
                                AllOutputsCacheable = false;
                        }
                        FuncLIMap.put(li, boundLI);
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java 
b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index e5e83be..2cc9bb3 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -503,6 +503,7 @@ public class ProgramConverter
                                        Recompiler.updateFunctionNames( hops, 
pid );
                                ret.setHops( hops );
                                ret.updateRecompilationFlag();
+                               
ret.setNondeterministic(sb.isNondeterministic());
                        }
                        else {
                                ret = sb;
@@ -541,6 +542,7 @@ public class ProgramConverter
                                Hop hops = Recompiler.deepCopyHopsDag( 
sb.getPredicateHops() );
                                ret.setPredicateHops( hops );
                                ret.updatePredicateRecompilationFlag();
+                               
ret.setNondeterministic(sb.isNondeterministic());
                        }
                        else {
                                ret = sb;
@@ -580,6 +582,7 @@ public class ProgramConverter
                                Hop hops = Recompiler.deepCopyHopsDag( 
sb.getPredicateHops() );
                                ret.setPredicateHops( hops );
                                ret.updatePredicateRecompilationFlag();
+                               
ret.setNondeterministic(sb.isNondeterministic());
                        }
                        else {
                                ret = sb;
@@ -633,6 +636,7 @@ public class ProgramConverter
                                        ret.setIncrementHops( hops );
                                }
                                ret.updatePredicateRecompilationFlags();
+                               
ret.setNondeterministic(sb.isNondeterministic());
                        }
                        else {
                                ret = sb;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
 
b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
index b0ff854..46a69b4 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
@@ -42,7 +42,7 @@ public class FunctionFullReuseTest extends AutomatedTestBase
 {
        protected static final String TEST_DIR = "functions/lineage/";
        protected static final String TEST_NAME = "FunctionFullReuse";
-       protected static final int TEST_VARIANTS = 7;
+       protected static final int TEST_VARIANTS = 8;
        
        protected String TEST_CLASS_DIR = TEST_DIR + 
FunctionFullReuseTest.class.getSimpleName() + "/";
        
@@ -82,6 +82,11 @@ public class FunctionFullReuseTest extends AutomatedTestBase
        public void testParforIssue2() {
                testLineageTrace(TEST_NAME+"7");
        }
+
+       @Test
+       public void testCompilerAssistedNondeterminism() {
+               testLineageTrace(TEST_NAME+"8");
+       }
        
        public void testLineageTrace(String testname) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
diff --git a/src/test/scripts/functions/lineage/FunctionFullReuse8.dml 
b/src/test/scripts/functions/lineage/FunctionFullReuse8.dml
new file mode 100644
index 0000000..4e33613
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FunctionFullReuse8.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Increase rows and cols for better performance gains
+
+foo = function(Matrix[Double] X) return (Matrix[Double] S) 
+{
+  if (ncol(X) == 15)
+    S = X;
+  else {
+    S = bar(X);
+    X = cbind(X, matrix(1, nrow(X), 1));
+  }
+}
+bar = function(Matrix[Double] X) return (Matrix[Double] S) 
+{
+  if (ncol(X) == 15)
+    S = X;
+  else {
+    X = cbind(X, rand(rows=nrow(X), cols=1, seed=42));
+    S = foo(X);
+    #S = X;
+  }
+}
+
+r = 100
+c = 10
+
+X = rand(rows=r, cols=c, seed=42);
+y = rand(rows=r, cols=1, seed=43);
+R = matrix(0, 1, 2);
+
+S = foo(X);
+R[,1] = sum(S);
+
+S = foo(X);
+R[,2] = sum(S);
+
+write(R, $1, format="text");
diff --git a/src/test/scripts/functions/lineage/LineageReuseAlg2.dml 
b/src/test/scripts/functions/lineage/LineageReuseAlg2.dml
index 8639f20..dcc69cd 100644
--- a/src/test/scripts/functions/lineage/LineageReuseAlg2.dml
+++ b/src/test/scripts/functions/lineage/LineageReuseAlg2.dml
@@ -25,9 +25,7 @@ l2norm = function(Matrix[Double] X, Matrix[Double] y, 
Matrix[Double] B) return (
 
 randColSet = function(Matrix[Double] X, Integer seed, Double sample) return 
(Matrix[Double] Xi) {
   temp = rand(rows=ncol(X), cols=1, min = 0, max = 1, sparsity=1, seed=seed) 
<= sample
-  sel = diag(temp)
-  sel = removeEmpty(target = sel, margin = "cols")
-  Xi = X %*% sel
+  Xi = removeEmpty(target = X, margin = "cols", select = temp);
 }
 
 X = rand(rows=100, cols=100, sparsity=1.0, seed=1);

Reply via email to