This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f1cdf3  [SYSTEMDS-291] Extended eval lazy function compilation 
(nested builtins)
5f1cdf3 is described below

commit 5f1cdf367b0616359461f1fd198898d59f0598a4
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Apr 13 18:39:47 2020 +0200

    [SYSTEMDS-291] Extended eval lazy function compilation (nested builtins)
    
    This patch extends the lazy function compilation of dml-bodied builtin
    functions called through eval. We now support nested dml-bodied function
    calls (e.g., eval -> lm -> lmDS/lmCG) which is crucial for generic
    primitives of hyper-parameter optimization and the enumeration of
    cleaning pipelines.
---
 .../sysds/hops/rewrite/RewriteConstantFolding.java |  2 +-
 .../java/org/apache/sysds/parser/DMLProgram.java   |  4 ++
 .../org/apache/sysds/parser/DMLTranslator.java     |  2 +-
 .../sysds/parser/FunctionCallIdentifier.java       |  8 +--
 .../sysds/parser/FunctionStatementBlock.java       | 14 ++---
 .../org/apache/sysds/parser/IfStatementBlock.java  |  4 +-
 .../org/apache/sysds/parser/StatementBlock.java    |  2 +-
 .../sysds/parser/dml/DmlSyntacticValidator.java    |  8 ++-
 .../sysds/runtime/controlprogram/Program.java      | 18 +++++-
 .../controlprogram/paramserv/ParamservUtils.java   |  2 +-
 .../instructions/cp/EvalNaryCPInstruction.java     | 70 ++++++++++++++--------
 .../sysds/runtime/lineage/LineageRewriteReuse.java |  2 +-
 .../test/functions/mlcontext/MLContextTest.java    | 10 ++++
 .../mlcontext/eval4-nested_builtin-test.dml        | 30 ++++++++++
 14 files changed, 129 insertions(+), 47 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
index ec098e6..6e04082 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
@@ -184,7 +184,7 @@ public class RewriteConstantFolding extends HopRewriteRule
        
        private BasicProgramBlock getProgramBlock() {
                if( _tmpPB == null )
-                       _tmpPB = new BasicProgramBlock( new Program() );
+                       _tmpPB = new BasicProgramBlock(new Program());
                return _tmpPB;
        }
        
diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java 
b/src/main/java/org/apache/sysds/parser/DMLProgram.java
index e86464c..4e5e229 100644
--- a/src/main/java/org/apache/sysds/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java
@@ -131,6 +131,10 @@ public class DMLProgram
                return ret;
        }
 
+       public boolean containsFunctionStatementBlock(String name) {
+               return _functionBlocks.containsKey(name);
+       }
+       
        public void addFunctionStatementBlock(String fname, 
FunctionStatementBlock fsb) {
                _functionBlocks.put(fname, fsb);
        }
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 9e41f9b..e61c928 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -412,7 +412,7 @@ public class DMLTranslator
                throws LanguageException, DMLRuntimeException, LopsException, 
HopsException 
        {       
                // constructor resets the set of registered functions
-               Program rtprog = new Program();
+               Program rtprog = new Program(prog);
                
                // for all namespaces, translate function statement blocks into 
function program blocks
                for (String namespace : prog.getNamespaces().keySet()){
diff --git a/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java 
b/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
index fc5e1d8..497d591 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
@@ -115,8 +115,8 @@ public class FunctionCallIdentifier extends DataIdentifier
                }
                if (hasNamed && hasUnnamed){
                        raiseValidateError(" In DML, functions can only have 
named parameters " +
-                                       "(e.g., name1=value1, name2=value2) or 
unnamed parameters (e.g, value1, value2). " + 
-                                       _name + " has both parameter types.", 
conditional);
+                               "(e.g., name1=value1, name2=value2) or unnamed 
parameters (e.g, value1, value2). " + 
+                               _name + " has both parameter types.", 
conditional);
                }
                
                // Step 4: validate expressions for each passed parameter
@@ -176,8 +176,8 @@ public class FunctionCallIdentifier extends DataIdentifier
                if (_namespace != null && _namespace.length() > 0 && 
!_namespace.equals(DMLProgram.DEFAULT_NAMESPACE))
                        sb.append(_namespace + "::");
                sb.append(_name);
