This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new b9114b9 [SYSTEMDS-291] Extended eval function calls (named/unnamed
list args)
b9114b9 is described below
commit b9114b90a180d79e4ab6de7647e3a35cd7ce3f78
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Apr 14 21:37:28 2020 +0200
[SYSTEMDS-291] Extended eval function calls (named/unnamed list args)
This patch improves the existing eval function calls by support for
named and unnamed list inputs. During the function call these list
arguments are expanded and ordered on demand (if the the function
doesn't have a signature of a single list argument).
Furthermore, this patch integrates the tests for gridSearch, and makes a
couple of smaller improvements to list data types (e.g., list append,
proper function return handling with unknown value types).
---
scripts/builtin/gridSearch.dml | 30 ++++-----
.../sysds/parser/BuiltinFunctionExpression.java | 2 +-
.../sysds/parser/FunctionStatementBlock.java | 8 +--
.../controlprogram/context/ExecutionContext.java | 2 +-
.../sysds/runtime/instructions/cp/CPOperand.java | 9 ++-
.../instructions/cp/EvalNaryCPInstruction.java | 74 +++++++++++++++++++---
.../instructions/cp/FunctionCallCPInstruction.java | 2 +-
.../cp/ListAppendRemoveCPInstruction.java | 12 +++-
.../sysds/runtime/instructions/cp/ListObject.java | 10 +++
.../functions/builtin/BuiltinGridSearchTest.java | 9 +--
.../scripts/functions/builtin/GridSearchLM.dml | 10 +--
11 files changed, 121 insertions(+), 47 deletions(-)
diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
index 227b863..ca37745 100644
--- a/scripts/builtin/gridSearch.dml
+++ b/scripts/builtin/gridSearch.dml
@@ -35,8 +35,8 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y,
String train, String
vect = as.matrix(paramValues[j,1]);
paramVals[j,1:nrow(vect)] = t(vect);
}
- cumLens = rev(cumprod(rev(paramLens))/rev(paramLens));
- numConfigs = prod(paramLens);
+ cumLens = rev(cumprod(rev(paramLens))/rev(paramLens));
+ numConfigs = prod(paramLens);
# Step 1) materialize hyper-parameter combinations
# (simplify debugging and compared to compute negligible)
@@ -53,28 +53,22 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y,
String train, String
# TODO integrate cross validation
Rbeta = matrix(0, nrow(HP), ncol(X));
Rloss = matrix(0, nrow(HP), 1);
- arguments = list(X=X, y=y);
+ # TODO pass arguments for function call from outside
+ arguments = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
parfor( i in 1:nrow(HP) ) {
- # a) prepare training arguments
+ # a) replace training arguments
largs = arguments;
- for( j in 1:numParams ) {
- key = as.scalar(params[j]);
- value = as.scalar(HP[i,j]);
- largs = append(largs, list(key=value));
- }
-
- # b) core training/scoring
- lbeta = eval(train, largs);
- lloss = eval(predict, list(X, y, lbeta));
-
- # c) write models and loss back to output
- Rbeta[i,] = lbeta;
- Rloss[i,] = lloss;
+ for( j in 1:numParams )
+ largs[as.scalar(params[j])] = as.scalar(HP[i,j]);
+ # b) core training/scoring and write-back
+ # TODO investigate rmvar handling with explicit binding (lbeta)
+ Rbeta[i,] = t(eval(train, largs));
+ Rloss[i,] = eval(predict, list(X, y, t(Rbeta[i,])));
}
# Step 3) select best parameter combination
ix = as.scalar(rowIndexMin(t(Rloss)));
- B = Rbeta[ix,]; # optimal model
+ B = t(Rbeta[ix,]); # optimal model
opt = as.frame(HP[ix,]); # optimal hyper-parameters
}
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index cfda652..2c5d61a 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -786,7 +786,7 @@ public class BuiltinFunctionExpression extends
DataIdentifier
//list append
if(getFirstExpr().getOutput().getDataType().isList() )
for(int i=1; i<getAllExpr().length; i++)
- checkDataTypeParam(getExpr(i),
DataType.SCALAR, DataType.MATRIX, DataType.FRAME);
+ checkDataTypeParam(getExpr(i),
DataType.SCALAR, DataType.MATRIX, DataType.FRAME, DataType.LIST);
//matrix append (rbind/cbind)
else
for(int i=0; i<getAllExpr().length; i++)
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index 7d32816..a8f3c75 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -86,7 +86,8 @@ public class FunctionStatementBlock extends StatementBlock
raiseValidateError("for function " +
fstmt.getName() + ", return variable " + curr.getName() + " data type of " +
curr.getDataType() + " does not match data type in function signature of " +
returnValue.getDataType(), conditional);
}
- if (curr.getValueType() != ValueType.UNKNOWN &&
!curr.getValueType().equals(returnValue.getValueType())){
+ if (curr.getValueType() != ValueType.UNKNOWN &&
returnValue.getValueType() != ValueType.UNKNOWN
+ &&
!curr.getValueType().equals(returnValue.getValueType())){
// attempt to convert value type: handle
conversion from scalar DOUBLE or INT
if (curr.getDataType() == DataType.SCALAR &&
returnValue.getDataType() == DataType.SCALAR){
@@ -121,9 +122,8 @@ public class FunctionStatementBlock extends StatementBlock
+ " does not
match value type in function signature of "
+
returnValue.getValueType() + " and cannot safely cast " + curr.getValueType()
+ " as " +
returnValue.getValueType());
-
- }
- }
+ }
+ }
else {
throw new
LanguageException(curr.printErrorLocation() + "for function " + fstmt.getName()
+ ", return variable " + curr.getName() + " value type of " +
curr.getValueType() + " does not match value type in function signature of " +
returnValue.getValueType() + " and cannot safely cast " + curr.getValueType() +
" as " + returnValue.getValueType());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index d2d9887..2022224 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -223,7 +223,7 @@ public class ExecutionContext {
if( dat == null )
throw new
DMLRuntimeException(getNonExistingVarError(varname));
if( !(dat instanceof MatrixObject) )
- throw new DMLRuntimeException("Variable '"+varname+"'
is not a matrix.");
+ throw new DMLRuntimeException("Variable '"+varname+"'
is not a matrix: "+dat.getClass().getName());
return (MatrixObject) dat;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
index 2f6e200..be46930 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
@@ -44,7 +44,7 @@ public class CPOperand
split(str);
}
- public CPOperand(String name, ValueType vt, DataType dt ) {
+ public CPOperand(String name, ValueType vt, DataType dt) {
this(name, vt, dt, false);
}
@@ -69,6 +69,13 @@ public class CPOperand
_isLiteral = variable._isLiteral;
_literal = variable._literal;
}
+
+ public CPOperand(String name, Data dat) {
+ _name = name;
+ _valueType = dat.getValueType();
+ _dataType = dat.getDataType();
+ _isLiteral = false;
+ }
public String getName() {
return _name;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 62f6c67..ce337b5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -21,6 +21,8 @@ package org.apache.sysds.runtime.instructions.cp;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
@@ -28,6 +30,7 @@ import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
+import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.FunctionStatementBlock;
@@ -62,9 +65,9 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
// bound the inputs to avoiding being deleted after the
function call
CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1,
inputs.length);
- ArrayList<String> boundOutputNames = new ArrayList<>();
+ List<String> boundOutputNames = new ArrayList<>();
boundOutputNames.add(output.getName());
- ArrayList<String> boundInputNames = new ArrayList<>();
+ List<String> boundInputNames = new ArrayList<>();
for (CPOperand input : boundInputs) {
boundInputNames.add(input.getName());
}
@@ -73,19 +76,43 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
MatrixObject outputMO = new
MatrixObject(ec.getMatrixObject(output.getName()));
//3. lazy loading of dml-bodied builtin functions (incl. rename
- // of function name to dml-bodied builtin scheme
(data-type-specific)
- if( !ec.getProgram().containsFunctionProgramBlock(null,
funcName) ) {
- compileFunctionProgramBlock(funcName,
boundInputs[0].getDataType(), ec.getProgram());
- funcName = Builtins.getInternalFName(funcName,
boundInputs[0].getDataType());
+ // of function name to dml-bodied builtin scheme
(data-type-specific)
+ DataType dt1 = boundInputs[0].getDataType().isList() ?
+ DataType.MATRIX : boundInputs[0].getDataType();
+ String funcName2 = Builtins.getInternalFName(funcName, dt1);
+ if( !ec.getProgram().containsFunctionProgramBlock(null,
funcName)) {
+ if(
!ec.getProgram().containsFunctionProgramBlock(null,funcName2) )
+ compileFunctionProgramBlock(funcName, dt1,
ec.getProgram());
+ funcName = funcName2;
}
- //4. call the function
+ //4. expand list arguments if needed
+ CPOperand[] boundInputs2 = null;
FunctionProgramBlock fpb =
ec.getProgram().getFunctionProgramBlock(null, funcName);
+ if( boundInputs.length == 1 &&
boundInputs[0].getDataType().isList()
+ && fpb.getInputParams().size() > 1 &&
!fpb.getInputParams().get(0).getDataType().isList())
+ {
+ ListObject lo = ec.getListObject(boundInputs[0]);
+ checkValidArguments(lo.getData(), lo.getNames(),
fpb.getInputParamNames());
+ if( lo.isNamedList() )
+ lo = reorderNamedListForFunctionCall(lo,
fpb.getInputParamNames());
+ boundInputs2 = new CPOperand[lo.getLength()];
+ for( int i=0; i<lo.getLength(); i++ ) {
+ Data in = lo.getData(i);
+ String varName =
Dag.getNextUniqueVarname(in.getDataType());
+ ec.getVariables().put(varName, in);
+ boundInputs2[i] = new CPOperand(varName, in);
+ }
+ boundInputNames = lo.isNamedList() ? lo.getNames() :
fpb.getInputParamNames();
+ boundInputs = boundInputs2;
+ }
+
+ //5. call the function
FunctionCallCPInstruction fcpi = new
FunctionCallCPInstruction(null, funcName,
boundInputs, boundInputNames, fpb.getInputParamNames(),
boundOutputNames, "eval func");
fcpi.processInstruction(ec);
- //5. convert the result to matrix
+ //6. convert the result to matrix
Data newOutput = ec.getVariable(output);
if (newOutput instanceof MatrixObject) {
return;
@@ -102,6 +129,12 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
outputMO.acquireModify(mb);
outputMO.release();
ec.setVariable(output.getName(), outputMO);
+
+ //7. cleanup of variable expanded from list
+ if( boundInputs2 != null ) {
+ for( CPOperand op : boundInputs2 )
+
VariableCPInstruction.processRemoveVariableInstruction(ec, op.getName());
+ }
}
private static void compileFunctionProgramBlock(String name, DataType
dt, Program prog) {
@@ -151,4 +184,29 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
prog.addFunctionProgramBlock(null, fsb.getKey(), fpb);
}
}
+
+ private void checkValidArguments(List<Data> loData, List<String>
loNames, List<String> fArgNames) {
+ //check number of parameters
+ int listSize = (loNames != null) ? loNames.size() :
loData.size();
+ if( listSize != fArgNames.size() )
+ throw new DMLRuntimeException("Failed to expand list
for function call "
+ + "(mismatching number of arguments:
"+listSize+" vs. "+fArgNames.size()+").");
+
+ //check individual parameters
+ if( loNames != null ) {
+ HashSet<String> probe = new HashSet<>();
+ for( String var : fArgNames )
+ probe.add(var);
+ for( String var : loNames )
+ if( !probe.contains(var) )
+ throw new DMLRuntimeException("List
argument named '"+var+"' not in function signature.");
+ }
+ }
+
+ private ListObject reorderNamedListForFunctionCall(ListObject in,
List<String> fArgNames) {
+ List<Data> sortedData = new ArrayList<>();
+ for( String name : fArgNames )
+ sortedData.add(in.getData(name));
+ return new ListObject(sortedData, new ArrayList<>(fArgNames));
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index e605a55..9c1eac0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -104,7 +104,7 @@ public class FunctionCallCPInstruction extends
CPInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
if( LOG.isTraceEnabled() ){
- LOG.trace("Executing instruction : " + this.toString());
+ LOG.trace("Executing instruction : " + toString());
}
// get the function program block (stored in the Program object)
FunctionProgramBlock fpb =
ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
index b672ff3..1721c4a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
@@ -46,8 +46,16 @@ public final class ListAppendRemoveCPInstruction extends
AppendCPInstruction {
if( getOpcode().equals("append") ) {
//copy on write and append unnamed argument
Data dat2 = ec.getVariable(input2);
- LineageItem li = DMLScript.LINEAGE ?
ec.getLineage().get(input2):null;
- ListObject tmp = lo.copy().add(dat2, li);
+ LineageItem li = DMLScript.LINEAGE ?
ec.getLineage().get(input2) : null;
+ ListObject tmp = null;
+ if( dat2 instanceof ListObject &&
((ListObject)dat2).getLength() == 1 ) {
+ //add unfolded elements for lists of size 1
(e.g., named)
+ ListObject lo2 = (ListObject) dat2;
+ tmp = lo.copy().add(lo2.getName(0),
lo2.getData(0), li);
+ }
+ else {
+ tmp = lo.copy().add(dat2, li);
+ }
//set output variable
ec.setVariable(output.getName(), tmp);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 8cfb682..8798799 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -104,6 +104,14 @@ public class ListObject extends Data {
return _data;
}
+ public Data getData(int ix) {
+ return _data.get(ix);
+ }
+
+ public Data getData(String name) {
+ return slice(name);
+ }
+
public List<LineageItem> getLineageItems() {
return _lineage;
}
@@ -219,6 +227,8 @@ public class ListObject extends Data {
if( _names != null && name == null )
throw new DMLRuntimeException("Cannot add to a named
list");
//otherwise append and ignore name
+ if( _names != null )
+ _names.add(name);
_data.add(dat);
if (_lineage == null && li!= null)
_lineage = new ArrayList<>();
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
index 556a7d7..232db2f 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
@@ -45,17 +45,14 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
@Test
public void testGridSearchCP() {
- //TODO additional list features needed
- //runGridSearch(ExecType.CP);
+ runGridSearch(ExecType.CP);
}
@Test
public void testGridSearchSpark() {
- //TODO additional list features needed
- //runGridSearch(ExecType.SPARK);
+ runGridSearch(ExecType.SPARK);
}
- @SuppressWarnings("unused")
private void runGridSearch(ExecType et)
{
ExecMode modeOld = setExecMode(et);
@@ -64,7 +61,7 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-args", input("X"),
input("y"), output("R")};
+ programArgs = new String[] {"-explain","-args",
input("X"), input("y"), output("R")};
double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8,
-1);
double[][] y = getRandomMatrix(rows, 1, 0, 1, 0.8, -1);
writeInputMatrixWithMTD("X", X, true);
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml
b/src/test/scripts/functions/builtin/GridSearchLM.dml
index 9b33713..41a6fa1 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM.dml
@@ -19,8 +19,8 @@
#
#-------------------------------------------------------------
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return
(Double loss) {
- loss = sum((y - X%*%B)^2);
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return
(Matrix[Double] loss) {
+ loss = as.matrix(sum((y - X%*%B)^2));
}
X = read($1);
@@ -33,12 +33,12 @@ Xtest = X[(N+1):nrow(X),];
ytest = y[(N+1):nrow(X),];
params = list("reg", "tol", "maxi");
-paramRanges = list(10^seq(0,-4), 10^seq(-5,-9), 10^seq(1,3));
-[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "lmPredict", params, paramRanges,
TRUE);
+paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "l2norm", params, paramRanges,
TRUE);
B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
l1 = l2norm(Xtest, ytest, B1);
l2 = l2norm(Xtest, ytest, B2);
-R = l1 <= l2;
+R = as.scalar(l1 <= l2);
write(R, $3)