Repository: systemml
Updated Branches:
  refs/heads/master 162a5b0f6 -> 71f8c836d


[SYSTEMML-1444,1759,1785] Support for UDFs in expressions

This patch fixes a long existing shortcoming of DML, namely the missing
support for user-defined functions (both dml-bodied and external java
UDFs) in expressions such as foo(A, B) + 7, or R[i,j] = foo(A, B). In
detail, we use a very simply approach of hoisting these function calls
out of expressions directly after parsing. This approach allows for full
flexibility at script level, yet can reuse the entire infrastructure of
inlining, inter procedural analysis, and dynamic recompilation in case
of unknown function outputs without compiler or runtime changes.

Right now, these function calls are supported in assignment statements,
multi-assignment statements, and print statements. In subsequent
patches, we will also add support for loop/branch predicates and
potentially output statements but this requires additional improvements.
 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/71f8c836
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/71f8c836
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/71f8c836

Branch: refs/heads/master
Commit: 71f8c836ddac637ef960b6069babb6aad925a11d
Parents: 162a5b0
Author: Matthias Boehm <[email protected]>
Authored: Sun Mar 4 01:42:36 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Sun Mar 4 01:45:40 2018 -0800

----------------------------------------------------------------------
 .../RewriteSplitDagDataDependentOperators.java  |  14 +-
 .../sysml/parser/AssignmentStatement.java       |  12 +-
 .../org/apache/sysml/parser/DMLProgram.java     |  16 ++
 .../org/apache/sysml/parser/StatementBlock.java | 165 ++++++++++++++++++-
 .../parser/common/CommonSyntacticValidator.java |  16 +-
 .../sysml/parser/dml/DMLParserWrapper.java      |   3 +
 .../sysml/parser/dml/DmlSyntacticValidator.java |   7 +-
 .../sysml/parser/pydml/PyDMLParserWrapper.java  |   3 +
 .../parser/pydml/PydmlSyntacticValidator.java   |   5 +-
 .../misc/FunctionInExpressionTest.java          |  88 ++++++++++
 .../scripts/functions/misc/FunInExpression1.dml |  29 ++++
 .../scripts/functions/misc/FunInExpression2.dml |  32 ++++
 .../scripts/functions/misc/FunInExpression3.dml |  36 ++++
 .../scripts/functions/misc/FunInExpression4.dml |  36 ++++
 .../functions/misc/ZPackageSuite.java           |   1 +
 15 files changed, 443 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index a55ea41..7b4a733 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -67,7 +67,8 @@ import org.apache.sysml.runtime.matrix.data.Pair;
  */
 public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewriteRule
 {
-       private static String _varnamePredix = "_sbcvar";
+       private static final String SB_CUT_PREFIX = "_sbcvar";
+       private static final String FUN_CUT_PREFIX = "_funvar";
        private static IDSequence _seq = new IDSequence();
        
        @Override
@@ -151,7 +152,7 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                                        }
                                        else //create transient write to 
