This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new e581b5a [SYSTEMDS-2575] Fix eval function calls (incorrect pinning of
inputs)
e581b5a is described below
commit e581b5a6248b56a70e18ffe6ba699e8142a2d679
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Jul 20 21:37:21 2020 +0200
[SYSTEMDS-2575] Fix eval function calls (incorrect pinning of inputs)
This patch fixes an issue of indirect eval function calls where wrong
input variable names led to missing pinning of inputs and thus too eager
cleanup of these variables (which causes crashes if the inputs are used
in other operations of the eval call).
The fix is simple. We avoid such inconsistent construction and
invocation of fcall instructions by using a narrower interface and
constructing the materialized names internally in the fcall.
---
.../runtime/controlprogram/paramserv/PSWorker.java | 4 +--
.../controlprogram/paramserv/ParamServer.java | 4 +--
.../instructions/cp/EvalNaryCPInstruction.java | 34 +++++++++-------------
.../instructions/cp/FunctionCallCPInstruction.java | 11 +++----
4 files changed, 22 insertions(+), 31 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index 9f2311b..0eb9cf9 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -77,12 +77,10 @@ public abstract class PSWorker implements Serializable
CPOperand[] boundInputs = inputs.stream()
.map(input -> new CPOperand(input.getName(),
input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
- ArrayList<String> inputNames =
inputs.stream().map(DataIdentifier::getName)
- .collect(Collectors.toCollection(ArrayList::new));
ArrayList<String> outputNames =
outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
_inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
- inputNames, func.getInputParamNames(), outputNames,
"update function");
+ func.getInputParamNames(), outputNames, "update
function");
// Check the inputs of the update function
checkInput(false, inputs, DataType.MATRIX,
Statement.PS_FEATURES);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 92e29b1..81cee33 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -104,12 +104,10 @@ public abstract class ParamServer
CPOperand[] boundInputs = inputs.stream()
.map(input -> new CPOperand(input.getName(),
input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
- ArrayList<String> inputNames =
inputs.stream().map(DataIdentifier::getName)
- .collect(Collectors.toCollection(ArrayList::new));
ArrayList<String> outputNames =
outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
_inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
- inputNames, func.getInputParamNames(), outputNames,
"aggregate function");
+ func.getInputParamNames(), outputNames, "aggregate
function");
}
public abstract void push(int workerID, ListObject value);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index f10e7bb..070a3fc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -67,10 +67,6 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1,
inputs.length);
List<String> boundOutputNames = new ArrayList<>();
boundOutputNames.add(output.getName());
- List<String> boundInputNames = new ArrayList<>();
- for (CPOperand input : boundInputs) {
- boundInputNames.add(input.getName());
- }
//2. copy the created output matrix
MatrixObject outputMO = new
MatrixObject(ec.getMatrixObject(output.getName()));
@@ -103,32 +99,30 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
ec.getVariables().put(varName, in);
boundInputs2[i] = new CPOperand(varName, in);
}
- boundInputNames = lo.isNamedList() ? lo.getNames() :
fpb.getInputParamNames();
boundInputs = boundInputs2;
}
//5. call the function
FunctionCallCPInstruction fcpi = new
FunctionCallCPInstruction(null, funcName,
- boundInputs, boundInputNames, fpb.getInputParamNames(),
boundOutputNames, "eval func");
+ boundInputs, fpb.getInputParamNames(),
boundOutputNames, "eval func");
fcpi.processInstruction(ec);
//6. convert the result to matrix
Data newOutput = ec.getVariable(output);
- if (newOutput instanceof MatrixObject) {
- return;
- }
- MatrixBlock mb = null;
- if (newOutput instanceof ScalarObject) {
- //convert scalar to matrix
- mb = new MatrixBlock(((ScalarObject)
newOutput).getDoubleValue());
- } else if (newOutput instanceof FrameObject) {
- //convert frame to matrix
- mb = DataConverter.convertToMatrixBlock(((FrameObject)
newOutput).acquireRead());
- ec.cleanupCacheableData((FrameObject) newOutput);
+ if (!(newOutput instanceof MatrixObject)) {
+ MatrixBlock mb = null;
+ if (newOutput instanceof ScalarObject) {
+ //convert scalar to matrix
+ mb = new MatrixBlock(((ScalarObject)
newOutput).getDoubleValue());
+ } else if (newOutput instanceof FrameObject) {
+ //convert frame to matrix
+ mb =
DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
+ ec.cleanupCacheableData((FrameObject)
newOutput);
+ }
+ outputMO.acquireModify(mb);
+ outputMO.release();
+ ec.setVariable(output.getName(), outputMO);
}
- outputMO.acquireModify(mb);
- outputMO.release();
- ec.setVariable(output.getName(), outputMO);
//7. cleanup of variable expanded from list
if( boundInputs2 != null ) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 695b07a..8b88647 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -20,8 +20,10 @@
package org.apache.sysds.runtime.instructions.cp;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
+import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.lops.Lop;
@@ -55,12 +57,13 @@ public class FunctionCallCPInstruction extends
CPInstruction {
private final List<String> _boundOutputNames;
public FunctionCallCPInstruction(String namespace, String functName,
CPOperand[] boundInputs,
- List<String> boundInputNames, List<String> funArgNames,
List<String> boundOutputNames, String istr) {
+ List<String> funArgNames, List<String>
boundOutputNames, String istr) {
super(CPType.External, null, functName, istr);
_functionName = functName;
_namespace = namespace;
_boundInputs = boundInputs;
- _boundInputNames = boundInputNames;
+ _boundInputNames = Arrays.stream(boundInputs).map(i ->
i.getName())
+ .collect(Collectors.toCollection(ArrayList::new));
_funArgNames = funArgNames;
_boundOutputNames = boundOutputNames;
}
@@ -81,19 +84,17 @@ public class FunctionCallCPInstruction extends
CPInstruction {
int numInputs = Integer.valueOf(parts[3]);
int numOutputs = Integer.valueOf(parts[4]);
CPOperand[] boundInputs = new CPOperand[numInputs];
- List<String> boundInputNames = new ArrayList<>();
List<String> funArgNames = new ArrayList<>();
List<String> boundOutputNames = new ArrayList<>();
for (int i = 0; i < numInputs; i++) {
String[] nameValue =
IOUtilFunctions.splitByFirst(parts[5 + i], "=");
boundInputs[i] = new CPOperand(nameValue[1]);
funArgNames.add(nameValue[0]);
- boundInputNames.add(boundInputs[i].getName());
}
for (int i = 0; i < numOutputs; i++)
boundOutputNames.add(parts[5 + numInputs + i]);
return new FunctionCallCPInstruction ( namespace, functionName,
- boundInputs, boundInputNames, funArgNames,
boundOutputNames, str );
+ boundInputs, funArgNames, boundOutputNames, str );
}
@Override