Repository: incubator-systemml Updated Branches: refs/heads/master f4939632c -> 90f56da29
[SYSTEMML-654] Support DML functions override built-in functions This patch adds support for overriding built-in functions. The approach was to track new internal and external function definition names per script and skip built-in function call handling if the function name was also defined by user (or imported when converting pydml syntax). Also added parse error detection/reporting if user attempts to define two functions with same name in same script. Previously, the second definition would overwrite the first. Closes #147. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/90f56da2 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/90f56da2 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/90f56da2 Branch: refs/heads/master Commit: 90f56da29d02d4a7b59d1c46a1b5b12993177233 Parents: f493963 Author: Glenn Weidner <[email protected]> Authored: Thu May 12 11:34:40 2016 -0700 Committer: Mike Dusenberry <[email protected]> Committed: Thu May 12 11:36:01 2016 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/parser/DataExpression.java | 2 +- .../parser/common/CommonSyntacticValidator.java | 20 +- .../sysml/parser/dml/DMLParserWrapper.java | 6 +- .../sysml/parser/dml/DmlPreprocessor.java | 374 +++++++++++++++++++ .../sysml/parser/dml/DmlSyntacticValidator.java | 5 +- .../sysml/parser/pydml/PyDMLParserWrapper.java | 6 +- .../sysml/parser/pydml/PydmlPreprocessor.java | 374 +++++++++++++++++++ .../parser/pydml/PydmlSyntacticValidator.java | 13 +- .../functions/misc/FunctionNamespaceTest.java | 49 +++ src/test/scripts/functions/misc/Functions11.dml | 66 ++++ .../scripts/functions/misc/Functions11.pydml | 62 +++ src/test/scripts/functions/misc/Functions12.dml | 54 +++ .../scripts/functions/misc/Functions12.pydml | 52 +++ src/test/scripts/functions/misc/Functions13.dml | 39 ++ .../scripts/functions/misc/Functions13.pydml | 33 ++ src/test/scripts/functions/misc/FunctionsL1.dml | 55 +++ .../scripts/functions/misc/FunctionsL1.pydml | 43 +++ src/test/scripts/functions/misc/FunctionsL2.dml | 44 +++ .../scripts/functions/misc/FunctionsL2.pydml | 36 ++ 19 files changed, 1315 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/main/java/org/apache/sysml/parser/DataExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java b/src/main/java/org/apache/sysml/parser/DataExpression.java index 2970fa0..be2f569 100644 --- a/src/main/java/org/apache/sysml/parser/DataExpression.java +++ b/src/main/java/org/apache/sysml/parser/DataExpression.java @@ -200,7 +200,7 @@ public class DataExpression extends DataIdentifier String pname = currExpr.getName(); Expression pexpr = currExpr.getExpr(); if (pname == null){ - dataExpr.raiseValidateError("for Rand Statment all arguments must be named parameters"); + dataExpr.raiseValidateError("for Rand Statement all arguments must be named parameters"); } dataExpr.addRandExprParam(pname, pexpr); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java index 78ba3d6..91ec37e 100644 --- a/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/common/CommonSyntacticValidator.java @@ -21,6 +21,7 @@ package org.apache.sysml.parser.common; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.regex.Pattern; @@ -64,23 +65,26 @@ public abstract class CommonSyntacticValidator { protected String _workingDir = "."; //current working directory protected Map<String,String> argVals = null; protected String sourceNamespace = null; - // track imported scripts to prevent infinite recursion + // Track imported scripts to prevent infinite recursion protected static ThreadLocal<HashMap<String, String>> _scripts = new ThreadLocal<HashMap<String, String>>() { @Override protected HashMap<String, String> initialValue() { return new HashMap<String, String>(); } }; - // mapping of namespaces to full paths as defined only from source statements in this script (i.e., currentFile) + // Map namespaces to full paths as defined only from source statements in this script (i.e., currentFile) protected HashMap<String, String> sources; + // Names of new internal and external functions defined in this script (i.e., currentFile) + protected Set<String> functions; public static void init() { _scripts.get().clear(); } - public CommonSyntacticValidator(CustomErrorListener errorListener, Map<String,String> argVals, String sourceNamespace) { + public CommonSyntacticValidator(CustomErrorListener errorListener, Map<String,String> argVals, String sourceNamespace, Set<String> prepFunctions) { this.errorListener = errorListener; currentFile = errorListener.getCurrentFileName(); this.argVals = argVals; this.sourceNamespace = sourceNamespace; sources = new HashMap<String, String>(); + functions = (null != prepFunctions) ? prepFunctions : new HashSet<String>(); } protected void notifyErrorListeners(String message, int line, int charPositionInLine) { @@ -611,7 +615,11 @@ public abstract class CommonSyntacticValidator { int line = ctx.start.getLine(); int col = ctx.start.getCharPositionInLine(); try { - + if (functions.contains(functionName)) { + // It is a user function definition (which takes precedence if name same as built-in) + return false; + } + Expression lsf = handleLanguageSpecificFunction(ctx, functionName, paramExpressions); if (lsf != null){ setFileLineColumn(lsf, ctx); @@ -662,7 +670,7 @@ public abstract class CommonSyntacticValidator { } // For builtin functions without LHS - if(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) { + if(namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && !functions.contains(functionName)) { if (printStatements.contains(functionName)){ setPrintStatement(ctx, functionName, paramExpression, info); return; @@ -688,7 +696,7 @@ public abstract class CommonSyntacticValidator { } // For builtin functions with LHS - if(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)){ + if(namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && !functions.contains(functionName)){ final DataIdentifier ftarget = target; Action f = new Action() { @Override public void execute(Expression e) { setAssignmentStatement(ctx, info , ftarget, e); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java index 4109add..25c6c6d 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java +++ b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java @@ -174,7 +174,11 @@ public class DMLParserWrapper extends AParserWrapper ParseTree tree = ast; // And also do syntactic validation ParseTreeWalker walker = new ParseTreeWalker(); - DmlSyntacticValidator validator = new DmlSyntacticValidator(errorListener, argVals, sourceNamespace); + // Get list of function definitions which take precedence over built-in functions if same name + DmlPreprocessor prep = new DmlPreprocessor(errorListener); + walker.walk(prep, tree); + // Syntactic validation + DmlSyntacticValidator validator = new DmlSyntacticValidator(errorListener, argVals, sourceNamespace, prep.getFunctionDefs()); walker.walk(validator, tree); errorListener.unsetCurrentFileName(); this.parseIssues = errorListener.getParseIssues(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java new file mode 100644 index 0000000..1fc66f3 --- /dev/null +++ b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java @@ -0,0 +1,374 @@ +package org.apache.sysml.parser.dml; + +import java.util.HashSet; +import java.util.Set; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.tree.ErrorNode; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.apache.sysml.parser.common.CustomErrorListener; +import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.BooleanAndExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.BooleanNotExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.BooleanOrExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.BuiltinFunctionExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.CommandlineParamExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.CommandlinePositionExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstDoubleIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstFalseExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstIntIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstStringIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstTrueExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.DataIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ExternalFunctionDefExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ForStatementContext; +import org.apache.sysml.parser.dml.DmlParser.FunctionCallAssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.FunctionCallMultiAssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.IfStatementContext; +import org.apache.sysml.parser.dml.DmlParser.IfdefAssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.ImportStatementContext; +import org.apache.sysml.parser.dml.DmlParser.IndexedExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.InternalFunctionDefExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.IterablePredicateColonExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.IterablePredicateSeqExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.MatrixDataTypeCheckContext; +import org.apache.sysml.parser.dml.DmlParser.MatrixMulExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.Ml_typeContext; +import org.apache.sysml.parser.dml.DmlParser.ModIntDivExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.MultDivExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ParForStatementContext; +import org.apache.sysml.parser.dml.DmlParser.ParameterizedExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.PathStatementContext; +import org.apache.sysml.parser.dml.DmlParser.PowerExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ProgramrootContext; +import org.apache.sysml.parser.dml.DmlParser.RelationalExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.SimpleDataIdentifierExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.StrictParameterizedExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.StrictParameterizedKeyValueStringContext; +import org.apache.sysml.parser.dml.DmlParser.TypedArgNoAssignContext; +import org.apache.sysml.parser.dml.DmlParser.UnaryExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ValueTypeContext; +import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext; + +/** + * Minimal pre-processing of user function definitions which take precedence over built-in + * functions in cases where names conflict. This pre-processing takes place outside of + * DmlSyntacticValidator since the function definition can be located after the function + * is used in a statement. + */ +public class DmlPreprocessor implements DmlListener { + + protected final CustomErrorListener errorListener; + // Names of user internal and external functions definitions + protected Set<String> functions; + + public DmlPreprocessor(CustomErrorListener errorListener) { + this.errorListener = errorListener; + functions = new HashSet<String>(); + } + + public Set<String> getFunctionDefs() { + return functions; + } + + @Override + public void enterExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) { + validateFunctionName(ctx.name.getText(), ctx); + } + + @Override + public void exitExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) {} + + @Override + public void enterInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) { + validateFunctionName(ctx.name.getText(), ctx); + } + + @Override + public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {} + + protected void validateFunctionName(String name, ParserRuleContext ctx) { + if (!functions.contains(name)) { + functions.add(name); + } + else { + notifyErrorListeners("Function Name Conflict: '" + name + "' already defined in " + errorListener.getCurrentFileName(), ctx.start); + } + } + + protected void notifyErrorListeners(String message, Token op) { + errorListener.validationError(op.getLine(), op.getCharPositionInLine(), message); + } + + // ----------------------------------------------------------------- + // Not overridden + // ----------------------------------------------------------------- + + @Override + public void visitTerminal(TerminalNode node) {} + + @Override + public void visitErrorNode(ErrorNode node) {} + + @Override + public void enterEveryRule(ParserRuleContext ctx) {} + + @Override + public void exitEveryRule(ParserRuleContext ctx) {} + + @Override + public void enterFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {} + + @Override + public void exitFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {} + + @Override + public void enterMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {} + + @Override + public void exitMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {} + + @Override + public void enterStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) {} + + @Override + public void exitStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) {} + + @Override + public void enterPathStatement(PathStatementContext ctx) {} + + @Override + public void exitPathStatement(PathStatementContext ctx) {} + + @Override + public void enterConstTrueExpression(ConstTrueExpressionContext ctx) {} + + @Override + public void exitConstTrueExpression(ConstTrueExpressionContext ctx) {} + + @Override + public void enterTypedArgNoAssign(TypedArgNoAssignContext ctx) {} + + @Override + public void exitTypedArgNoAssign(TypedArgNoAssignContext ctx) {} + + @Override + public void enterWhileStatement(WhileStatementContext ctx) {} + + @Override + public void exitWhileStatement(WhileStatementContext ctx) {} + + @Override + public void enterConstStringIdExpression(ConstStringIdExpressionContext ctx) {} + + @Override + public void exitConstStringIdExpression(ConstStringIdExpressionContext ctx) {} + + @Override + public void enterDataIdExpression(DataIdExpressionContext ctx) {} + + @Override + public void exitDataIdExpression(DataIdExpressionContext ctx) {} + + @Override + public void enterAtomicExpression(AtomicExpressionContext ctx) {} + + @Override + public void exitAtomicExpression(AtomicExpressionContext ctx) {} + + @Override + public void enterPowerExpression(PowerExpressionContext ctx) {} + + @Override + public void exitPowerExpression(PowerExpressionContext ctx) {} + + @Override + public void enterFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) {} + + @Override + public void exitFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) {} + + @Override + public void enterMatrixMulExpression(MatrixMulExpressionContext ctx) {} + + @Override + public void exitMatrixMulExpression(MatrixMulExpressionContext ctx) {} + + @Override + public void enterModIntDivExpression(ModIntDivExpressionContext ctx) {} + + @Override + public void exitModIntDivExpression(ModIntDivExpressionContext ctx) {} + + @Override + public void enterSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {} + + @Override + public void exitSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {} + + @Override + public void enterBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {} + + @Override + public void exitBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {} + + @Override + public void enterConstIntIdExpression(ConstIntIdExpressionContext ctx) {} + + @Override + public void exitConstIntIdExpression(ConstIntIdExpressionContext ctx) {} + + @Override + public void enterForStatement(ForStatementContext ctx) {} + + @Override + public void exitForStatement(ForStatementContext ctx) {} + + @Override + public void enterValueType(ValueTypeContext ctx) {} + + @Override + public void exitValueType(ValueTypeContext ctx) {} + + @Override + public void enterParameterizedExpression(ParameterizedExpressionContext ctx) {} + + @Override + public void exitParameterizedExpression(ParameterizedExpressionContext ctx) {} + + @Override + public void enterConstFalseExpression(ConstFalseExpressionContext ctx) {} + + @Override + public void exitConstFalseExpression(ConstFalseExpressionContext ctx) {} + + @Override + public void enterBooleanOrExpression(BooleanOrExpressionContext ctx) {} + + @Override + public void exitBooleanOrExpression(BooleanOrExpressionContext ctx) {} + + @Override + public void enterAssignmentStatement(AssignmentStatementContext ctx) {} + + @Override + public void exitAssignmentStatement(AssignmentStatementContext ctx) {} + + @Override + public void enterIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) {} + + @Override + public void exitIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) {} + + @Override + public void enterParForStatement(ParForStatementContext ctx) {} + + @Override + public void exitParForStatement(ParForStatementContext ctx) {} + + @Override + public void enterStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {} + + @Override + public void exitStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {} + + @Override + public void enterCommandlineParamExpression(CommandlineParamExpressionContext ctx) {} + + @Override + public void exitCommandlineParamExpression(CommandlineParamExpressionContext ctx) {} + + @Override + public void enterMultDivExpression(MultDivExpressionContext ctx) {} + + @Override + public void exitMultDivExpression(MultDivExpressionContext ctx) {} + + @Override + public void enterAddSubExpression(AddSubExpressionContext ctx) {} + + @Override + public void exitAddSubExpression(AddSubExpressionContext ctx) {} + + @Override + public void enterImportStatement(ImportStatementContext ctx) {} + + @Override + public void exitImportStatement(ImportStatementContext ctx) {} + + @Override + public void enterProgramroot(ProgramrootContext ctx) {} + + @Override + public void exitProgramroot(ProgramrootContext ctx) {} + + @Override + public void enterIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {} + + @Override + public void exitIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {} + + @Override + public void enterIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {} + + @Override + public void exitIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {} + + @Override + public void enterBooleanAndExpression(BooleanAndExpressionContext ctx) {} + + @Override + public void exitBooleanAndExpression(BooleanAndExpressionContext ctx) {} + + @Override + public void enterIndexedExpression(IndexedExpressionContext ctx) {} + + @Override + public void exitIndexedExpression(IndexedExpressionContext ctx) {} + + @Override + public void enterBooleanNotExpression(BooleanNotExpressionContext ctx) {} + + @Override + public void exitBooleanNotExpression(BooleanNotExpressionContext ctx) {} + + @Override + public void enterIfStatement(IfStatementContext ctx) {} + + @Override + public void exitIfStatement(IfStatementContext ctx) {} + + @Override + public void enterRelationalExpression(RelationalExpressionContext ctx) {} + + @Override + public void exitRelationalExpression(RelationalExpressionContext ctx) {} + + @Override + public void enterCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {} + + @Override + public void exitCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {} + + @Override + public void enterConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) {} + + @Override + public void exitConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) {} + + @Override + public void enterUnaryExpression(UnaryExpressionContext ctx) {} + + @Override + public void exitUnaryExpression(UnaryExpressionContext ctx) {} + + @Override + public void enterMl_type(Ml_typeContext ctx) {} + + @Override + public void exitMl_type(Ml_typeContext ctx) {} + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java index 81538bc..6bbaf10 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java @@ -112,8 +112,8 @@ import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext; public class DmlSyntacticValidator extends CommonSyntacticValidator implements DmlListener { - public DmlSyntacticValidator(CustomErrorListener errorListener, Map<String,String> argVals, String sourceNamespace) { - super(errorListener, argVals, sourceNamespace); + public DmlSyntacticValidator(CustomErrorListener errorListener, Map<String,String> argVals, String sourceNamespace, Set<String> prepFunctions) { + super(errorListener, argVals, sourceNamespace, prepFunctions); } @Override public String namespaceResolutionOp() { return "::"; } @@ -789,7 +789,6 @@ public class DmlSyntacticValidator extends CommonSyntacticValidator implements D // set function name functionStmt.setName(ctx.name.getText()); - if(ctx.body.size() > 0) { // handle function body // Create arraylist of one statement block http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java b/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java index ed5ee5b..e98df38 100644 --- a/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java +++ b/src/main/java/org/apache/sysml/parser/pydml/PyDMLParserWrapper.java @@ -161,7 +161,11 @@ public class PyDMLParserWrapper extends AParserWrapper ParseTree tree = ast; // And also do syntactic validation ParseTreeWalker walker = new ParseTreeWalker(); - PydmlSyntacticValidator validator = new PydmlSyntacticValidator(errorListener, argVals, sourceNamespace); + // Get list of function definitions which take precedence over built-in functions if same name + PydmlPreprocessor prep = new PydmlPreprocessor(errorListener); + walker.walk(prep, tree); + // Syntactic validation + PydmlSyntacticValidator validator = new PydmlSyntacticValidator(errorListener, argVals, sourceNamespace, prep.getFunctionDefs()); walker.walk(validator, tree); errorListener.unsetCurrentFileName(); this.parseIssues = errorListener.getParseIssues(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/main/java/org/apache/sysml/parser/pydml/PydmlPreprocessor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/pydml/PydmlPreprocessor.java b/src/main/java/org/apache/sysml/parser/pydml/PydmlPreprocessor.java new file mode 100644 index 0000000..2407b45 --- /dev/null +++ b/src/main/java/org/apache/sysml/parser/pydml/PydmlPreprocessor.java @@ -0,0 +1,374 @@ +package org.apache.sysml.parser.pydml; + +import java.util.HashSet; +import java.util.Set; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.tree.ErrorNode; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.apache.sysml.parser.common.CustomErrorListener; +import org.apache.sysml.parser.pydml.PydmlParser.AddSubExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.AssignmentStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.AtomicExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.BooleanAndExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.BooleanNotExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.BooleanOrExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.BuiltinFunctionExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.CommandlineParamExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.CommandlinePositionExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ConstDoubleIdExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ConstFalseExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ConstIntIdExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ConstStringIdExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ConstTrueExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.DataIdExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ExternalFunctionDefExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ForStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.FunctionCallAssignmentStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.FunctionCallMultiAssignmentStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.IfStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.IfdefAssignmentStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.IgnoreNewLineContext; +import org.apache.sysml.parser.pydml.PydmlParser.ImportStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.IndexedExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.InternalFunctionDefExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.IterablePredicateColonExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.IterablePredicateSeqExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.MatrixDataTypeCheckContext; +import org.apache.sysml.parser.pydml.PydmlParser.Ml_typeContext; +import org.apache.sysml.parser.pydml.PydmlParser.ModIntDivExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.MultDivExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ParForStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.ParameterizedExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.PathStatementContext; +import org.apache.sysml.parser.pydml.PydmlParser.PowerExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ProgramrootContext; +import org.apache.sysml.parser.pydml.PydmlParser.RelationalExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.SimpleDataIdentifierExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.StrictParameterizedExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.StrictParameterizedKeyValueStringContext; +import org.apache.sysml.parser.pydml.PydmlParser.TypedArgNoAssignContext; +import org.apache.sysml.parser.pydml.PydmlParser.UnaryExpressionContext; +import org.apache.sysml.parser.pydml.PydmlParser.ValueDataTypeCheckContext; +import org.apache.sysml.parser.pydml.PydmlParser.WhileStatementContext; + +/** + * Minimal pre-processing of user function definitions which take precedence over built-in + * functions in cases where names conflict. This pre-processing takes place outside of + * PymlSyntacticValidator since the function definition can be located after the function + * is used in a statement. + */ +public class PydmlPreprocessor implements PydmlListener { + + protected final CustomErrorListener errorListener; + // Names of user internal and external functions definitions + protected Set<String> functions; + + public PydmlPreprocessor(CustomErrorListener errorListener) { + this.errorListener = errorListener; + functions = new HashSet<String>(); + } + + public Set<String> getFunctionDefs() { + return functions; + } + + @Override + public void enterExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) { + validateFunctionName(ctx.name.getText(), ctx); + } + + @Override + public void exitExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) {} + + @Override + public void enterInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) { + validateFunctionName(ctx.name.getText(), ctx); + } + + @Override + public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {} + + protected void validateFunctionName(String name, ParserRuleContext ctx) { + if (!functions.contains(name)) { + functions.add(name); + } + else { + notifyErrorListeners("Function Name Conflict: '" + name + "' already defined in " + errorListener.getCurrentFileName(), ctx.start); + } + } + + protected void notifyErrorListeners(String message, Token op) { + errorListener.validationError(op.getLine(), op.getCharPositionInLine(), message); + } + + // ----------------------------------------------------------------- + // Not overridden + // ----------------------------------------------------------------- + + @Override + public void visitTerminal(TerminalNode node) {} + + @Override + public void visitErrorNode(ErrorNode node) {} + + @Override + public void enterEveryRule(ParserRuleContext ctx) {} + + @Override + public void exitEveryRule(ParserRuleContext ctx) {} + + @Override + public void enterFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {} + + @Override + public void exitFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {} + + @Override + public void enterIgnoreNewLine(IgnoreNewLineContext ctx) {} + + @Override + public void exitIgnoreNewLine(IgnoreNewLineContext ctx) {} + + @Override + public void enterMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {} + + @Override + public void exitMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {} + + @Override + public void enterStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) {} + + @Override + public void exitStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) {} + + @Override + public void enterPathStatement(PathStatementContext ctx) {} + + @Override + public void exitPathStatement(PathStatementContext ctx) {} + + @Override + public void enterConstTrueExpression(ConstTrueExpressionContext ctx) {} + + @Override + public void exitConstTrueExpression(ConstTrueExpressionContext ctx) {} + + @Override + public void enterTypedArgNoAssign(TypedArgNoAssignContext ctx) {} + + @Override + public void exitTypedArgNoAssign(TypedArgNoAssignContext ctx) {} + + @Override + public void enterWhileStatement(WhileStatementContext ctx) {} + + @Override + public void exitWhileStatement(WhileStatementContext ctx) {} + + @Override + public void enterConstStringIdExpression(ConstStringIdExpressionContext ctx) {} + + @Override + public void exitConstStringIdExpression(ConstStringIdExpressionContext ctx) {} + + @Override + public void enterDataIdExpression(DataIdExpressionContext ctx) {} + + @Override + public void exitDataIdExpression(DataIdExpressionContext ctx) {} + + @Override + public void enterAtomicExpression(AtomicExpressionContext ctx) {} + + @Override + public void exitAtomicExpression(AtomicExpressionContext ctx) {} + + @Override + public void enterPowerExpression(PowerExpressionContext ctx) {} + + @Override + public void exitPowerExpression(PowerExpressionContext ctx) {} + + @Override + public void enterFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) {} + + @Override + public void exitFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) {} + + @Override + public void enterModIntDivExpression(ModIntDivExpressionContext ctx) {} + + @Override + public void exitModIntDivExpression(ModIntDivExpressionContext ctx) {} + + @Override + public void enterSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {} + + @Override + public void exitSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {} + + @Override + public void enterBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {} + + @Override + public void exitBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {} + + @Override + public void enterConstIntIdExpression(ConstIntIdExpressionContext ctx) {} + + @Override + public void exitConstIntIdExpression(ConstIntIdExpressionContext ctx) {} + + @Override + public void enterForStatement(ForStatementContext ctx) {} + + @Override + public void exitForStatement(ForStatementContext ctx) {} + + @Override + public void enterParameterizedExpression(ParameterizedExpressionContext ctx) {} + + @Override + public void exitParameterizedExpression(ParameterizedExpressionContext ctx) {} + + @Override + public void enterConstFalseExpression(ConstFalseExpressionContext ctx) {} + + @Override + public void exitConstFalseExpression(ConstFalseExpressionContext ctx) {} + + @Override + public void enterBooleanOrExpression(BooleanOrExpressionContext ctx) {} + + @Override + public void exitBooleanOrExpression(BooleanOrExpressionContext ctx) {} + + @Override + public void enterAssignmentStatement(AssignmentStatementContext ctx) {} + + @Override + public void exitAssignmentStatement(AssignmentStatementContext ctx) {} + + @Override + public void enterIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) {} + + @Override + public void exitIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) {} + + @Override + public void enterParForStatement(ParForStatementContext ctx) {} + + @Override + public void exitParForStatement(ParForStatementContext ctx) {} + + @Override + public void enterStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {} + + @Override + public void exitStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {} + + @Override + public void enterCommandlineParamExpression(CommandlineParamExpressionContext ctx) {} + + @Override + public void exitCommandlineParamExpression(CommandlineParamExpressionContext ctx) {} + + @Override + public void enterMultDivExpression(MultDivExpressionContext ctx) {} + + @Override + public void exitMultDivExpression(MultDivExpressionContext ctx) {} + + @Override + public void enterAddSubExpression(AddSubExpressionContext ctx) {} + + @Override + public void exitAddSubExpression(AddSubExpressionContext ctx) {} + + @Override + public void enterImportStatement(ImportStatementContext ctx) {} + + @Override + public void exitImportStatement(ImportStatementContext ctx) {} + + @Override + public void enterProgramroot(ProgramrootContext ctx) {} + + @Override + public void exitProgramroot(ProgramrootContext ctx) {} + + @Override + public void enterIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {} + + @Override + public void exitIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {} + + @Override + public void enterIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {} + + @Override + public void exitIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {} + + @Override + public void enterBooleanAndExpression(BooleanAndExpressionContext ctx) {} + + @Override + public void exitBooleanAndExpression(BooleanAndExpressionContext ctx) {} + + @Override + public void enterIndexedExpression(IndexedExpressionContext ctx) {} + + @Override + public void exitIndexedExpression(IndexedExpressionContext ctx) {} + + @Override + public void enterBooleanNotExpression(BooleanNotExpressionContext ctx) {} + + @Override + public void exitBooleanNotExpression(BooleanNotExpressionContext ctx) {} + + @Override + public void enterIfStatement(IfStatementContext ctx) {} + + @Override + public void exitIfStatement(IfStatementContext ctx) {} + + @Override + public void enterRelationalExpression(RelationalExpressionContext ctx) {} + + @Override + public void exitRelationalExpression(RelationalExpressionContext ctx) {} + + @Override + public void enterCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {} + + @Override + public void exitCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {} + + @Override + public void enterConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) {} + + @Override + public void exitConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) {} + + @Override + public void enterUnaryExpression(UnaryExpressionContext ctx) {} + + @Override + public void exitUnaryExpression(UnaryExpressionContext ctx) {} + + @Override + public void enterValueDataTypeCheck(ValueDataTypeCheckContext ctx) {} + + @Override + public void exitValueDataTypeCheck(ValueDataTypeCheckContext ctx) {} + + @Override + public void enterMl_type(Ml_typeContext ctx) {} + + @Override + public void exitMl_type(Ml_typeContext ctx) {} + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java index 3e2215c..b070314 100644 --- a/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/pydml/PydmlSyntacticValidator.java @@ -122,8 +122,8 @@ import org.apache.sysml.parser.pydml.PydmlParser.WhileStatementContext; */ public class PydmlSyntacticValidator extends CommonSyntacticValidator implements PydmlListener { - public PydmlSyntacticValidator(CustomErrorListener errorListener, Map<String,String> argVals, String sourceNamespace) { - super(errorListener, argVals, sourceNamespace); + public PydmlSyntacticValidator(CustomErrorListener errorListener, Map<String,String> argVals, String sourceNamespace, Set<String> prepFunctions) { + super(errorListener, argVals, sourceNamespace, prepFunctions); } @Override public String namespaceResolutionOp() { return "."; } @@ -600,7 +600,9 @@ public class PydmlSyntacticValidator extends CommonSyntacticValidator implements */ private ConvertedDMLSyntax convertPythonBuiltinFunctionToDMLSyntax(ParserRuleContext ctx, String namespace, String functionName, ArrayList<ParameterExpression> paramExpression, Token fnName) { - + if (sources.containsValue(namespace) || functions.contains(functionName)) { + return new ConvertedDMLSyntax(namespace, functionName, paramExpression); + } String fileName = currentFile; int line = ctx.start.getLine(); @@ -1344,8 +1346,7 @@ public class PydmlSyntacticValidator extends CommonSyntacticValidator implements // set function name functionStmt.setName(ctx.name.getText()); - - + if(ctx.body.size() > 0) { // handle function body // Create arraylist of one statement block @@ -1379,7 +1380,7 @@ public class PydmlSyntacticValidator extends CommonSyntacticValidator implements // set function name functionStmt.setName(ctx.name.getText()); - + // set other parameters HashMap<String, String> otherParams = new HashMap<String,String>(); boolean atleastOneClassName = false; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionNamespaceTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionNamespaceTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionNamespaceTest.java index 40cc423..640f973 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionNamespaceTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionNamespaceTest.java @@ -44,6 +44,9 @@ public class FunctionNamespaceTest extends AutomatedTestBase private final static String TEST_NAME8 = "Functions8"; private final static String TEST_NAME9 = "Functions9"; private final static String TEST_NAME10 = "Functions10"; + private final static String TEST_NAME11 = "Functions11"; + private final static String TEST_NAME12 = "Functions12"; + private final static String TEST_NAME13 = "Functions13"; private final static String TEST_DIR = "functions/misc/"; private final static String TEST_CLASS_DIR = TEST_DIR + FunctionNamespaceTest.class.getSimpleName() + "/"; @@ -65,6 +68,9 @@ public class FunctionNamespaceTest extends AutomatedTestBase addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8)); addTestConfiguration(TEST_NAME9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME9)); addTestConfiguration(TEST_NAME10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME10)); + addTestConfiguration(TEST_NAME11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME11)); + addTestConfiguration(TEST_NAME12, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME12)); + addTestConfiguration(TEST_NAME13, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME13)); } @Test @@ -145,6 +151,24 @@ public class FunctionNamespaceTest extends AutomatedTestBase } @Test + public void testFunctionBuiltinOverride() + { + runFunctionNamespaceTest(TEST_NAME11, ScriptType.DML); + } + + @Test + public void testFunctionMultiOverride() + { + runFunctionNamespaceTest(TEST_NAME12, ScriptType.DML); + } + + @Test + public void testFunctionErrorOverride() + { + runFunctionNamespaceTest(TEST_NAME13, ScriptType.DML); + } + + @Test public void testPyFunctionDefaultNS() { runFunctionNamespaceTest(TEST_NAME0, ScriptType.PYDML); @@ -221,6 +245,24 @@ public class FunctionNamespaceTest extends AutomatedTestBase { runFunctionNamespaceTest(TEST_NAME10, ScriptType.PYDML); } + + @Test + public void testPyFunctionBuiltinOverride() + { + runFunctionNamespaceTest(TEST_NAME11, ScriptType.PYDML); + } + + @Test + public void testPyFunctionMultiOverride() + { + runFunctionNamespaceTest(TEST_NAME12, ScriptType.PYDML); + } + + @Test + public void testPyFunctionErrorOverride() + { + runFunctionNamespaceTest(TEST_NAME13, ScriptType.PYDML); + } private void runFunctionNamespaceTest(String TEST_NAME, ScriptType scriptType) { @@ -257,6 +299,13 @@ public class FunctionNamespaceTest extends AutomatedTestBase Assert.fail("Expected parse issue not detected."); } } + else if (TEST_NAME13.equals(TEST_NAME)) + { + if (stdErrString != null && !stdErrString.contains("Function Name Conflict")) + { + Assert.fail("Expected parse issue not detected."); + } + } else { Assert.fail("Unexpected parse error or DML script error: " + stdErrString); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/Functions11.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/Functions11.dml b/src/test/scripts/functions/misc/Functions11.dml new file mode 100644 index 0000000..c9938de --- /dev/null +++ b/src/test/scripts/functions/misc/Functions11.dml @@ -0,0 +1,66 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Import function definition override of built-in sum +source("./src/test/scripts/functions/misc/FunctionsL1.dml") as Functions + +# Local function definition override of built-in min returning scalar +min = function(integer x) return (integer y) +{ + print("override min") + Z = matrix(x, rows=4, cols=2) + y = x*2 +} + +# Use local min after definition +result = min(1) +print("min is " + result) + +M1 = matrix("1 2 3 4", rows=2, cols=2) +M2 = matrix("5 6 7 8", rows=2, cols=2) + +# Built-in min not directly accessible in this script due to local override +#result = min(M1) + +# Use imported min +result = Functions::min(M1) + +# Use imported function with overrides +[min, max] = Functions::minMax(M2) + +# Use imported sum +result = Functions::sum(M1) + +# Built-in sum accessible since imported override +result = sum(M2) +print("Built-in sum is " + result) + +# Use local override before function definition +Z = rand(2) +print("rand sum is " + sum(Z)) + +# Local function definition override of built-in rand returning matrix +rand = function(int x) return (matrix[double] Z) +{ + print("override rand") + Z = matrix(x, rows=4, cols=2) + y = x*2 +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/Functions11.pydml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/Functions11.pydml b/src/test/scripts/functions/misc/Functions11.pydml new file mode 100644 index 0000000..d3511bd --- /dev/null +++ b/src/test/scripts/functions/misc/Functions11.pydml @@ -0,0 +1,62 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Import function definition override of built-in sum +source("./src/test/scripts/functions/misc/FunctionsL1.pydml") as Functions + +# Local function definition override of built-in min returning scalar +def min(x: int) -> (y: int): + print("override min") + Z = full(x, rows=4, cols=2) + y = x*2 + +# Use local min after definition +result = min(1) +print("min is " + result) + +M1 = full("1 2 3 4", rows=2, cols=2) +M2 = full("5 6 7 8", rows=2, cols=2) + +# Built-in min not directly accessible in this script due to local override +#result = min(M1) + +# Use imported min +result = Functions.min(M1) + +# Use imported function with overrides +[min, max] = Functions.minMax(M2) + +# Use imported sum +result = Functions.sum(M1) + +# Built-in sum accessible since imported override +result = sum(M2) +print("Built-in sum is " + result) + +# Use local override before function definition +Z = rand(2) +print("rand sum is " + sum(Z)) + +# Local function definition override of built-in rand returning matrix +def rand(x: int) -> (Z: matrix[float]): + print("override rand") + Z = full(x, rows=4, cols=2) + y = x*2 http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/Functions12.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/Functions12.dml b/src/test/scripts/functions/misc/Functions12.dml new file mode 100644 index 0000000..d6fe024 --- /dev/null +++ b/src/test/scripts/functions/misc/Functions12.dml @@ -0,0 +1,54 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Override specific built-in function definition types +source("./src/test/scripts/functions/misc/FunctionsL2.dml") as Functions + +M1 = matrix("1 2 3 4", rows=2, cols=2) +M2 = matrix("5 6 7 8", rows=2, cols=2) + +# Use imported external function t +now = Functions::t() +print("Time is " + now) + +# Built-in transpose accessible since imported override +result = t(M2) +nothing = Functions::printMatrix(result) + +# Use imported qr multiple return function +[min, max] = Functions::qr(M1) + +# Use built-in qr +[Q, R] = qr(M2) +nothing = Functions::printMatrix(Q) +nothing = Functions::printMatrix(R) + +# Use local override before function definition +[y, Z] = rand(2) +print("rand is " + y + ", " + sum(Z)) + +# Local function definition override of built-in rand returning matrix and scalar +rand = function(int x) return (int y, matrix[double] Z) +{ + print("override rand") + y = x*2 + Z = matrix(y, rows=4, cols=2) +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/Functions12.pydml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/Functions12.pydml b/src/test/scripts/functions/misc/Functions12.pydml new file mode 100644 index 0000000..4b293c6 --- /dev/null +++ b/src/test/scripts/functions/misc/Functions12.pydml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Override specific built-in function definition types +source("./src/test/scripts/functions/misc/FunctionsL2.pydml") as Functions + +M1 = full("1 2 3 4", rows=2, cols=2) +M2 = full("5 6 7 8", rows=2, cols=2) + +# Use imported external function t +now = Functions.t() +print("Time is " + now) + +# Built-in transpose accessible since imported override +result = t(M2) +nothing = Functions.printMatrix(result) + +# Use imported qr multiple return function +[min, max] = Functions.qr(M1) + +# Use built-in qr +[Q, R] = qr(M2) +nothing = Functions.printMatrix(Q) +nothing = Functions.printMatrix(R) + +# Use local override before function definition +[y, Z] = rand(2) +print("rand is " + y + ", " + sum(Z)) + +# Local function definition override of built-in rand returning matrix and scalar +def rand(x: int) -> (y: int, Z: matrix[float]): + print("override rand") + y = x*2 + Z = full(y, rows=4, cols=2) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/Functions13.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/Functions13.dml b/src/test/scripts/functions/misc/Functions13.dml new file mode 100644 index 0000000..77e0a28 --- /dev/null +++ b/src/test/scripts/functions/misc/Functions13.dml @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +min = function(integer i) +{ + print("min" + i) +} + +write = function(String message) return () +{ + print(message) +} + +nothing = min(1) +nothing = write("goodbye") + +# Report parse issue if attempt to redefine function in same file +min = function(integer i) +{ + print("max" + i) +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/Functions13.pydml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/Functions13.pydml b/src/test/scripts/functions/misc/Functions13.pydml new file mode 100644 index 0000000..0978af2 --- /dev/null +++ b/src/test/scripts/functions/misc/Functions13.pydml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +def min(i: int): + print("min" + i) + +def save(message: str) -> (): + print(message) + +nothing = min(1) +nothing = save("goodbye") + +# Report parse issue if attempt to redefine function in same file +def min(i: int): + print("max" + i) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/FunctionsL1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/FunctionsL1.dml b/src/test/scripts/functions/misc/FunctionsL1.dml new file mode 100644 index 0000000..1c02e8b --- /dev/null +++ b/src/test/scripts/functions/misc/FunctionsL1.dml @@ -0,0 +1,55 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +sum = function(matrix[double] X) return (double total) +{ + total = 0 + for (i in 1:nrow(X)) + { + for (j in 1:ncol(X)) + { + total = total + as.scalar(X[i,j]) + } + } + print("Override sum is " + total) +} + +min = function(matrix[double] X) return (double minimum) +{ + MinRow = rowMins(X) + MinCol = colMins(MinRow) + minimum = as.scalar(MinCol[1,1]) + print("Minimum is " + minimum) +} + +minMax = function(matrix[double] M) return (double minVal, double maxVal) +{ + # Access local overrides (defined before or after) instead of built-ins + minVal = min(M) + maxVal = max(M) +} + +max = function(matrix[double] X) return (double maximum) +{ + MaxRow = rowMaxs(X) + MaxCol = colMaxs(MaxRow) + maximum = as.scalar(MaxCol[1,1]) + print("Maximum is " + maximum) +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/FunctionsL1.pydml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/FunctionsL1.pydml b/src/test/scripts/functions/misc/FunctionsL1.pydml new file mode 100644 index 0000000..fa16e7c --- /dev/null +++ b/src/test/scripts/functions/misc/FunctionsL1.pydml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +def sum(X: matrix[float]) -> (total: float): + total = 0 + for (i in 0:nrow(X)-1): + for (j in 0:ncol(X)-1): + total = total + scalar(X[i,j]) + print("Override sum is " + total) + +def min(X: matrix[float]) -> (minimum: float): + MinRow = rowMins(X) + MinCol = colMins(MinRow) + minimum = scalar(MinCol[0,0]) + print("Minimum is " + minimum) + +def minMax(M: matrix[float]) -> (minVal: float, maxVal: float): + # Access local overrides (defined before or after) instead of built-ins + minVal = min(M) + maxVal = max(M) + +def max(X: matrix[float]) -> (maximum: float): + MaxRow = rowMaxs(X) + MaxCol = colMaxs(MaxRow) + maximum = scalar(MaxCol[0,0]) + print("Maximum is " + maximum) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/FunctionsL2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/FunctionsL2.dml b/src/test/scripts/functions/misc/FunctionsL2.dml new file mode 100644 index 0000000..27cc8ad --- /dev/null +++ b/src/test/scripts/functions/misc/FunctionsL2.dml @@ -0,0 +1,44 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# External function definition override (t built-in matrix transpose) +t = externalFunction() return (double B) + implemented in (classname="org.apache.sysml.udf.lib.TimeWrapper", exectype="mem") + +# Multiple return function definition override (qr built-in matrix QR decomposition) +qr = function(matrix[double] M) return (double minVal, double maxVal) { + minVal = min(M) + maxVal = max(M) + print("Minimum is " + minVal) + print("Maximum is " + maxVal) +} + +printMatrix = function(matrix[double] X) return () +{ + for (i in 1:nrow(X)) + { + for (j in 1:ncol(X)) + { + xij = as.scalar(X[i,j]) + print("[" + i + "," + j + "] " + xij) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/90f56da2/src/test/scripts/functions/misc/FunctionsL2.pydml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/FunctionsL2.pydml b/src/test/scripts/functions/misc/FunctionsL2.pydml new file mode 100644 index 0000000..a2f577f --- /dev/null +++ b/src/test/scripts/functions/misc/FunctionsL2.pydml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# External function definition override (t built-in matrix transpose) +defExternal t() -> (B: float) implemented in (classname="org.apache.sysml.udf.lib.TimeWrapper", exectype="mem") + +# Multiple return function definition override (qr built-in matrix QR decomposition) +def qr(M: matrix[float]) -> (minVal: float, maxVal: float): + minVal = min(M) + maxVal = max(M) + print("Minimum is " + minVal) + print("Maximum is " + maxVal) + +def printMatrix(X: matrix[float]) -> (): + for (i in 0:nrow(X)-1): + for (j in 0:ncol(X)-1): + xij = scalar(X[i,j]) + print("[" + i + "," + j + "] " + xij)