-               sb.append(" ( ");               
-                               
+               sb.append(" ( ");
+               
                for (int i = 0; i < _paramExprs.size(); i++){
                        sb.append(_paramExprs.get(i).toString());
                        if (i<_paramExprs.size() - 1) 
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index aad710b..7d32816 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -51,13 +51,13 @@ public class FunctionStatementBlock extends StatementBlock
                        
                // validate all function input parameters
                ArrayList<DataIdentifier> inputValues = fstmt.getInputParams();
-        for( DataIdentifier inputValue : inputValues ) {
-            //check all input matrices have value type double
-            if( inputValue.getDataType()==DataType.MATRIX && 
inputValue.getValueType()!=ValueType.FP64 ) {
-                raiseValidateError("for function " + fstmt.getName() + ", 
input variable " + inputValue.getName() 
-                                 + " has an unsupported value type of " + 
inputValue.getValueType() + ".", false);
-            }
-        }
+               for( DataIdentifier inputValue : inputValues ) {
+                       //check all input matrices have value type double
+                       if( inputValue.getDataType()==DataType.MATRIX && 
inputValue.getValueType()!=ValueType.FP64 ) {
+                               raiseValidateError("for function " + 
fstmt.getName() + ", input variable " + inputValue.getName() 
+                                       + " has an unsupported value type of " 
+ inputValue.getValueType() + ".", false);
+                       }
+               }
                
                // handle DML-bodied functions
                // perform validate for function body
diff --git a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
index 7322bce..4762a14 100644
--- a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
@@ -55,9 +55,9 @@ public class IfStatementBlock extends StatementBlock
                HashMap<String,ConstIdentifier> constVarsIfCopy = new 
HashMap<>(constVars);
                HashMap<String,ConstIdentifier> constVarsElseCopy = new 
HashMap<> (constVars);
                
-               VariableSet idsIfCopy   = new VariableSet(ids);
+               VariableSet idsIfCopy   = new VariableSet(ids);
                VariableSet idsElseCopy = new VariableSet(ids);
-               VariableSet     idsOrigCopy = new VariableSet(ids);
+               VariableSet idsOrigCopy = new VariableSet(ids);
 
                // handle if stmt body
                _dmlProg = dmlProg;
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 5991315..f275a84 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -230,7 +230,7 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                return true;
        }
 
