This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new b9114b9  [SYSTEMDS-291] Extended eval function calls (named/unnamed 
list args)
b9114b9 is described below

commit b9114b90a180d79e4ab6de7647e3a35cd7ce3f78
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Apr 14 21:37:28 2020 +0200

    [SYSTEMDS-291] Extended eval function calls (named/unnamed list args)
    
    This patch improves the existing eval function calls by support for
    named and unnamed list inputs. During the function call these list
    arguments are expanded and ordered on demand (if the the function
    doesn't have a signature of a single list argument).
    
    Furthermore, this patch integrates the tests for gridSearch, and makes a
    couple of smaller improvements to list data types (e.g., list append,
    proper function return handling with unknown value types).
---
 scripts/builtin/gridSearch.dml                     | 30 ++++-----
 .../sysds/parser/BuiltinFunctionExpression.java    |  2 +-
 .../sysds/parser/FunctionStatementBlock.java       |  8 +--
 .../controlprogram/context/ExecutionContext.java   |  2 +-
 .../sysds/runtime/instructions/cp/CPOperand.java   |  9 ++-
 .../instructions/cp/EvalNaryCPInstruction.java     | 74 +++++++++++++++++++---
 .../instructions/cp/FunctionCallCPInstruction.java |  2 +-
 .../cp/ListAppendRemoveCPInstruction.java          | 12 +++-
 .../sysds/runtime/instructions/cp/ListObject.java  | 10 +++
 .../functions/builtin/BuiltinGridSearchTest.java   |  9 +--
 .../scripts/functions/builtin/GridSearchLM.dml     | 10 +--
 11 files changed, 121 insertions(+), 47 deletions(-)

diff --git a/scripts/builtin/gridSearch.dml b/scripts/builtin/gridSearch.dml
index 227b863..ca37745 100644
--- a/scripts/builtin/gridSearch.dml
+++ b/scripts/builtin/gridSearch.dml
@@ -35,8 +35,8 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, 
String train, String
     vect = as.matrix(paramValues[j,1]);
     paramVals[j,1:nrow(vect)] = t(vect);
   }
-       cumLens = rev(cumprod(rev(paramLens))/rev(paramLens));
-       numConfigs = prod(paramLens);
+  cumLens = rev(cumprod(rev(paramLens))/rev(paramLens));
+  numConfigs = prod(paramLens);
   
   # Step 1) materialize hyper-parameter combinations 
   # (simplify debugging and compared to compute negligible)
@@ -53,28 +53,22 @@ m_gridSearch = function(Matrix[Double] X, Matrix[Double] y, 
String train, String
   # TODO integrate cross validation
   Rbeta = matrix(0, nrow(HP), ncol(X));
   Rloss = matrix(0, nrow(HP), 1);
-  arguments = list(X=X, y=y);
+  # TODO pass arguments for function call from outside
+  arguments = list(X=X, y=y, icpt=0, reg=-1, tol=-1, maxi=-1, verbose=FALSE);
 
   parfor( i in 1:nrow(HP) ) {
-    # a) prepare training arguments
+    # a) replace training arguments
     largs = arguments;
-    for( j in 1:numParams ) {
-      key = as.scalar(params[j]);
-      value = as.scalar(HP[i,j]);
-      largs = append(largs, list(key=value));
-    }
-
-    # b) core training/scoring
-    lbeta = eval(train, largs);
-    lloss = eval(predict, list(X, y, lbeta));
-
-    # c) write models and loss back to output
-    Rbeta[i,] = lbeta;
-    Rloss[i,] = lloss;
+    for( j in 1:numParams )
+      largs[as.scalar(params[j])] = as.scalar(HP[i,j]);
+    # b) core training/scoring and write-back
+    # TODO investigate rmvar handling with explicit binding (lbeta)
+    Rbeta[i,] = t(eval(train, largs));
+    Rloss[i,] = eval(predict, list(X, y, t(Rbeta[i,])));
   }
 
   # Step 3) select best parameter combination
   ix = as.scalar(rowIndexMin(t(Rloss)));
-  B = Rbeta[ix,];          # optimal model
+  B = t(Rbeta[ix,]);       # optimal model
   opt = as.frame(HP[ix,]); # optimal hyper-parameters
 }
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index cfda652..2c5d61a 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -786,7 +786,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                                //list append
                                
if(getFirstExpr().getOutput().getDataType().isList() )
                                        for(int i=1; i<getAllExpr().length; i++)
