Repository: systemml Updated Branches: refs/heads/master 162a5b0f6 -> 71f8c836d
[SYSTEMML-1444,1759,1785] Support for UDFs in expressions This patch fixes a long existing shortcoming of DML, namely the missing support for user-defined functions (both dml-bodied and external java UDFs) in expressions such as foo(A, B) + 7, or R[i,j] = foo(A, B). In detail, we use a very simply approach of hoisting these function calls out of expressions directly after parsing. This approach allows for full flexibility at script level, yet can reuse the entire infrastructure of inlining, inter procedural analysis, and dynamic recompilation in case of unknown function outputs without compiler or runtime changes. Right now, these function calls are supported in assignment statements, multi-assignment statements, and print statements. In subsequent patches, we will also add support for loop/branch predicates and potentially output statements but this requires additional improvements. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/71f8c836 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/71f8c836 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/71f8c836 Branch: refs/heads/master Commit: 71f8c836ddac637ef960b6069babb6aad925a11d Parents: 162a5b0 Author: Matthias Boehm <[email protected]> Authored: Sun Mar 4 01:42:36 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Sun Mar 4 01:45:40 2018 -0800 ---------------------------------------------------------------------- .../RewriteSplitDagDataDependentOperators.java | 14 +- .../sysml/parser/AssignmentStatement.java | 12 +- .../org/apache/sysml/parser/DMLProgram.java | 16 ++ .../org/apache/sysml/parser/StatementBlock.java | 165 ++++++++++++++++++- .../parser/common/CommonSyntacticValidator.java | 16 +- .../sysml/parser/dml/DMLParserWrapper.java | 3 + .../sysml/parser/dml/DmlSyntacticValidator.java | 7 +- .../sysml/parser/pydml/PyDMLParserWrapper.java | 3 + .../parser/pydml/PydmlSyntacticValidator.java | 5 +- .../misc/FunctionInExpressionTest.java | 88 ++++++++++ .../scripts/functions/misc/FunInExpression1.dml | 29 ++++ .../scripts/functions/misc/FunInExpression2.dml | 32 ++++ .../scripts/functions/misc/FunInExpression3.dml | 36 ++++ .../scripts/functions/misc/FunInExpression4.dml | 36 ++++ .../functions/misc/ZPackageSuite.java | 1 + 15 files changed, 443 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index a55ea41..7b4a733 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -67,7 +67,8 @@ import org.apache.sysml.runtime.matrix.data.Pair; */ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewriteRule { - private static String _varnamePredix = "_sbcvar"; + private static final String SB_CUT_PREFIX = "_sbcvar"; + private static final String FUN_CUT_PREFIX = "_funvar"; private static IDSequence _seq = new IDSequence(); @Override @@ -151,7 +152,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite } else //create transient write to artificial variables { - varname = _varnamePredix + _seq.getNextID(); + varname = createCutVarName(false); //create new transient read DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), @@ -350,7 +351,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //step 3: create additional cuts for( Pair<Hop,Hop> p : candSet ) { - String varname = _varnamePredix + _seq.getNextID(); + String varname = createCutVarName(false); Hop hop = p.getKey(); Hop c = p.getValue(); @@ -474,4 +475,11 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite ProgramRewriteStatus sate) throws HopsException { return sbs; } + + public static String createCutVarName(boolean fun) { + return fun ? + FUN_CUT_PREFIX + _seq.getNextID() : + SB_CUT_PREFIX + _seq.getNextID(); + + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/AssignmentStatement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/AssignmentStatement.java b/src/main/java/org/apache/sysml/parser/AssignmentStatement.java index 3525d13..7746fb3 100644 --- a/src/main/java/org/apache/sysml/parser/AssignmentStatement.java +++ b/src/main/java/org/apache/sysml/parser/AssignmentStatement.java @@ -43,18 +43,20 @@ public class AssignmentStatement extends Statement AssignmentStatement retVal = new AssignmentStatement(newTarget, newSource, this); return retVal; } - - public AssignmentStatement(DataIdentifier di, Expression exp, ParseInfo parseInfo) { + + public AssignmentStatement(DataIdentifier di, Expression exp) { _targetList = new ArrayList<>(); _targetList.add(di); _source = exp; + } + + public AssignmentStatement(DataIdentifier di, Expression exp, ParseInfo parseInfo) { + this(di, exp); setParseInfo(parseInfo); } public AssignmentStatement(ParserRuleContext ctx, DataIdentifier di, Expression exp) throws LanguageException { - _targetList = new ArrayList<>(); - _targetList.add(di); - _source = exp; + this(di, exp); setCtxValues(ctx); } http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/DMLProgram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLProgram.java b/src/main/java/org/apache/sysml/parser/DMLProgram.java index 6fa5b98..6bc5847 100644 --- a/src/main/java/org/apache/sysml/parser/DMLProgram.java +++ b/src/main/java/org/apache/sysml/parser/DMLProgram.java @@ -131,6 +131,22 @@ public class DMLProgram _blocks = StatementBlock.mergeStatementBlocks(_blocks); } + public void hoistFunctionCallsFromExpressions() { + try { + //handle statement blocks of all functions + for( FunctionStatementBlock fsb : getFunctionStatementBlocks() ) + StatementBlock.rHoistFunctionCallsFromExpressions(fsb); + //handle statement blocks of main program + ArrayList<StatementBlock> tmp = new ArrayList<>(); + for( StatementBlock sb : _blocks ) + tmp.addAll(StatementBlock.rHoistFunctionCallsFromExpressions(sb)); + _blocks = tmp; + } + catch(LanguageException ex) { + throw new RuntimeException(ex); + } + } + @Override public String toString(){ StringBuilder sb = new StringBuilder(); http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/StatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java index 34a023a..f7901c1 100644 --- a/src/main/java/org/apache/sysml/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java @@ -21,6 +21,7 @@ package org.apache.sysml.parser; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -31,6 +32,7 @@ import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.hops.rewrite.RewriteSplitDagDataDependentOperators; import org.apache.sysml.lops.Lop; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.FormatType; @@ -71,6 +73,12 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo _constVarsOut = new HashMap<>(); _updateInPlaceVars = new ArrayList<>(); } + + public StatementBlock(StatementBlock sb) { + this(); + setParseInfo(sb); + _dmlProg = sb._dmlProg; + } public void setDMLProg(DMLProgram dmlProg){ _dmlProg = dmlProg; @@ -399,8 +407,160 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo return result; } + + public static List<StatementBlock> rHoistFunctionCallsFromExpressions(StatementBlock current) { + if (current instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)current; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for (StatementBlock sb : fstmt.getBody()) + rHoistFunctionCallsFromExpressions(sb); + } + else if (current instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) current; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + //TODO handle predicates + for (StatementBlock sb : wstmt.getBody()) + rHoistFunctionCallsFromExpressions(sb); + } + else if (current instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) current; + IfStatement istmt = (IfStatement)isb.getStatement(0); + //TODO handle predicates + for (StatementBlock sb : istmt.getIfBody()) + rHoistFunctionCallsFromExpressions(sb); + for (StatementBlock sb : istmt.getElseBody()) + rHoistFunctionCallsFromExpressions(sb); + } + else if (current instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) current; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + //TODO handle predicates + for (StatementBlock sb : fstmt.getBody()) + rHoistFunctionCallsFromExpressions(sb); + } + else { //generic (last-level) + ArrayList<Statement> tmp = new ArrayList<>(); + for(Statement stmt : current.getStatements()) + tmp.addAll(rHoistFunctionCallsFromExpressions(stmt)); + if( current.getStatements().size() != tmp.size() ) + return createStatementBlocks(current, tmp); + } + return Arrays.asList(current); + } - + public static List<Statement> rHoistFunctionCallsFromExpressions(Statement stmt) { + ArrayList<Statement> tmp = new ArrayList<>(); + if( stmt instanceof AssignmentStatement ) { + AssignmentStatement astmt = (AssignmentStatement)stmt; + boolean ix = (astmt.getTargetList().get(0) instanceof IndexedIdentifier); + rHoistFunctionCallsFromExpressions(astmt.getSource(), !ix, tmp); + if( ix && astmt.getSource() instanceof FunctionCallIdentifier ) { + AssignmentStatement lstmt = (AssignmentStatement) tmp.get(tmp.size()-1); + astmt.setSource(copy(lstmt.getTarget())); + } + } + else if( stmt instanceof MultiAssignmentStatement ) { + MultiAssignmentStatement mstmt = (MultiAssignmentStatement)stmt; + rHoistFunctionCallsFromExpressions(mstmt.getSource(), true, tmp); + } + else if( stmt instanceof PrintStatement ) { + PrintStatement pstmt = (PrintStatement)stmt; + for(int i=0; i<pstmt.expressions.size(); i++) { + Expression lexpr = pstmt.getExpressions().get(i); + rHoistFunctionCallsFromExpressions(lexpr, false, tmp); + if( lexpr instanceof FunctionCallIdentifier ) { + AssignmentStatement lstmt = (AssignmentStatement) tmp.get(tmp.size()-1); + pstmt.getExpressions().set(i, copy(lstmt.getTarget())); + } + } + } + + //most statements will be returned unchanged, while expressions with + //function calls are split into potentially many statements + List<Statement> ret = tmp.isEmpty() ? Arrays.asList(stmt) : tmp; + if( !tmp.isEmpty() ) { + for( Statement ltmp : tmp ) + ltmp.setParseInfo(stmt); + tmp.add(stmt); + } + return ret; + } + + public static Expression rHoistFunctionCallsFromExpressions(Expression expr, boolean root, ArrayList<Statement> tmp) { + if( expr == null || expr instanceof ConstIdentifier ) + return expr; //do nothing + if( expr instanceof BinaryExpression ) { + BinaryExpression lexpr = (BinaryExpression) expr; + lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp)); + lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp)); + } + else if( expr instanceof RelationalExpression ) { + RelationalExpression lexpr = (RelationalExpression) expr; + lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp)); + lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp)); + } + else if( expr instanceof BooleanExpression ) { + BooleanExpression lexpr = (BooleanExpression) expr; + lexpr.setLeft(rHoistFunctionCallsFromExpressions(lexpr.getLeft(), false, tmp)); + lexpr.setRight(rHoistFunctionCallsFromExpressions(lexpr.getRight(), false, tmp)); + } + else if( expr instanceof BuiltinFunctionExpression ) { + BuiltinFunctionExpression lexpr = (BuiltinFunctionExpression) expr; + Expression[] clexpr = lexpr.getAllExpr(); + for( int i=0; i<clexpr.length; i++ ) + clexpr[i] = rHoistFunctionCallsFromExpressions(clexpr[i], false, tmp); + } + else if( expr instanceof ParameterizedBuiltinFunctionExpression ) { + ParameterizedBuiltinFunctionExpression lexpr = (ParameterizedBuiltinFunctionExpression) expr; + HashMap<String, Expression> clexpr = lexpr.getVarParams(); + for( String key : clexpr.keySet() ) + clexpr.put(key, rHoistFunctionCallsFromExpressions(clexpr.get(key), false, tmp)); + } + else if( expr instanceof DataExpression ) { + DataExpression lexpr = (DataExpression) expr; + HashMap<String, Expression> clexpr = lexpr.getVarParams(); + for( String key : clexpr.keySet() ) + clexpr.put(key, rHoistFunctionCallsFromExpressions(clexpr.get(key), false, tmp)); + } + else if( expr instanceof FunctionCallIdentifier ) { + FunctionCallIdentifier fexpr = (FunctionCallIdentifier) expr; + for( ParameterExpression pexpr : fexpr.getParamExprs() ) + pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp)); + if( !root ) { //core hoisting + String varname = RewriteSplitDagDataDependentOperators.createCutVarName(true); + DataIdentifier di = new DataIdentifier(varname); + di.setDataType(fexpr.getDataType()); + di.setValueType(fexpr.getValueType()); + tmp.add(new AssignmentStatement(di, fexpr, di)); + return di; + } + } + //note: all remaining expressions data identifiers remain unchanged + return expr; + } + + private static DataIdentifier copy(DataIdentifier di) { + return new DataIdentifier(di); + } + + private static List<StatementBlock> createStatementBlocks(StatementBlock sb, List<Statement> stmts) { + List<StatementBlock> ret = new ArrayList<StatementBlock>(); + StatementBlock current = new StatementBlock(sb); + for(Statement stmt : stmts) { + current.addStatement(stmt); + //cut the statement block after the current function + if( stmt instanceof AssignmentStatement + && ((AssignmentStatement)stmt).getSource() + instanceof FunctionCallIdentifier ) { + ret.add(current); + current = new StatementBlock(sb); + } + } + if( current.getNumStatements() > 0 ) + ret.add(current); + return ret; + } + public ArrayList<Statement> rewriteFunctionCallStatements (DMLProgram dmlProg, ArrayList<Statement> statements) throws LanguageException { ArrayList<Statement> newStatements = new ArrayList<>(); @@ -1063,5 +1223,4 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo public void setUpdateInPlaceVars( ArrayList<String> vars ) { _updateInPlaceVars = vars; } - -} // end class +} http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java index 890bad2..67fbaf0 100644 --- a/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java @@ -679,16 +679,20 @@ public abstract class CommonSyntacticValidator { return; } - // If builtin functions weren't found... + // handle user-defined functions + setAssignmentStatement(ctx, info, target, + createFunctionCall(ctx, namespace, functionName, paramExpression)); + } + + protected FunctionCallIdentifier createFunctionCall(ParserRuleContext ctx, + String namespace, String functionName, ArrayList<ParameterExpression> paramExpression) { FunctionCallIdentifier functCall = new FunctionCallIdentifier(paramExpression); functCall.setFunctionName(functionName); - // Override default namespace for imported non-built-in function - String inferNamespace = (sourceNamespace != null && sourceNamespace.length() > 0 && DMLProgram.DEFAULT_NAMESPACE.equals(namespace)) ? sourceNamespace : namespace; + String inferNamespace = (sourceNamespace != null && sourceNamespace.length() > 0 + && DMLProgram.DEFAULT_NAMESPACE.equals(namespace)) ? sourceNamespace : namespace; functCall.setFunctionNamespace(inferNamespace); - functCall.setCtxValuesAndFilename(ctx, currentFile); - - setAssignmentStatement(ctx, info, target, functCall); + return functCall; } /** http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java index fb13289..8714968 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java +++ b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java @@ -253,7 +253,10 @@ public class DMLParserWrapper extends ParserWrapper dmlPgm.addStatementBlock(getStatementBlock(current)); } + //post-processing + dmlPgm.hoistFunctionCallsFromExpressions(); dmlPgm.mergeStatementBlocks(); + return dmlPgm; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java index 93b670b..4a1d6c3 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java @@ -534,11 +534,14 @@ public class DmlSyntacticValidator extends CommonSyntacticValidator implements D Action f = new Action() { @Override public void execute(Expression e) { info.expr = e; } }; + + // handle built-in functions boolean validBIF = buildForBuiltInFunction(ctx, functionName, paramExpression, f); if (validBIF) return; - - notifyErrorListeners("only builtin functions allowed as part of expression", ctx.start); + + // handle user-defined functions + info.expr = createFunctionCall(ctx, namespace, functionName, paramExpression); } http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java b/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java index 6865683..8a81b46 100644 --- a/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java +++ b/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java @@ -245,7 +245,10 @@ public class PyDMLParserWrapper extends ParserWrapper dmlPgm.addStatementBlock(getStatementBlock(current)); } + //post-processing + dmlPgm.hoistFunctionCallsFromExpressions(); dmlPgm.mergeStatementBlocks(); + return dmlPgm; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java index 9c1510b..858eca6 100644 --- a/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java @@ -1149,11 +1149,14 @@ public class PydmlSyntacticValidator extends CommonSyntacticValidator implements Action f = new Action() { @Override public void execute(Expression e) { info.expr = e; } }; + + //handle builtin functions boolean validBIF = buildForBuiltInFunction(ctx, functionName, paramExpression, f); if (validBIF) return; - notifyErrorListeners("only builtin functions allowed as part of expression", ctx.start); + // handle user-defined functions + info.expr = createFunctionCall(ctx, namespace, functionName, paramExpression); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionInExpressionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionInExpressionTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionInExpressionTest.java new file mode 100644 index 0000000..eefbf52 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionInExpressionTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class FunctionInExpressionTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "FunInExpression1"; + private final static String TEST_NAME2 = "FunInExpression2"; + private final static String TEST_NAME3 = "FunInExpression3"; + private final static String TEST_NAME4 = "FunInExpression4"; + + private final static String TEST_DIR = "functions/misc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + FunctionInExpressionTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); + } + + @Test + public void testFunInExpression1() { + runFunInExpressionTest( TEST_NAME1 ); + } + + @Test + public void testFunInExpression2() { + runFunInExpressionTest( TEST_NAME2 ); + } + + @Test + public void testFunInExpression3() { + runFunInExpressionTest( TEST_NAME3 ); + } + + @Test + public void testFunInExpression4() { + runFunInExpressionTest( TEST_NAME4 ); + } + + private void runFunInExpressionTest( String testName ) + { + TestConfiguration config = getTestConfiguration(testName); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[]{"-explain", "-stats", "-args", output("R") }; + + fullRScriptName = HOME + testName + ".R"; + rCmd = getRCmd(expectedDir()); + + //run script and compare output + runTest(true, false, null, -1); + + //compare results + double val = readDMLMatrixFromHDFS("R").get(new CellIndex(1,1)); + Assert.assertTrue("Wrong result: 7 vs "+val, Math.abs(val-7)<Math.pow(10, -14)); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/scripts/functions/misc/FunInExpression1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/FunInExpression1.dml b/src/test/scripts/functions/misc/FunInExpression1.dml new file mode 100644 index 0000000..dff7113 --- /dev/null +++ b/src/test/scripts/functions/misc/FunInExpression1.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(Matrix[Double] A) return(Matrix[Double] B) { + B = A + A; #inlined +} + +A = matrix(0.07, 10, 10); +R = as.matrix(sum(foo(A)/2)); + +write(R, $1); http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/scripts/functions/misc/FunInExpression2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/FunInExpression2.dml b/src/test/scripts/functions/misc/FunInExpression2.dml new file mode 100644 index 0000000..e9fabeb --- /dev/null +++ b/src/test/scripts/functions/misc/FunInExpression2.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(Matrix[Double] A) return(Matrix[Double] B) { + if( sum(A) != 0 ) + B = A + A; + else + B = A + 2; +} + +A = matrix(0.07, 10, 10); +R = as.matrix(sum(foo(A)/2)); + +write(R, $1); http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/scripts/functions/misc/FunInExpression3.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/FunInExpression3.dml b/src/test/scripts/functions/misc/FunInExpression3.dml new file mode 100644 index 0000000..7b30d90 --- /dev/null +++ b/src/test/scripts/functions/misc/FunInExpression3.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo1 = function(Matrix[Double] A) return(Matrix[Double] B) { + B = A + A; #inlined +} + +foo2 = function(Matrix[Double] A) return(Matrix[Double] B) { + if( sum(A) != 0 ) + B = A + A; + else + B = A + 2; +} + +A = matrix(0.07, 10, 10); +R = as.matrix(sum((1 + foo1(1.5*A) - foo2(A/2) - 1) / 2)); + +write(R, $1); http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test/scripts/functions/misc/FunInExpression4.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/FunInExpression4.dml b/src/test/scripts/functions/misc/FunInExpression4.dml new file mode 100644 index 0000000..25d85b8 --- /dev/null +++ b/src/test/scripts/functions/misc/FunInExpression4.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(Matrix[Double] A) return(Matrix[Double] B) { + if( sum(A) != 0 ) + B = A + A; + else + B = A + 2; +} + +A = matrix(0.07, 10, 10); +R0 = matrix(0, 11, 11); +R0[1:10,1:10] = (1 + foo(1.5*A) - foo(A/2) - 1) / 2; +while(FALSE){} + +R = as.matrix(sum(R0)); + +write(R, $1); http://git-wip-us.apache.org/repos/asf/systemml/blob/71f8c836/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index 80db7c7..1849b51 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -29,6 +29,7 @@ import org.junit.runners.Suite; ConditionalValidateTest.class, DataTypeCastingTest.class, DataTypeChangeTest.class, + FunctionInExpressionTest.class, FunctionInliningTest.class, FunctionNamespaceTest.class, FunctionReturnTest.class,