-    public boolean isRewritableFunctionCall(Statement stmt, DMLProgram 
dmlProg) {
+       public boolean isRewritableFunctionCall(Statement stmt, DMLProgram 
dmlProg) {
 
                // for regular stmt, check if this is a function call stmt block
                if (stmt instanceof AssignmentStatement || stmt instanceof 
MultiAssignmentStatement){
diff --git 
a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java 
b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
index 5e2cae5..5841e3b 100644
--- a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
@@ -610,18 +610,20 @@ public class DmlSyntacticValidator implements DmlListener 
{
                }
        }
        
-       public static FunctionStatementBlock loadAndParseBuiltinFunction(String 
name, String namespace, DataType dt) {
+       public static Map<String,FunctionStatementBlock> 
loadAndParseBuiltinFunction(String name, String namespace) {
                if( !Builtins.contains(name, true, false) ) {
                        throw new DMLRuntimeException("Function "
                                + DMLProgram.constructFunctionKey(namespace, 
name)+" is not a builtin function.");
                }
                //load and add builtin DML-bodied functions (via tmp validator 
instance)
+               //including nested builtin function calls unless already loaded
                DmlSyntacticValidator tmp = new DmlSyntacticValidator(
                        new CustomErrorListener(), new HashMap<>(), namespace, 
new HashSet<>());
                String filePath = Builtins.getFilePath(name);
                DMLProgram prog = tmp.parseAndAddImportedFunctions(namespace, 
filePath, null);
-               String name2 = Builtins.getInternalFName(name, dt);
-               return prog.getNamedFunctionStatementBlocks().get(name2);
+               
+               //construct output map of all functions
+               return prog.getNamedFunctionStatementBlocks();
        }
 
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
index e868a38..03a516b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
@@ -33,7 +33,8 @@ public class Program
 {
        public static final String KEY_DELIM = "::";
        
-       public ArrayList<ProgramBlock> _programBlocks;
+       private DMLProgram _prog;
+       private ArrayList<ProgramBlock> _programBlocks;
 
        private HashMap<String, HashMap<String,FunctionProgramBlock>> 
_namespaceFunctions;
        
@@ -42,7 +43,20 @@ public class Program
                _namespaceFunctions.put(DMLProgram.DEFAULT_NAMESPACE, new 
HashMap<>());
                _programBlocks = new ArrayList<>();
        }
+       
+       public Program(DMLProgram prog) {
+               this();
+               setDMLProg(prog);
+       }
 
+       public void setDMLProg(DMLProgram prog) {
+               _prog = prog;
+       }
+       
+       public DMLProgram getDMLProg() {
+               return _prog;
+       }
+       
        public synchronized void addFunctionProgramBlock(String namespace, 
String fname, FunctionProgramBlock fpb) {
                if( fpb == null )
                        throw new DMLRuntimeException("Invalid null function 
program block.");
@@ -124,7 +138,7 @@ public class Program
        public Program clone(boolean deep) {
                if( deep )
                        throw new NotImplementedException();
-               Program ret = new Program();
+               Program ret = new Program(_prog);
                //shallow copy of all program blocks
                ret._programBlocks.addAll(_programBlocks);
                //shallow copy of all functions, except external 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index c8b8a3a..84fc2c9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -252,7 +252,7 @@ public class ParamservUtils {
        }
        
        private static Program copyProgramFunctions(Program prog) {
-               Program newProg = new Program();
+               Program newProg = new Program(prog.getDMLProg());
                prog.getFunctionProgramBlocks()
                        .forEach((func, pb) -> putFunction(newProg, 
copyFunction(func, pb)));
                return newProg;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 92b7227..62f6c67 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -21,7 +21,10 @@ package org.apache.sysds.runtime.instructions.cp;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Map;
+import java.util.Map.Entry;
 
+import org.apache.sysds.common.Builtins;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.rewrite.ProgramRewriter;
@@ -69,11 +72,11 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                //2. copy the created output matrix
                MatrixObject outputMO = new 
MatrixObject(ec.getMatrixObject(output.getName()));
 
-               //3. lazy loading of dml-bodied builtin functions
+               //3. lazy loading of dml-bodied builtin functions (incl. rename 
+               // of function name to dml-bodied builtin scheme 
(data-type-specific) 
                if( !ec.getProgram().containsFunctionProgramBlock(null, 
funcName) ) {
-                       FunctionProgramBlock fpb = compileFunctionProgramBlock(
-                               funcName, boundInputs[0].getDataType(), 
ec.getProgram());
-                       ec.getProgram().addFunctionProgramBlock(null, funcName, 
fpb);
+                       compileFunctionProgramBlock(funcName, 
boundInputs[0].getDataType(), ec.getProgram());
+                       funcName = Builtins.getInternalFName(funcName, 
boundInputs[0].getDataType());
                }
                
                //4. call the function
@@ -101,32 +104,51 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                ec.setVariable(output.getName(), outputMO);
        }
        
-       private static FunctionProgramBlock compileFunctionProgramBlock(String 
name, DataType dt, Program prog) {
+       private static void compileFunctionProgramBlock(String name, DataType 
dt, Program prog) {
                //load builtin file and parse function statement block
-               FunctionStatementBlock fsb = DmlSyntacticValidator
-                       .loadAndParseBuiltinFunction(name, 
DMLProgram.DEFAULT_NAMESPACE, dt);
+               Map<String,FunctionStatementBlock> fsbs = DmlSyntacticValidator
+                       .loadAndParseBuiltinFunction(name, 
DMLProgram.DEFAULT_NAMESPACE);
+               if( fsbs.isEmpty() )
+                       throw new DMLRuntimeException("Failed to compile 
function '"+name+"'.");
                
-               // validate function (could be avoided for performance because 
known builtin functions)
-               DMLProgram dmlp = fsb.getDMLProg();
+               // prepare common data structures, including a consolidated dml 
program
+               // to facilitate function validation which tries to inline 
lazily loaded
+               // and existing functions.
+               DMLProgram dmlp = (prog.getDMLProg() != null) ? 
prog.getDMLProg() :
+                       fsbs.get(Builtins.getInternalFName(name, 
dt)).getDMLProg();
+               for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() 
) {
+                       if( !dmlp.containsFunctionStatementBlock(fsb.getKey()) )
+                               dmlp.addFunctionStatementBlock(fsb.getKey(), 
fsb.getValue());
+                       fsb.getValue().setDMLProg(dmlp);
+               }
                DMLTranslator dmlt = new DMLTranslator(dmlp);
-               dmlt.liveVariableAnalysisFunction(dmlp, fsb);
-               dmlt.validateFunction(dmlp, fsb);
-               
-               // compile hop dags, rewrite hop dags and compile lop dags
-               dmlt.constructHops(fsb);
                ProgramRewriter rewriter = new ProgramRewriter(true, false);
-               rewriter.rewriteHopDAGsFunction(fsb, false); //rewrite and merge
-               DMLTranslator.resetHopsDAGVisitStatus(fsb);
-               rewriter.rewriteHopDAGsFunction(fsb, true); //rewrite and split
-               DMLTranslator.resetHopsDAGVisitStatus(fsb);
                ProgramRewriter rewriter2 = new ProgramRewriter(false, true);
-               rewriter2.rewriteHopDAGsFunction(fsb, true);
-               DMLTranslator.resetHopsDAGVisitStatus(fsb);
-               DMLTranslator.refreshMemEstimates(fsb);
-               dmlt.constructLops(fsb);
+               
+               // validate functions, in two passes for cross references
+               for( FunctionStatementBlock fsb : fsbs.values() ) {
+                       dmlt.liveVariableAnalysisFunction(dmlp, fsb);
+                       dmlt.validateFunction(dmlp, fsb);
+               }
+               
+               // compile hop dags, rewrite hop dags and compile lop dags
+               for( FunctionStatementBlock fsb : fsbs.values() ) {
+                       dmlt.constructHops(fsb);
+                       rewriter.rewriteHopDAGsFunction(fsb, false); //rewrite 
and merge
+                       DMLTranslator.resetHopsDAGVisitStatus(fsb);
+                       rewriter.rewriteHopDAGsFunction(fsb, true); //rewrite 
and split
+                       DMLTranslator.resetHopsDAGVisitStatus(fsb);
+                       rewriter2.rewriteHopDAGsFunction(fsb, true);
+                       DMLTranslator.resetHopsDAGVisitStatus(fsb);
+                       DMLTranslator.refreshMemEstimates(fsb);
+                       dmlt.constructLops(fsb);
+               }
                
                // compile runtime program
-               return (FunctionProgramBlock) dmlt.createRuntimeProgramBlock(
-                       prog, fsb, ConfigurationManager.getDMLConfig());
+               for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() 
) {
+                       FunctionProgramBlock fpb = (FunctionProgramBlock) dmlt
+                               .createRuntimeProgramBlock(prog, 
fsb.getValue(), ConfigurationManager.getDMLConfig());
+                       prog.addFunctionProgramBlock(null, fsb.getKey(), fpb);
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index d400623..f48c869 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -789,7 +789,7 @@ public class LineageRewriteReuse
 
        private static BasicProgramBlock getProgramBlock() {
                if( _lrPB == null )
-                       _lrPB = new BasicProgramBlock( new Program() );
+                       _lrPB = new BasicProgramBlock(new Program());
                return _lrPB;
        }
 }
\ No newline at end of file
diff --git 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
index c2add4d..ce7df49 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
@@ -117,6 +117,16 @@ public class MLContextTest extends MLContextTestBase {
                ml.execute(script);
                ml.setExplain(false);
        }
+       
+       @Test
+       public void testExecuteEvalNestedBuiltinTest() {
+               System.out.println("MLContextTest - eval builtin test");
+               setExpectedStdOut("TRUE");
+               ml.setExplain(true);
+               Script script = dmlFromFile(baseDirectory + File.separator + 
"eval4-nested_builtin-test.dml");
+               ml.execute(script);
+               ml.setExplain(false);
+       }
 
        @Test
        public void testCreateDMLScriptBasedOnStringAndExecute() {
diff --git a/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml 
b/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml
new file mode 100644
index 0000000..2217085
--- /dev/null
+++ b/src/test/scripts/functions/mlcontext/eval4-nested_builtin-test.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=100, cols=10, seed=37)
+y = rand(rows=100, cols=1, seed=38)
+
+F = cbind(as.frame("lm"),as.frame("mlogreg"));
+ix = ifelse(sum(X)>1, 1, 2);
+R1 = eval(as.scalar(F[1,ix]), X, y, 0, 1e-7, 1e-7, 0, FALSE); #calls lm->lmDS
+R2 = lmCG(X=X, y=y, verbose=FALSE);
+
+print(sum(abs(R1-R2)<1e-6)==ncol(X));

Reply via email to