artificial variables
                                        {
-                                               varname = _varnamePredix + 
_seq.getNextID();
+                                               varname = 
createCutVarName(false);
                                                
                                                //create new transient read
                                                DataOp tread = new 
DataOp(varname, c.getDataType(), c.getValueType(),
@@ -350,7 +351,7 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                //step 3: create additional cuts
                for( Pair<Hop,Hop> p : candSet ) 
                {
-                       String varname = _varnamePredix + _seq.getNextID();
+                       String varname = createCutVarName(false);
                        
                        Hop hop = p.getKey();
                        Hop c = p.getValue();
@@ -474,4 +475,11 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
                        ProgramRewriteStatus sate) throws HopsException {
                return sbs;
        }
+       
+       public static String createCutVarName(boolean fun) {
+               return fun ?
+                       FUN_CUT_PREFIX + _seq.getNextID() :
+                       SB_CUT_PREFIX + _seq.getNextID();
+               
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/AssignmentStatement.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/AssignmentStatement.java 
b/src/main/java/org/apache/sysml/parser/AssignmentStatement.java
index 3525d13..7746fb3 100644
--- a/src/main/java/org/apache/sysml/parser/AssignmentStatement.java
+++ b/src/main/java/org/apache/sysml/parser/AssignmentStatement.java
@@ -43,18 +43,20 @@ public class AssignmentStatement extends Statement
                AssignmentStatement retVal = new AssignmentStatement(newTarget, 
newSource, this);
                return retVal;
        }
-
-       public AssignmentStatement(DataIdentifier di, Expression exp, ParseInfo 
parseInfo) {
+       
+       public AssignmentStatement(DataIdentifier di, Expression exp) {
                _targetList = new ArrayList<>();
                _targetList.add(di);
                _source = exp;
+       }
+       
+       public AssignmentStatement(DataIdentifier di, Expression exp, ParseInfo 
parseInfo) {
+               this(di, exp);
                setParseInfo(parseInfo);
        }
 
        public AssignmentStatement(ParserRuleContext ctx, DataIdentifier di, 
Expression exp) throws LanguageException {
-               _targetList = new ArrayList<>();
-               _targetList.add(di);
-               _source = exp;
+               this(di, exp);
                setCtxValues(ctx);
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/DMLProgram.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLProgram.java 
b/src/main/java/org/apache/sysml/parser/DMLProgram.java
index 6fa5b98..6bc5847 100644
--- a/src/main/java/org/apache/sysml/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysml/parser/DMLProgram.java
@@ -131,6 +131,22 @@ public class DMLProgram
                _blocks = StatementBlock.mergeStatementBlocks(_blocks);
        }
        
+       public void hoistFunctionCallsFromExpressions() {
+               try {
+                       //handle statement blocks of all functions
+                       for( FunctionStatementBlock fsb : 
getFunctionStatementBlocks() )
+                               
StatementBlock.rHoistFunctionCallsFromExpressions(fsb);
+                       //handle statement blocks of main program
+                       ArrayList<StatementBlock> tmp = new ArrayList<>();
+                       for( StatementBlock sb : _blocks )
+                               
tmp.addAll(StatementBlock.rHoistFunctionCallsFromExpressions(sb));
+                       _blocks = tmp;
+               }
+               catch(LanguageException ex) {
+                       throw new RuntimeException(ex);
+               }
+       }
+       
        @Override
        public String toString(){
                StringBuilder sb = new StringBuilder();

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/StatementBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java 
b/src/main/java/org/apache/sysml/parser/StatementBlock.java
index 34a023a..f7901c1 100644
--- a/src/main/java/org/apache/sysml/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java
@@ -21,6 +21,7 @@ package org.apache.sysml.parser;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 
@@ -31,6 +32,7 @@ import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.hops.rewrite.RewriteSplitDagDataDependentOperators;
 import org.apache.sysml.lops.Lop;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.FormatType;
@@ -71,6 +73,12 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                _constVarsOut = new HashMap<>();
                _updateInPlaceVars = new ArrayList<>();
        }
+       
+       public StatementBlock(StatementBlock sb) {
+               this();
+               setParseInfo(sb);
+               _dmlProg = sb._dmlProg;
+       }
 
        public void setDMLProg(DMLProgram dmlProg){
                _dmlProg = dmlProg;
@@ -399,8 +407,160 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                return result;
 
        }
+       
+       public static List<StatementBlock> 
rHoistFunctionCallsFromExpressions(StatementBlock current) {
+               if (current instanceof FunctionStatementBlock) {
+                       FunctionStatementBlock fsb = 
(FunctionStatementBlock)current;
+                       FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
+                       for (StatementBlock sb : fstmt.getBody())
+                               rHoistFunctionCallsFromExpressions(sb);
+               }
+               else if (current instanceof WhileStatementBlock) {
+                       WhileStatementBlock wsb = (WhileStatementBlock) current;
+                       WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
+                       //TODO handle predicates
+                       for (StatementBlock sb : wstmt.getBody())
+                               rHoistFunctionCallsFromExpressions(sb);
+               }
+               else if (current instanceof IfStatementBlock) {
+                       IfStatementBlock isb = (IfStatementBlock) current;
+                       IfStatement istmt = (IfStatement)isb.getStatement(0);
+                       //TODO handle predicates
+                       for (StatementBlock sb : istmt.getIfBody())
+                               rHoistFunctionCallsFromExpressions(sb);
+                       for (StatementBlock sb : istmt.getElseBody())
+                               rHoistFunctionCallsFromExpressions(sb);
+               }
+               else if (current instanceof ForStatementBlock) { //incl parfor
+                       ForStatementBlock fsb = (ForStatementBlock) current;
+                       ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+                       //TODO handle predicates
+                       for (StatementBlock sb : fstmt.getBody())
+                               rHoistFunctionCallsFromExpressions(sb);
+               }
+               else { //generic (last-level)
+                       ArrayList<Statement> tmp = new ArrayList<>();
+                       for(Statement stmt : current.getStatements())
+                               
tmp.addAll(rHoistFunctionCallsFromExpressions(stmt));
+                       if( current.getStatements().size() != tmp.size() )
+                               return createStatementBlocks(current, tmp);
+               }
+               return Arrays.asList(current);
+       }
 
-
+       public static List<Statement> 
rHoistFunctionCallsFromExpressions(Statement stmt) {
+               ArrayList<Statement> tmp = new ArrayList<>();
+               if( stmt instanceof AssignmentStatement ) {
+                       AssignmentStatement astmt = (AssignmentStatement)stmt;
+                       boolean ix = (astmt.getTargetList().get(0) instanceof 
IndexedIdentifier);
+                       rHoistFunctionCallsFromExpressions(astmt.getSource(), 
!ix, tmp);
+                       if( ix && astmt.getSource() instanceof 
FunctionCallIdentifier ) {
+                               AssignmentStatement lstmt = 
(AssignmentStatement) tmp.get(tmp.size()-1);
+                               astmt.setSource(copy(lstmt.getTarget()));
+                       }
+               }
+               else if( stmt instanceof MultiAssignmentStatement ) {
+                       MultiAssignmentStatement mstmt = 
(MultiAssignmentStatement)stmt;
+                       rHoistFunctionCallsFromExpressions(mstmt.getSource(), 
true, tmp);
+               }
+               else if( stmt instanceof PrintStatement ) {
+                       PrintStatement pstmt = (PrintStatement)stmt;
+                       for(int i=0; i<pstmt.expressions.size(); i++) {
+                               Expression lexpr = 
pstmt.getExpressions().get(i);
+                               rHoistFunctionCallsFromExpressions(lexpr, 
false, tmp);
+                               if( lexpr instanceof FunctionCallIdentifier ) {
+                                       AssignmentStatement lstmt = 
(AssignmentStatement) tmp.get(tmp.size()-1);
+                                       pstmt.getExpressions().set(i, 
copy(lstmt.getTarget()));
+                               }
+                       }
+               }
+               
+               //most statements will be returned unchanged, while expressions 
with
+               //function calls are split into potentially many statements
+               List<Statement> ret = tmp.isEmpty() ? Arrays.asList(stmt) : tmp;
+               if( !tmp.isEmpty() ) {
+                       for( Statement ltmp : tmp )
+                               ltmp.setParseInfo(stmt);
+                       tmp.add(stmt);
+               }
+               return ret;
+       }
+       
+       public static Expression rHoistFunctionCallsFromExpressions(Expression 
expr, boolean root, ArrayList<Statement> tmp) {
+               if( expr == null || expr instanceof ConstIdentifier )
+                       return expr; //do nothing
+               if( expr instanceof BinaryExpression ) {
+                       BinaryExpression lexpr = (BinaryExpression) expr;
+                       
lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp));
+                       
lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, 
tmp));
+               }
+               else if( expr instanceof RelationalExpression ) {
+                       RelationalExpression lexpr = (RelationalExpression) 
expr;
+                       
lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp));
+                       
lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, 
tmp));
+               }
+               else if( expr instanceof BooleanExpression ) {
+                       BooleanExpression lexpr = (BooleanExpression) expr;
+                       
lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp));
+                       
lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, 
tmp));
+               }
+               else if( expr instanceof BuiltinFunctionExpression ) {
+                       BuiltinFunctionExpression lexpr = 
(BuiltinFunctionExpression) expr;
+                       Expression[] clexpr = lexpr.getAllExpr();
+                       for( int i=0; i<clexpr.length; i++ )
+                               clexpr[i] = 
rHoistFunctionCallsFromExpressions(clexpr[i], false, tmp);
+               }
+               else if( expr instanceof ParameterizedBuiltinFunctionExpression 
) {
+                       ParameterizedBuiltinFunctionExpression lexpr = 
(ParameterizedBuiltinFunctionExpression) expr;
+                       HashMap<String, Expression> clexpr = 
lexpr.getVarParams();
+                       for( String key : clexpr.keySet() )
+                               clexpr.put(key, 
rHoistFunctionCallsFromExpressions(clexpr.get(key), false, tmp));
+               }
+               else if( expr instanceof DataExpression ) {
+                       DataExpression lexpr = (DataExpression) expr;
+                       HashMap<String, Expression> clexpr = 
lexpr.getVarParams();
+                       for( String key : clexpr.keySet() )
+                               clexpr.put(key, 
rHoistFunctionCallsFromExpressions(clexpr.get(key), false, tmp));
+               }
+               else if( expr instanceof FunctionCallIdentifier ) {
+                       FunctionCallIdentifier fexpr = (FunctionCallIdentifier) 
expr;
+                       for( ParameterExpression pexpr : fexpr.getParamExprs() )
+                               
pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp));
+                       if( !root ) { //core hoisting
+                               String varname = 
RewriteSplitDagDataDependentOperators.createCutVarName(true);
+                               DataIdentifier di = new DataIdentifier(varname);
+                               di.setDataType(fexpr.getDataType());
+                               di.setValueType(fexpr.getValueType());
+                               tmp.add(new AssignmentStatement(di, fexpr, di));
+                               return di;
+                       }
+               }
+               //note: all remaining expressions data identifiers remain 
unchanged
+               return expr;
+       }
+       
+       private static DataIdentifier copy(DataIdentifier di) {
+               return new DataIdentifier(di);
+       }
+       
+       private static List<StatementBlock> 
createStatementBlocks(StatementBlock sb, List<Statement> stmts) {
+               List<StatementBlock> ret = new ArrayList<StatementBlock>();
+               StatementBlock current = new StatementBlock(sb);
+               for(Statement stmt : stmts) {
+                       current.addStatement(stmt);
+                       //cut the statement block after the current function
+                       if( stmt instanceof AssignmentStatement
+                               && ((AssignmentStatement)stmt).getSource()
+                               instanceof FunctionCallIdentifier ) {
+                               ret.add(current);
+                               current = new StatementBlock(sb);
+                       }
+               }
+               if( current.getNumStatements() > 0 )
+                       ret.add(current);
+               return ret;
+       }
+       
        public ArrayList<Statement> rewriteFunctionCallStatements (DMLProgram 
dmlProg, ArrayList<Statement> statements) throws LanguageException {
 
                ArrayList<Statement> newStatements = new ArrayList<>();
@@ -1063,5 +1223,4 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
        public void setUpdateInPlaceVars( ArrayList<String> vars ) {
                _updateInPlaceVars = vars;
        }
-
-}  // end class
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java 
b/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java
index 890bad2..67fbaf0 100644
--- a/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java
+++ b/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java
@@ -679,16 +679,20 @@ public abstract class CommonSyntacticValidator {
                                return;
                }
 
