This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 44e88d843a [SYSTEMDS-3509] Fix IPA pass function forwarding (named 
args ordering)
44e88d843a is described below

commit 44e88d843a819b107c13478b329c578ff809d6d6
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Mar 25 16:32:05 2023 +0100

    [SYSTEMDS-3509] Fix IPA pass function forwarding (named args ordering)
    
    This patch fixes the inter-procedural-analysis (IPA) rewrite pass
    'function forwarding' where a chain of function calls is collapsed to
    a single function call. Previously if function arguments with same
    name were passed in different orders a misassignment could happen but
    only if the rewrite applies. We now wire the function arguments in the
    correct order according to argument names.
---
 scripts/builtin/decisionTreePredict.dml                        |  1 -
 .../org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java | 10 ++++++----
 .../org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java   |  1 +
 .../apache/sysds/runtime/instructions/cp/CPInstruction.java    |  1 -
 src/test/scripts/functions/builtin/decisionTreePredict.dml     |  3 +--
 5 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/scripts/builtin/decisionTreePredict.dml 
b/scripts/builtin/decisionTreePredict.dml
index b312910a48..c4e75b4fe1 100644
--- a/scripts/builtin/decisionTreePredict.dml
+++ b/scripts/builtin/decisionTreePredict.dml
@@ -46,7 +46,6 @@ m_decisionTreePredict = function(Matrix[Double] X, 
Matrix[Double] y = matrix(0,0
     Matrix[Double] ctypes, Matrix[Double] M, String strategy="TT", Boolean 
verbose = FALSE)
   return (Matrix[Double] yhat)
 {
-    print(toString(M))
   if( strategy == "TT" )
     yhat = predict_TT(M, X);
   else if( strategy == "GEMM" )
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
index 8b57742f26..e35f2f50d6 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
@@ -55,7 +55,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
                        
                        //step 1: basic application filter: simple forwarding 
call
                        if( fstmt.getBody().size() != 1 || 
!singleFunctionOp(fstmt.getBody().get(0).getHops())
-                               || 
!hasOnlySimplyArguments((FunctionOp)fstmt.getBody().get(0).getHops().get(0)))
+                               || 
!hasOnlySimpleArguments((FunctionOp)fstmt.getBody().get(0).getHops().get(0)))
                                continue;
                        if( LOG.isDebugEnabled() )
                                LOG.debug("IPA: Forward-function-call candidate 
L1: '"+fkey+"'");
@@ -96,7 +96,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
                return hops.get(0) instanceof FunctionOp;
        }
        
-       private static boolean hasOnlySimplyArguments(FunctionOp fop) {
+       private static boolean hasOnlySimpleArguments(FunctionOp fop) {
                return fop.getInput().stream().allMatch(h -> h instanceof 
LiteralOp 
                        || HopRewriteUtils.isData(h, OpOpData.TRANSIENTREAD));
        }
@@ -127,15 +127,17 @@ public class IPAPassForwardFunctionCalls extends IPAPass
                for( int i=0; i<call2.getInput().size(); i++ )
                        probe.put(call2.getInputVariableNames()[i], 
call2.getInput().get(i));
                
-               //construct new inputs for call1
+               //construct new named inputs for call1 (in right order)
+               ArrayList<String> varNames = new ArrayList<>();
                ArrayList<Hop> inputs = new ArrayList<>();
                for( int i=0; i<call1.getInput().size(); i++ )
                        if( probe.containsKey(call1.getInputVariableNames()[i]) 
) {
+                               varNames.add(call1.getInputVariableNames()[i]);
                                inputs.add( 
(probe.get(call1.getInputVariableNames()[i]) instanceof LiteralOp) ? 
                                        
probe.get(call1.getInputVariableNames()[i]) : call1.getInput().get(i));
                        }
                HopRewriteUtils.removeAllChildReferences(call1);
                call1.addAllInputs(inputs);
-               call1.setInputVariableNames(call2.getInputVariableNames());
+               call1.setInputVariableNames(varNames.toArray(new String[0]));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
index d4f976a795..1cf5761423 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
@@ -57,6 +57,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
                Map<Long, Integer> operatorJobCount = new HashMap<>();
                markPersistableSparkOps(sparkRoots, operatorJobCount);
                // TODO: A rewrite pass to remove less effective chkpoints
+               @SuppressWarnings("unused")
                List<Lop> nodesWithChkpt = addChkpointLop(lops, 
operatorJobCount);
                //New node is added inplace in the Lop DAG
                return List.of(sb);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index aa17fa2cab..174e4f2d27 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -31,7 +31,6 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.CPInstructionParser;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.fed.FEDInstructionUtils;
-import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
 
diff --git a/src/test/scripts/functions/builtin/decisionTreePredict.dml 
b/src/test/scripts/functions/builtin/decisionTreePredict.dml
index e87b01c581..733d363b78 100644
--- a/src/test/scripts/functions/builtin/decisionTreePredict.dml
+++ b/src/test/scripts/functions/builtin/decisionTreePredict.dml
@@ -21,6 +21,5 @@
 
 M = read($1);
 X = read($2);
-# FIXME reordering of M and X yields wrong passing
-Y = decisionTreePredict(M=M, X=X, ctypes=matrix(2,1,ncol(X)+1), strategy=$3);
+Y = decisionTreePredict(X=X, M=M, ctypes=matrix(2,1,ncol(X)+1), strategy=$3);
 write(Y, $4);

Reply via email to