-                                               checkDataTypeParam(getExpr(i), 
DataType.SCALAR, DataType.MATRIX, DataType.FRAME);
+                                               checkDataTypeParam(getExpr(i), 
DataType.SCALAR, DataType.MATRIX, DataType.FRAME, DataType.LIST);
                                //matrix append (rbind/cbind)
                                else
                                        for(int i=0; i<getAllExpr().length; i++)
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index 7d32816..a8f3c75 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -86,7 +86,8 @@ public class FunctionStatementBlock extends StatementBlock
                                raiseValidateError("for function " + 
fstmt.getName() + ", return variable " + curr.getName() + " data type of " + 
curr.getDataType() + " does not match data type in function signature of " + 
returnValue.getDataType(), conditional);
                        }
                        
-                       if (curr.getValueType() != ValueType.UNKNOWN && 
!curr.getValueType().equals(returnValue.getValueType())){
+                       if (curr.getValueType() != ValueType.UNKNOWN && 
returnValue.getValueType() != ValueType.UNKNOWN
+                               && 
!curr.getValueType().equals(returnValue.getValueType())){
                                
                                // attempt to convert value type: handle 
conversion from scalar DOUBLE or INT
                                if (curr.getDataType() == DataType.SCALAR && 
returnValue.getDataType() == DataType.SCALAR){ 
@@ -121,9 +122,8 @@ public class FunctionStatementBlock extends StatementBlock
                                                                + " does not 
match value type in function signature of " 
                                                                + 
returnValue.getValueType() + " and cannot safely cast " + curr.getValueType() 
                                                                + " as " + 
returnValue.getValueType());
-                                               
-                                       } 
-                               }       
+                                       }
+                               }
                                else {
                                        throw new 
LanguageException(curr.printErrorLocation() + "for function " + fstmt.getName() 
+ ", return variable " + curr.getName() + " value type of " + 
curr.getValueType() + " does not match value type in function signature of " + 
returnValue.getValueType() + " and cannot safely cast " + curr.getValueType() + 
" as " + returnValue.getValueType());
                                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index d2d9887..2022224 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -223,7 +223,7 @@ public class ExecutionContext {
                if( dat == null )
                        throw new 
DMLRuntimeException(getNonExistingVarError(varname));
                if( !(dat instanceof MatrixObject) )
-                       throw new DMLRuntimeException("Variable '"+varname+"' 
is not a matrix.");
+                       throw new DMLRuntimeException("Variable '"+varname+"' 
is not a matrix: "+dat.getClass().getName());
                
                return (MatrixObject) dat;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
index 2f6e200..be46930 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
@@ -44,7 +44,7 @@ public class CPOperand
                split(str);
        }
        
-       public CPOperand(String name, ValueType vt, DataType dt ) {
+       public CPOperand(String name, ValueType vt, DataType dt) {
                this(name, vt, dt, false);
        }
 
@@ -69,6 +69,13 @@ public class CPOperand
                _isLiteral = variable._isLiteral;
                _literal = variable._literal;
        }
+       
+       public CPOperand(String name, Data dat) {
+               _name = name;
+               _valueType = dat.getValueType();
+               _dataType = dat.getDataType();
+               _isLiteral = false;
+       }
 
        public String getName() {
                return _name;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 62f6c67..ce337b5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -21,6 +21,8 @@ package org.apache.sysds.runtime.instructions.cp;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 
@@ -28,6 +30,7 @@ import org.apache.sysds.common.Builtins;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.rewrite.ProgramRewriter;
+import org.apache.sysds.lops.compile.Dag;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
 import org.apache.sysds.parser.FunctionStatementBlock;
@@ -62,9 +65,9 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                
                // bound the inputs to avoiding being deleted after the 
function call
                CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1, 
inputs.length);
-               ArrayList<String> boundOutputNames = new ArrayList<>();
+               List<String> boundOutputNames = new ArrayList<>();
                boundOutputNames.add(output.getName());
-               ArrayList<String> boundInputNames = new ArrayList<>();
+               List<String> boundInputNames = new ArrayList<>();
                for (CPOperand input : boundInputs) {
                        boundInputNames.add(input.getName());
                }
@@ -73,19 +76,43 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                MatrixObject outputMO = new 
MatrixObject(ec.getMatrixObject(output.getName()));
 
                //3. lazy loading of dml-bodied builtin functions (incl. rename 
-               // of function name to dml-bodied builtin scheme 
(data-type-specific) 
-               if( !ec.getProgram().containsFunctionProgramBlock(null, 
funcName) ) {
-                       compileFunctionProgramBlock(funcName, 
boundInputs[0].getDataType(), ec.getProgram());
-                       funcName = Builtins.getInternalFName(funcName, 
boundInputs[0].getDataType());
+               // of function name to dml-bodied builtin scheme 
(data-type-specific)
+               DataType dt1 = boundInputs[0].getDataType().isList() ? 
+                       DataType.MATRIX : boundInputs[0].getDataType();
+               String funcName2 = Builtins.getInternalFName(funcName, dt1);
+               if( !ec.getProgram().containsFunctionProgramBlock(null, 
funcName)) {
+                       if( 
!ec.getProgram().containsFunctionProgramBlock(null,funcName2) )
+                               compileFunctionProgramBlock(funcName, dt1, 
ec.getProgram());
+                       funcName = funcName2;
                }
                
-               //4. call the function
+               //4. expand list arguments if needed
+               CPOperand[] boundInputs2 = null;
                FunctionProgramBlock fpb = 
ec.getProgram().getFunctionProgramBlock(null, funcName);
+               if( boundInputs.length == 1 && 
boundInputs[0].getDataType().isList()
+                       && fpb.getInputParams().size() > 1 && 
!fpb.getInputParams().get(0).getDataType().isList()) 
+               {
+                       ListObject lo = ec.getListObject(boundInputs[0]);
+                       checkValidArguments(lo.getData(), lo.getNames(), 
fpb.getInputParamNames());
+                       if( lo.isNamedList() )
+                               lo = reorderNamedListForFunctionCall(lo, 
fpb.getInputParamNames());
+                       boundInputs2 = new CPOperand[lo.getLength()];
+                       for( int i=0; i<lo.getLength(); i++ ) {
+                               Data in = lo.getData(i);
+                               String varName = 
Dag.getNextUniqueVarname(in.getDataType());
+                               ec.getVariables().put(varName, in);
+                               boundInputs2[i] = new CPOperand(varName, in);
+                       }
+                       boundInputNames = lo.isNamedList() ? lo.getNames() : 
fpb.getInputParamNames();
+                       boundInputs = boundInputs2;
+               }
+               
+               //5. call the function
                FunctionCallCPInstruction fcpi = new 
FunctionCallCPInstruction(null, funcName,
                        boundInputs, boundInputNames, fpb.getInputParamNames(), 
boundOutputNames, "eval func");
                fcpi.processInstruction(ec);
 
-               //5. convert the result to matrix
+               //6. convert the result to matrix
                Data newOutput = ec.getVariable(output);
                if (newOutput instanceof MatrixObject) {
                        return;
@@ -102,6 +129,12 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                outputMO.acquireModify(mb);
                outputMO.release();
                ec.setVariable(output.getName(), outputMO);
+               
+               //7. cleanup of variable expanded from list
+               if( boundInputs2 != null ) {
+                       for( CPOperand op : boundInputs2 )
+                               
VariableCPInstruction.processRemoveVariableInstruction(ec, op.getName());
+               }
        }
        
        private static void compileFunctionProgramBlock(String name, DataType 
dt, Program prog) {
@@ -151,4 +184,29 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                        prog.addFunctionProgramBlock(null, fsb.getKey(), fpb);
                }
        }
+       
+       private void checkValidArguments(List<Data> loData, List<String> 
loNames, List<String> fArgNames) {
+               //check number of parameters
+               int listSize = (loNames != null) ? loNames.size() : 
loData.size();
+               if( listSize != fArgNames.size() )
+                       throw new DMLRuntimeException("Failed to expand list 
for function call "
+                               + "(mismatching number of arguments: 
"+listSize+" vs. "+fArgNames.size()+").");
+               
+               //check individual parameters
+               if( loNames != null ) {
+                       HashSet<String> probe = new HashSet<>();
+                       for( String var : fArgNames )
+                               probe.add(var);
+                       for( String var : loNames )
+                               if( !probe.contains(var) )
+                                       throw new DMLRuntimeException("List 
argument named '"+var+"' not in function signature.");
+               }
+       }
+       
+       private ListObject reorderNamedListForFunctionCall(ListObject in, 
List<String> fArgNames) {
+               List<Data> sortedData = new ArrayList<>();
+               for( String name : fArgNames )
+                       sortedData.add(in.getData(name));
+               return new ListObject(sortedData, new ArrayList<>(fArgNames));
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index e605a55..9c1eac0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -104,7 +104,7 @@ public class FunctionCallCPInstruction extends 
CPInstruction {
        @Override
        public void processInstruction(ExecutionContext ec) {
                if( LOG.isTraceEnabled() ){
-                       LOG.trace("Executing instruction : " + this.toString());
+                       LOG.trace("Executing instruction : " + toString());
                }
                // get the function program block (stored in the Program object)
                FunctionProgramBlock fpb = 
ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
index b672ff3..1721c4a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListAppendRemoveCPInstruction.java
@@ -46,8 +46,16 @@ public final class ListAppendRemoveCPInstruction extends 
AppendCPInstruction {
                if( getOpcode().equals("append") ) {
                        //copy on write and append unnamed argument
                        Data dat2 = ec.getVariable(input2);
-                       LineageItem li = DMLScript.LINEAGE ? 
ec.getLineage().get(input2):null;
-                       ListObject tmp = lo.copy().add(dat2, li);
+                       LineageItem li = DMLScript.LINEAGE ? 
ec.getLineage().get(input2) : null;
+                       ListObject tmp = null;
+                       if( dat2 instanceof ListObject && 
((ListObject)dat2).getLength() == 1 ) {
+                               //add unfolded elements for lists of size 1 
(e.g., named)
+                               ListObject lo2 = (ListObject) dat2;
+                               tmp = lo.copy().add(lo2.getName(0), 
lo2.getData(0), li);
+                       }
+                       else {
+                               tmp = lo.copy().add(dat2, li);
+                       }
                        //set output variable
                        ec.setVariable(output.getName(), tmp);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 8cfb682..8798799 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -104,6 +104,14 @@ public class ListObject extends Data {
                return _data;
        }
        
+       public Data getData(int ix) {
+               return _data.get(ix);
+       }
+       
+       public Data getData(String name) {
+               return slice(name);
+       }
+       
        public List<LineageItem> getLineageItems() {
                return _lineage;
        }
@@ -219,6 +227,8 @@ public class ListObject extends Data {
                if( _names != null && name == null )
                        throw new DMLRuntimeException("Cannot add to a named 
list");
                //otherwise append and ignore name
+               if( _names != null )
+                       _names.add(name);
                _data.add(dat);
                if (_lineage == null && li!= null) 
                        _lineage = new ArrayList<>();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
index 556a7d7..232db2f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
@@ -45,17 +45,14 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
        
        @Test
        public void testGridSearchCP() {
-               //TODO additional list features needed
-               //runGridSearch(ExecType.CP);
+               runGridSearch(ExecType.CP);
        }
        
        @Test
        public void testGridSearchSpark() {
-               //TODO additional list features needed
-               //runGridSearch(ExecType.SPARK);
+               runGridSearch(ExecType.SPARK);
        }
        
-       @SuppressWarnings("unused")
        private void runGridSearch(ExecType et)
        {
                ExecMode modeOld = setExecMode(et);
@@ -64,7 +61,7 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
                        String HOME = SCRIPT_DIR + TEST_DIR;
        
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[] {"-args", input("X"), 
input("y"), output("R")};
+                       programArgs = new String[] {"-explain","-args", 
input("X"), input("y"), output("R")};
                        double[][] X = getRandomMatrix(rows, cols, 0, 1, 0.8, 
-1);
                        double[][] y = getRandomMatrix(rows, 1, 0, 1, 0.8, -1);
                        writeInputMatrixWithMTD("X", X, true);
diff --git a/src/test/scripts/functions/builtin/GridSearchLM.dml 
b/src/test/scripts/functions/builtin/GridSearchLM.dml
index 9b33713..41a6fa1 100644
--- a/src/test/scripts/functions/builtin/GridSearchLM.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM.dml
@@ -19,8 +19,8 @@
 #
 #-------------------------------------------------------------
 
-l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return 
(Double loss) {
-  loss = sum((y - X%*%B)^2);
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) return 
(Matrix[Double] loss) {
+  loss = as.matrix(sum((y - X%*%B)^2));
 }
 
 X = read($1);
@@ -33,12 +33,12 @@ Xtest = X[(N+1):nrow(X),];
 ytest = y[(N+1):nrow(X),];
 
 params = list("reg", "tol", "maxi");
-paramRanges = list(10^seq(0,-4), 10^seq(-5,-9), 10^seq(1,3));
-[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "lmPredict", params, paramRanges, 
TRUE);
+paramRanges = list(10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+[B1, opt] = gridSearch(Xtrain, ytrain, "lm", "l2norm", params, paramRanges, 
TRUE);
 B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
 
 l1 = l2norm(Xtest, ytest, B1);
 l2 = l2norm(Xtest, ytest, B2);
-R = l1 <= l2;
+R = as.scalar(l1 <= l2);
 
 write(R, $3)

Reply via email to