-               // If builtin functions weren't found...
+               // handle user-defined functions
+               setAssignmentStatement(ctx, info, target,
+                       createFunctionCall(ctx, namespace, functionName, 
paramExpression));
+       }
+       
+       protected FunctionCallIdentifier createFunctionCall(ParserRuleContext 
ctx,
+               String namespace, String functionName, 
ArrayList<ParameterExpression> paramExpression) {
                FunctionCallIdentifier functCall = new 
FunctionCallIdentifier(paramExpression);
                functCall.setFunctionName(functionName);
-               // Override default namespace for imported non-built-in function
-               String inferNamespace = (sourceNamespace != null && 
sourceNamespace.length() > 0 && DMLProgram.DEFAULT_NAMESPACE.equals(namespace)) 
? sourceNamespace : namespace;
+               String inferNamespace = (sourceNamespace != null && 
sourceNamespace.length() > 0
+                       && DMLProgram.DEFAULT_NAMESPACE.equals(namespace)) ? 
sourceNamespace : namespace;
                functCall.setFunctionNamespace(inferNamespace);
-
                functCall.setCtxValuesAndFilename(ctx, currentFile);
-
-               setAssignmentStatement(ctx, info, target, functCall);
+               return functCall;
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java 
b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
index fb13289..8714968 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
@@ -253,7 +253,10 @@ public class DMLParserWrapper extends ParserWrapper
                        dmlPgm.addStatementBlock(getStatementBlock(current));
                }
 
