This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new e581b5a  [SYSTEMDS-2575] Fix eval function calls (incorrect pinning of 
inputs)
e581b5a is described below

commit e581b5a6248b56a70e18ffe6ba699e8142a2d679
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Jul 20 21:37:21 2020 +0200

    [SYSTEMDS-2575] Fix eval function calls (incorrect pinning of inputs)
    
    This patch fixes an issue of indirect eval function calls where wrong
    input variable names led to missing pinning of inputs and thus too eager
    cleanup of these variables (which causes crashes if the inputs are used
    in other operations of the eval call).
    
    The fix is simple. We avoid such inconsistent construction and
    invocation of fcall instructions by using a narrower interface and
    constructing the materialized names internally in the fcall.
---
 .../runtime/controlprogram/paramserv/PSWorker.java |  4 +--
 .../controlprogram/paramserv/ParamServer.java      |  4 +--
 .../instructions/cp/EvalNaryCPInstruction.java     | 34 +++++++++-------------
 .../instructions/cp/FunctionCallCPInstruction.java | 11 +++----
 4 files changed, 22 insertions(+), 31 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index 9f2311b..0eb9cf9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -77,12 +77,10 @@ public abstract class PSWorker implements Serializable
                CPOperand[] boundInputs = inputs.stream()
                        .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
                        .toArray(CPOperand[]::new);
-               ArrayList<String> inputNames = 
inputs.stream().map(DataIdentifier::getName)
-                       .collect(Collectors.toCollection(ArrayList::new));
                ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                        .collect(Collectors.toCollection(ArrayList::new));
                _inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
-                       inputNames, func.getInputParamNames(), outputNames, 
"update function");
+                       func.getInputParamNames(), outputNames, "update 
function");
 
                // Check the inputs of the update function
                checkInput(false, inputs, DataType.MATRIX, 
Statement.PS_FEATURES);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 92e29b1..81cee33 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -104,12 +104,10 @@ public abstract class ParamServer
                CPOperand[] boundInputs = inputs.stream()
                        .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
                        .toArray(CPOperand[]::new);
-               ArrayList<String> inputNames = 
inputs.stream().map(DataIdentifier::getName)
-                       .collect(Collectors.toCollection(ArrayList::new));
                ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                        .collect(Collectors.toCollection(ArrayList::new));
                _inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
-                       inputNames, func.getInputParamNames(), outputNames, 
"aggregate function");
+                       func.getInputParamNames(), outputNames, "aggregate 
function");
        }
 
        public abstract void push(int workerID, ListObject value);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index f10e7bb..070a3fc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -67,10 +67,6 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1, 
inputs.length);
                List<String> boundOutputNames = new ArrayList<>();
                boundOutputNames.add(output.getName());
-               List<String> boundInputNames = new ArrayList<>();
-               for (CPOperand input : boundInputs) {
-                       boundInputNames.add(input.getName());
-               }
 
                //2. copy the created output matrix
                MatrixObject outputMO = new 
MatrixObject(ec.getMatrixObject(output.getName()));
@@ -103,32 +99,30 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                                ec.getVariables().put(varName, in);
                                boundInputs2[i] = new CPOperand(varName, in);
                        }
-                       boundInputNames = lo.isNamedList() ? lo.getNames() : 
fpb.getInputParamNames();
                        boundInputs = boundInputs2;
                }
                
                //5. call the function
                FunctionCallCPInstruction fcpi = new 
FunctionCallCPInstruction(null, funcName,
-                       boundInputs, boundInputNames, fpb.getInputParamNames(), 
boundOutputNames, "eval func");
+                       boundInputs, fpb.getInputParamNames(), 
boundOutputNames, "eval func");
                fcpi.processInstruction(ec);
 
                //6. convert the result to matrix
                Data newOutput = ec.getVariable(output);
-               if (newOutput instanceof MatrixObject) {
-                       return;
-               }
-               MatrixBlock mb = null;
-               if (newOutput instanceof ScalarObject) {
-                       //convert scalar to matrix
-                       mb = new MatrixBlock(((ScalarObject) 
newOutput).getDoubleValue());
-               } else if (newOutput instanceof FrameObject) {
-                       //convert frame to matrix
-                       mb = DataConverter.convertToMatrixBlock(((FrameObject) 
newOutput).acquireRead());
-                       ec.cleanupCacheableData((FrameObject) newOutput);
+               if (!(newOutput instanceof MatrixObject)) {
+                       MatrixBlock mb = null;
+                       if (newOutput instanceof ScalarObject) {
+                               //convert scalar to matrix
+                               mb = new MatrixBlock(((ScalarObject) 
newOutput).getDoubleValue());
+                       } else if (newOutput instanceof FrameObject) {
+                               //convert frame to matrix
+                               mb = 
DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
+                               ec.cleanupCacheableData((FrameObject) 
newOutput);
+                       }
+                       outputMO.acquireModify(mb);
+                       outputMO.release();
+                       ec.setVariable(output.getName(), outputMO);
                }
-               outputMO.acquireModify(mb);
-               outputMO.release();
-               ec.setVariable(output.getName(), outputMO);
                
                //7. cleanup of variable expanded from list
                if( boundInputs2 != null ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 695b07a..8b88647 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -20,8 +20,10 @@
 package org.apache.sysds.runtime.instructions.cp;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
+import java.util.stream.Collectors;
 
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.lops.Lop;
@@ -55,12 +57,13 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
        private final List<String> _boundOutputNames;
 
        public FunctionCallCPInstruction(String namespace, String functName, 
CPOperand[] boundInputs,
-                       List<String> boundInputNames, List<String> funArgNames, 
List<String> boundOutputNames, String istr) {
+                       List<String> funArgNames, List<String> 
boundOutputNames, String istr) {
                super(CPType.External, null, functName, istr);
                _functionName = functName;
                _namespace = namespace;
                _boundInputs = boundInputs;
-               _boundInputNames = boundInputNames;
+               _boundInputNames = Arrays.stream(boundInputs).map(i -> 
i.getName())
+                       .collect(Collectors.toCollection(ArrayList::new));
                _funArgNames = funArgNames;
                _boundOutputNames = boundOutputNames;
        }
@@ -81,19 +84,17 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
                int numInputs = Integer.valueOf(parts[3]);
                int numOutputs = Integer.valueOf(parts[4]);
                CPOperand[] boundInputs = new CPOperand[numInputs];
-               List<String> boundInputNames = new ArrayList<>();
                List<String> funArgNames = new ArrayList<>();
                List<String> boundOutputNames = new ArrayList<>();
                for (int i = 0; i < numInputs; i++) {
                        String[] nameValue = 
IOUtilFunctions.splitByFirst(parts[5 + i], "=");
                        boundInputs[i] = new CPOperand(nameValue[1]);
                        funArgNames.add(nameValue[0]);
-                       boundInputNames.add(boundInputs[i].getName());
                }
                for (int i = 0; i < numOutputs; i++)
                        boundOutputNames.add(parts[5 + numInputs + i]);
                return new FunctionCallCPInstruction ( namespace, functionName,
-                       boundInputs, boundInputNames, funArgNames, 
boundOutputNames, str );
+                       boundInputs, funArgNames, boundOutputNames, str );
        }
        
        @Override

Reply via email to