+               //post-processing
+               dmlPgm.hoistFunctionCallsFromExpressions();
                dmlPgm.mergeStatementBlocks();
+               
                return dmlPgm;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java 
b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
index 93b670b..4a1d6c3 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
@@ -534,11 +534,14 @@ public class DmlSyntacticValidator extends 
CommonSyntacticValidator implements D
                Action f = new Action() {
                        @Override public void execute(Expression e) { info.expr 
= e; }
                };
+               
+               // handle built-in functions
                boolean validBIF = buildForBuiltInFunction(ctx, functionName, 
paramExpression, f);
                if (validBIF)
                        return;
-
-               notifyErrorListeners("only builtin functions allowed as part of 
expression", ctx.start);
+               
+               // handle user-defined functions
+               info.expr = createFunctionCall(ctx, namespace, functionName, 
paramExpression);
        }
 
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java 
b/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java
index 6865683..8a81b46 100644
--- a/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java
+++ b/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java
@@ -245,7 +245,10 @@ public class PyDMLParserWrapper extends ParserWrapper
                        dmlPgm.addStatementBlock(getStatementBlock(current));
                }
 
+               //post-processing
+               dmlPgm.hoistFunctionCallsFromExpressions();
                dmlPgm.mergeStatementBlocks();
+               
                return dmlPgm;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java 
b/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java
index 9c1510b..858eca6 100644
--- a/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java
@@ -1149,11 +1149,14 @@ public class PydmlSyntacticValidator extends 
CommonSyntacticValidator implements
                Action f = new Action() {
                        @Override public void execute(Expression e) { info.expr 
= e; }
                };
+               
+               //handle builtin functions
                boolean validBIF = buildForBuiltInFunction(ctx, functionName, 
paramExpression, f);
                if (validBIF)
                        return;
 
-               notifyErrorListeners("only builtin functions allowed as part of 
expression", ctx.start);
+               // handle user-defined functions
+               info.expr = createFunctionCall(ctx, namespace, functionName, 
paramExpression);
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionInExpressionTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionInExpressionTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionInExpressionTest.java
new file mode 100644
index 0000000..eefbf52
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionInExpressionTest.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class FunctionInExpressionTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME1 = "FunInExpression1"; 
+       private final static String TEST_NAME2 = "FunInExpression2"; 
+       private final static String TEST_NAME3 = "FunInExpression3"; 
+       private final static String TEST_NAME4 = "FunInExpression4"; 
+       
+       private final static String TEST_DIR = "functions/misc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FunctionInExpressionTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testFunInExpression1() {
+               runFunInExpressionTest( TEST_NAME1 );
+       }
+       
+       @Test
+       public void testFunInExpression2() {
+               runFunInExpressionTest( TEST_NAME2 );
+       }
+       
+       @Test
+       public void testFunInExpression3() {
+               runFunInExpressionTest( TEST_NAME3 );
+       }
+       
+       @Test
+       public void testFunInExpression4() {
+               runFunInExpressionTest( TEST_NAME4 );
+       }
+       
+       private void runFunInExpressionTest( String testName )
+       {
+               TestConfiguration config = getTestConfiguration(testName);
+               loadTestConfiguration(config);
+               
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               fullDMLScriptName = HOME + testName + ".dml";
+               programArgs = new String[]{"-explain", "-stats", "-args", 
output("R") };
+               
+               fullRScriptName = HOME + testName + ".R";
+               rCmd = getRCmd(expectedDir());
+
+               //run script and compare output
+               runTest(true, false, null, -1); 
+               
+               //compare results
+               double val = readDMLMatrixFromHDFS("R").get(new CellIndex(1,1));
+               Assert.assertTrue("Wrong result: 7 vs "+val, 
Math.abs(val-7)<Math.pow(10, -14));
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/scripts/functions/misc/FunInExpression1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/FunInExpression1.dml 
b/src/test/scripts/functions/misc/FunInExpression1.dml
new file mode 100644
index 0000000..dff7113
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunInExpression1.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] A) return(Matrix[Double] B) {
+   B = A + A; #inlined
+}
+
+A = matrix(0.07, 10, 10);
+R = as.matrix(sum(foo(A)/2));
+
+write(R, $1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/scripts/functions/misc/FunInExpression2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/FunInExpression2.dml 
b/src/test/scripts/functions/misc/FunInExpression2.dml
new file mode 100644
index 0000000..e9fabeb
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunInExpression2.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] A) return(Matrix[Double] B) {
+   if( sum(A) != 0 )
+      B = A + A;
+   else
+      B = A + 2;
+}
+
+A = matrix(0.07, 10, 10);
+R = as.matrix(sum(foo(A)/2));
+
+write(R, $1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/scripts/functions/misc/FunInExpression3.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/FunInExpression3.dml 
b/src/test/scripts/functions/misc/FunInExpression3.dml
new file mode 100644
index 0000000..7b30d90
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunInExpression3.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo1 = function(Matrix[Double] A) return(Matrix[Double] B) {
+   B = A + A; #inlined
+}
+
+foo2 = function(Matrix[Double] A) return(Matrix[Double] B) {
+   if( sum(A) != 0 )
+      B = A + A;
+   else
+      B = A + 2;
+}
+
+A = matrix(0.07, 10, 10);
+R = as.matrix(sum((1 + foo1(1.5*A) - foo2(A/2) - 1) / 2));
+
+write(R, $1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/scripts/functions/misc/FunInExpression4.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/FunInExpression4.dml 
b/src/test/scripts/functions/misc/FunInExpression4.dml
new file mode 100644
index 0000000..25d85b8
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunInExpression4.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] A) return(Matrix[Double] B) {
+   if( sum(A) != 0 )
+      B = A + A;
+   else
+      B = A + 2;
+}
+
+A = matrix(0.07, 10, 10);
+R0 = matrix(0, 11, 11);
+R0[1:10,1:10] = (1 + foo(1.5*A) - foo(A/2) - 1) / 2;
+while(FALSE){}
+
+R = as.matrix(sum(R0));
+
+write(R, $1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index 80db7c7..1849b51 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -29,6 +29,7 @@ import org.junit.runners.Suite;
        ConditionalValidateTest.class,
        DataTypeCastingTest.class,
        DataTypeChangeTest.class,
+       FunctionInExpressionTest.class,
        FunctionInliningTest.class,
        FunctionNamespaceTest.class,
        FunctionReturnTest.class,

Reply via email to