This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new ff4f23d [SYSTEMDS-2823] Rework namespace handling of built-in
functions
ff4f23d is described below
commit ff4f23d7854e2f204674d0fe750b6f85c00bd9e2
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Apr 2 21:42:58 2021 +0200
[SYSTEMDS-2823] Rework namespace handling of built-in functions
Motivated by the bug report SYSTEMDS-2823, this patch makes a major
rework of the namespace handling of builtin functions, which fixes many
subtle issues and provides very clean semantics:
All dml-based script functions, no matter if called in the default
namespace, or named namespaces, get placed and invoked from a new
dedicated namespace '.builtinNS'. The builtin functions are only
parsed and added if the called function does not exist in the invoking
default/named namespace - so, user functions always take precedence.
Furthermore, this patch fixes issues with private functions in builtin
scripts that are called from named namespaces, allows for
imports/sourcing in builtin functions, avoids name conflicts of user
functions that call both custom functions and builtin functions of the
same name (potentially in in a deep hierarchy), and avoids redundant
function parsing and compilation.
---
.../java/org/apache/sysds/common/Builtins.java | 5 ++-
.../java/org/apache/sysds/parser/DMLProgram.java | 24 ++++++++++-
.../sysds/parser/FunctionCallIdentifier.java | 13 +++---
.../org/apache/sysds/parser/StatementBlock.java | 20 +++++----
.../apache/sysds/parser/dml/DMLParserWrapper.java | 19 ++++++---
.../sysds/parser/dml/DmlSyntacticValidator.java | 49 ++++++++--------------
.../instructions/cp/EvalNaryCPInstruction.java | 24 ++++++-----
.../functions/builtin/BuiltinTomeklinkTest.java | 14 +++----
.../functions/builtin/BuiltinWinsorizeTest.java | 28 +++++++++----
.../scripts/functions/builtin/winsorizeFoo.dml | 29 +++++++++++++
.../scripts/functions/builtin/winsorizeMain.dml | 26 ++++++++++++
11 files changed, 172 insertions(+), 79 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 6610661..34b5cc5 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -374,6 +374,7 @@ public enum Builtins {
}
public static String getInternalFName(String name, DataType dt) {
- return (dt.isMatrix() ? "m_" : "s_") + name;
+ return !contains(name, true, false) ? name : // private builtin
+ (dt.isMatrix() ? "m_" : "s_") + name; // public
builtin
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java
b/src/main/java/org/apache/sysds/parser/DMLProgram.java
index d1720df..ea9f306 100644
--- a/src/main/java/org/apache/sysds/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java
@@ -30,6 +30,7 @@ import org.apache.sysds.runtime.controlprogram.Program;
public class DMLProgram
{
public static final String DEFAULT_NAMESPACE = ".defaultNS";
+ public static final String BUILTIN_NAMESPACE = ".builtinNS";
public static final String INTERNAL_NAMESPACE = "_internal"; // used
for multi-return builtin functions
private ArrayList<StatementBlock> _blocks;
@@ -42,7 +43,7 @@ public class DMLProgram
public DMLProgram(String namespace) {
this();
- _namespaces.put(namespace, new FunctionDictionary<>());
+ createNamespace(namespace);
}
public Map<String,FunctionDictionary<FunctionStatementBlock>>
getNamespaces(){
@@ -56,6 +57,19 @@ public class DMLProgram
public int getNumStatementBlocks(){
return _blocks.size();
}
+
+ public static boolean isInternalNamespace(String namespace) {
+ return DEFAULT_NAMESPACE.equals(namespace)
+ || BUILTIN_NAMESPACE.equals(namespace)
+ || INTERNAL_NAMESPACE.equals(namespace);
+ }
+
+ public FunctionDictionary<FunctionStatementBlock>
createNamespace(String namespace) {
+ // create on demand, necessary to avoid overwriting existing
functions
+ if( !_namespaces.containsKey(namespace) )
+ _namespaces.put(namespace, new FunctionDictionary<>());
+ return _namespaces.get(namespace);
+ }
/**
*
@@ -122,6 +136,14 @@ public class DMLProgram
return _namespaces.get(DEFAULT_NAMESPACE);
}
+ public FunctionDictionary<FunctionStatementBlock>
getBuiltinFunctionDictionary() {
+ return _namespaces.get(BUILTIN_NAMESPACE);
+ }
+
+ public FunctionDictionary<FunctionStatementBlock>
getFunctionDictionary(String namespace) {
+ return _namespaces.get(namespace);
+ }
+
public void addFunctionStatementBlock(String fname,
FunctionStatementBlock fsb) {
addFunctionStatementBlock(DEFAULT_NAMESPACE, fname, fsb);
}
diff --git a/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
b/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
index 533c3e0..173a17e 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionCallIdentifier.java
@@ -97,7 +97,11 @@ public class FunctionCallIdentifier extends DataIdentifier
raiseValidateError("namespace " + _namespace + " is not
defined ", conditional);
}
FunctionStatementBlock fblock =
dmlp.getFunctionStatementBlock(_namespace, _name);
- if (fblock == null && !Builtins.contains(_name, true, false) ){
+ if( fblock == null ) { //handle private builtin function
+ fblock =
dmlp.getFunctionStatementBlock(DMLProgram.BUILTIN_NAMESPACE, _name);
+ _namespace = DMLProgram.BUILTIN_NAMESPACE;
+ }
+ if (fblock == null && !Builtins.contains(_name, true, false)){
raiseValidateError("function " + _name + " is undefined
in namespace " + _namespace, conditional);
}
@@ -128,11 +132,10 @@ public class FunctionCallIdentifier extends DataIdentifier
}
// Step 5: replace dml-bodied builtin function calls after type
inference
- if( Builtins.contains(_name, true, false)
- && _namespace.equals(DMLProgram.DEFAULT_NAMESPACE) ) {
+ if( Builtins.contains(_name, true, false) && fblock == null ) {
DataType dt =
_paramExprs.get(0).getExpr().getOutput().getDataType();
_name = Builtins.getInternalFName(_name, dt);
- _namespace = DMLProgram.DEFAULT_NAMESPACE;
+ _namespace = DMLProgram.BUILTIN_NAMESPACE;
fblock = dmlp.getFunctionStatementBlock(_namespace,
_name);
if( fblock == null ) {
raiseValidateError("Builtin function '"+_name+
"': script loaded "
@@ -212,5 +215,3 @@ public class FunctionCallIdentifier extends DataIdentifier
return true;
}
}
-
-
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 5d19575..c4876d4 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -250,7 +250,8 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
FunctionCallIdentifier fcall =
(FunctionCallIdentifier) sourceExpr;
FunctionStatementBlock fblock =
dmlProg.getFunctionStatementBlock(fcall.getNamespace(),fcall.getName());
if (fblock == null) {
- if( Builtins.contains(fcall.getName(),
true, false) )
+ if( Builtins.contains(fcall.getName(),
true, false)
+ ||
DMLProgram.isInternalNamespace(fcall.getNamespace()))
return false;
throw new
LanguageException(sourceExpr.printErrorLocation() + "function "
+ fcall.getName() + " is
undefined in namespace " + fcall.getNamespace());
@@ -604,16 +605,17 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
di.setValueType(fexpr.getValueType());
tmp.add(new AssignmentStatement(di, fexpr, di));
//add hoisted dml-bodied builtin function to
program (if not already loaded)
- if( Builtins.contains(fexpr.getName(), true,
false)
- &&
!prog.getDefaultFunctionDictionary().containsFunction(
-
Builtins.getInternalFName(fexpr.getName(), DataType.SCALAR))
- &&
!prog.getDefaultFunctionDictionary().containsFunction(
-
Builtins.getInternalFName(fexpr.getName(), DataType.MATRIX))) {
+ FunctionDictionary<FunctionStatementBlock>
fdict = prog.getBuiltinFunctionDictionary();
+ if( Builtins.contains(fexpr.getName(), true,
false) && (fdict == null ||
+
(!fdict.containsFunction(Builtins.getInternalFName(fexpr.getName(),
DataType.SCALAR))
+ &&
!fdict.containsFunction(Builtins.getInternalFName(fexpr.getName(),
DataType.MATRIX)))) )
+ {
+ fdict =
prog.createNamespace(DMLProgram.BUILTIN_NAMESPACE);
Map<String,FunctionStatementBlock> fsbs
= DmlSyntacticValidator
-
.loadAndParseBuiltinFunction(fexpr.getName(), fexpr.getNamespace());
+
.loadAndParseBuiltinFunction(fexpr.getName(), DMLProgram.BUILTIN_NAMESPACE);
for(
Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet() ) {
- if(
!prog.getDefaultFunctionDictionary().containsFunction(fsb.getKey()) )
-
prog.getDefaultFunctionDictionary().addFunction(fsb.getKey(), fsb.getValue());
+ if(
!fdict.containsFunction(fsb.getKey()) )
+
fdict.addFunction(fsb.getKey(), fsb.getValue());
fsb.getValue().setDMLProg(prog);
}
}
diff --git a/src/main/java/org/apache/sysds/parser/dml/DMLParserWrapper.java
b/src/main/java/org/apache/sysds/parser/dml/DMLParserWrapper.java
index 551569c..4b5e282 100644
--- a/src/main/java/org/apache/sysds/parser/dml/DMLParserWrapper.java
+++ b/src/main/java/org/apache/sysds/parser/dml/DMLParserWrapper.java
@@ -24,6 +24,7 @@ import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
+import java.util.Map.Entry;
import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.BailErrorStrategy;
@@ -187,19 +188,20 @@ public class DMLParserWrapper extends ParserWrapper
if (atLeastOneWarning) {
LOG.warn(CustomErrorListener.generateParseIssuesMessage(dmlScript,
parseIssues));
}
- dmlPgm = createDMLProgram(ast, sourceNamespace);
+ dmlPgm = createDMLProgram(ast, validator, sourceNamespace);
return dmlPgm;
}
- private static DMLProgram createDMLProgram(ProgramrootContext ast,
String sourceNamespace)
+ private static DMLProgram createDMLProgram(ProgramrootContext ast,
+ DmlSyntacticValidator validator, String sourceNamespace)
{
DMLProgram dmlPgm = new DMLProgram();
- String namespace = (sourceNamespace != null &&
sourceNamespace.length() > 0)
- ? sourceNamespace : DMLProgram.DEFAULT_NAMESPACE;
+ String namespace = (sourceNamespace != null &&
sourceNamespace.length() > 0) ?
+ sourceNamespace : DMLProgram.DEFAULT_NAMESPACE;
dmlPgm.getNamespaces().put(namespace, new
FunctionDictionary<>());
- // add all functions from the main script file
+ // add all functions from the parsed script file
for(FunctionStatementContext fn : ast.functionBlocks) {
FunctionStatementBlock functionStmtBlk = new
FunctionStatementBlock();
functionStmtBlk.addStatement(fn.info.stmt);
@@ -211,6 +213,13 @@ public class DMLParserWrapper extends ParserWrapper
return null;
}
}
+
+ // add all builtin functions collected while parsing script file
+ FunctionDictionary<FunctionStatementBlock> fbuiltins =
validator.getParsedBuiltinFunctions();
+ if( !fbuiltins.getFunctions().isEmpty() )
+ dmlPgm.createNamespace(DMLProgram.BUILTIN_NAMESPACE);
+ for( Entry<String, FunctionStatementBlock> e :
fbuiltins.getFunctions().entrySet() )
+
dmlPgm.addFunctionStatementBlock(DMLProgram.BUILTIN_NAMESPACE, e.getKey(),
e.getValue());
// add statements from main script file, as well as
// functions from imports and dml-bodied builtin functions
diff --git
a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
index 0f33803..2c4e8a3 100644
--- a/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysds/parser/dml/DmlSyntacticValidator.java
@@ -187,6 +187,10 @@ public class DmlSyntacticValidator implements DmlListener {
return "FALSE";
}
+ public FunctionDictionary<FunctionStatementBlock>
getParsedBuiltinFunctions() {
+ return builtinFuns;
+ }
+
protected ArrayList<ParameterExpression>
getParameterExpressionList(List<ParameterizedExpressionContext> paramExprs) {
ArrayList<ParameterExpression> retVal = new ArrayList<>();
for(ParameterizedExpressionContext ctx : paramExprs) {
@@ -546,7 +550,6 @@ public class DmlSyntacticValidator implements DmlListener {
ctx.info.expr = createFunctionCall(ctx, namespace,
functionName, paramExpression);
}
-
@Override
public void exitFunctionCallMultiAssignmentStatement(
FunctionCallMultiAssignmentStatementContext ctx) {
@@ -590,28 +593,31 @@ public class DmlSyntacticValidator implements DmlListener
{
setMultiAssignmentStatement(targetList, e, ctx,
ctx.info);
return;
}
- handleDMLBodiedBuiltinFunction(functionName, namespace,
ctx);
+ handleDMLBodiedBuiltinFunction(functionName,
DMLProgram.BUILTIN_NAMESPACE, ctx);
}
// Override default namespace for imported non-built-in function
- String inferNamespace = (sourceNamespace != null &&
sourceNamespace.length() > 0 && DMLProgram.DEFAULT_NAMESPACE.equals(namespace))
? sourceNamespace : namespace;
+ String inferNamespace = (sourceNamespace != null &&
sourceNamespace.length() > 0
+ && DMLProgram.DEFAULT_NAMESPACE.equals(namespace)) ?
sourceNamespace : namespace;
functCall.setFunctionNamespace(inferNamespace);
setMultiAssignmentStatement(targetList, functCall, ctx,
ctx.info);
}
-
+
private void handleDMLBodiedBuiltinFunction(String functionName, String
namespace, ParserRuleContext ctx) {
- if( Builtins.contains(functionName, true, false) ) {
+ if( Builtins.contains(functionName, true, false)
+ && !builtinFuns.containsFunction(functionName) )
+ {
//load and add builtin DML-bodied functions
String filePath = Builtins.getFilePath(functionName);
FunctionDictionary<FunctionStatementBlock> prog =
- parseAndAddImportedFunctions(namespace,
filePath, ctx).getDefaultFunctionDictionary();
+ parseAndAddImportedFunctions(namespace,
filePath, ctx).getBuiltinFunctionDictionary();
if( prog != null ) //robustness for existing functions
for( Entry<String,FunctionStatementBlock> f :
prog.getFunctions().entrySet() )
builtinFuns.addFunction(f.getKey(),
f.getValue());
}
}
-
+
public static Map<String,FunctionStatementBlock>
loadAndParseBuiltinFunction(String name, String namespace) {
if( !Builtins.contains(name, true, false) ) {
throw new DMLRuntimeException("Function "
@@ -624,15 +630,14 @@ public class DmlSyntacticValidator implements DmlListener
{
String filePath = Builtins.getFilePath(name);
FunctionDictionary<FunctionStatementBlock> dict = tmp
.parseAndAddImportedFunctions(namespace, filePath, null)
- .getDefaultFunctionDictionary();
+ .getBuiltinFunctionDictionary();
//construct output map of all functions
return dict.getFunctions();
}
-
// -----------------------------------------------------------------
- // Control Statements - Guards & Loops
+ // Control Statements - Guards & Loops
// -----------------------------------------------------------------
private static StatementBlock getStatementBlock(Statement current) {
@@ -999,25 +1004,7 @@ public class DmlSyntacticValidator implements DmlListener
{
@Override public void enterProgramroot(ProgramrootContext ctx) {}
@Override
- public void exitProgramroot(ProgramrootContext ctx) {
- //take over dml-bodied builtin functions into list of script
functions
- for( Entry<String,FunctionStatementBlock> e :
builtinFuns.getFunctions().entrySet() ) {
- FunctionStatementContext fn = new
FunctionStatementContext();
- fn.info = new StatementInfo();
- fn.info.stmt = e.getValue().getStatement(0);
- fn.info.functionName = e.getKey();
- //existing user-function overrides builtin function
- if( !containsFunction(ctx, e.getKey()) )
- ctx.functionBlocks.add(fn);
- }
- }
-
- private static boolean containsFunction(ProgramrootContext ctx, String
fname) {
- for( FunctionStatementContext fn : ctx.functionBlocks )
- if( fn.info.functionName.equals(fname) )
- return true;
- return false;
- }
+ public void exitProgramroot(ProgramrootContext ctx) {}
@Override public void enterDataIdExpression(DataIdExpressionContext
ctx) {}
@@ -1139,7 +1126,7 @@ public class DmlSyntacticValidator implements DmlListener
{
protected void validateNamespace(String namespace, String filePath,
ParserRuleContext ctx) {
// error out if different scripts from different file paths are
bound to the same namespace
- if( !DMLProgram.DEFAULT_NAMESPACE.equals(namespace) ) {
+ if( !DMLProgram.isInternalNamespace(namespace) ) {
if( sources.containsKey(namespace) &&
!sources.get(namespace).equals(filePath) )
notifyErrorListeners("Namespace Conflict: '" +
namespace
+ "' already defined as " +
sources.get(namespace), ctx.start);
@@ -1647,7 +1634,7 @@ public class DmlSyntacticValidator implements DmlListener
{
setAssignmentStatement(ctx, info, target, e);
return;
}
- handleDMLBodiedBuiltinFunction(functionName, namespace,
ctx);
+ handleDMLBodiedBuiltinFunction(functionName,
DMLProgram.BUILTIN_NAMESPACE, ctx);
}
// handle user-defined functions
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index eb75f71..9781d88 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -60,6 +60,7 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
public void processInstruction(ExecutionContext ec) {
//1. get the namespace and func
String funcName = ec.getScalarInput(inputs[0]).getStringValue();
+ String nsName = null; //default namespace
if( funcName.contains(Program.KEY_DELIM) )
throw new DMLRuntimeException("Eval calls to
'"+funcName+"', i.e., a function outside "
+ "the default "+ "namespace, are not supported
yet. Please call the function directly.");
@@ -77,14 +78,15 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
DataType dt1 = boundInputs[0].getDataType().isList() ?
DataType.MATRIX : boundInputs[0].getDataType();
String funcName2 = Builtins.getInternalFName(funcName, dt1);
- if( !ec.getProgram().containsFunctionProgramBlock(null,
funcName)) {
- if(
!ec.getProgram().containsFunctionProgramBlock(null,funcName2) )
+ if( !ec.getProgram().containsFunctionProgramBlock(nsName,
funcName)) {
+ nsName = DMLProgram.BUILTIN_NAMESPACE;
+ if(
!ec.getProgram().containsFunctionProgramBlock(nsName, funcName2) )
compileFunctionProgramBlock(funcName, dt1,
ec.getProgram());
funcName = funcName2;
}
//obtain function block (but unoptimized version of existing
functions for correctness)
- FunctionProgramBlock fpb =
ec.getProgram().getFunctionProgramBlock(null, funcName, false);
+ FunctionProgramBlock fpb =
ec.getProgram().getFunctionProgramBlock(nsName, funcName, false);
//4. expand list arguments if needed
CPOperand[] boundInputs2 = null;
@@ -105,8 +107,8 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
boundInputs = boundInputs2;
}
- //5. call the function
- FunctionCallCPInstruction fcpi = new
FunctionCallCPInstruction(null, funcName,
+ //5. call the function (to unoptimized function)
+ FunctionCallCPInstruction fcpi = new
FunctionCallCPInstruction(nsName, funcName,
false, boundInputs, fpb.getInputParamNames(),
boundOutputNames, "eval func");
fcpi.processInstruction(ec);
@@ -136,8 +138,9 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
private static void compileFunctionProgramBlock(String name, DataType
dt, Program prog) {
//load builtin file and parse function statement block
+ String nsName = DMLProgram.BUILTIN_NAMESPACE;
Map<String,FunctionStatementBlock> fsbs = DmlSyntacticValidator
- .loadAndParseBuiltinFunction(name,
DMLProgram.DEFAULT_NAMESPACE);
+ .loadAndParseBuiltinFunction(name, nsName);
if( fsbs.isEmpty() )
throw new DMLRuntimeException("Failed to compile
function '"+name+"'.");
@@ -147,8 +150,9 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
DMLProgram dmlp = (prog.getDMLProg() != null) ?
prog.getDMLProg() :
fsbs.get(Builtins.getInternalFName(name,
dt)).getDMLProg();
for( Entry<String,FunctionStatementBlock> fsb : fsbs.entrySet()
) {
- if(
!dmlp.getDefaultFunctionDictionary().containsFunction(fsb.getKey()) ) {
- dmlp.addFunctionStatementBlock(fsb.getKey(),
fsb.getValue());
+ dmlp.createNamespace(nsName); // create namespace on
demand
+ if(
!dmlp.getBuiltinFunctionDictionary().containsFunction(fsb.getKey()) ) {
+ dmlp.addFunctionStatementBlock(nsName,
fsb.getKey(), fsb.getValue());
}
fsb.getValue().setDMLProg(dmlp);
}
@@ -183,8 +187,8 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
if( !prog.containsFunctionProgramBlock(null,
fsb.getKey(), false) ) {
FunctionProgramBlock fpb =
(FunctionProgramBlock) dmlt
.createRuntimeProgramBlock(prog,
fsb.getValue(), ConfigurationManager.getDMLConfig());
- prog.addFunctionProgramBlock(null,
fsb.getKey(), fpb, true); // optimized
- prog.addFunctionProgramBlock(null,
fsb.getKey(), fpb, false); // unoptimized -> eval
+ prog.addFunctionProgramBlock(nsName,
fsb.getKey(), fpb, true); // optimized
+ prog.addFunctionProgramBlock(nsName,
fsb.getKey(), fpb, false); // unoptimized -> eval
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTomeklinkTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTomeklinkTest.java
index a251242..761b2d1 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTomeklinkTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinTomeklinkTest.java
@@ -37,7 +37,7 @@ public class BuiltinTomeklinkTest extends AutomatedTestBase
private final static double eps = 1e-3;
private final static int rows = 53;
- private final static int cols = 6;
+ private final static int cols = 6;
@Override
public void setUp() {
@@ -49,7 +49,7 @@ public class BuiltinTomeklinkTest extends AutomatedTestBase
runTomeklinkTest(ExecType.CP);
}
- @Test
+ @Test
public void testTomeklinkSP() {
runTomeklinkTest(ExecType.SPARK);
}
@@ -66,16 +66,16 @@ public class BuiltinTomeklinkTest extends AutomatedTestBase
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-args", input("A"),
input("B"), output("C")};
- fullRScriptName = HOME + TEST_NAME + ".R";
+ fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
//generate actual dataset
- double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.7, 1);
+ double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.7,
1);
writeInputMatrixWithMTD("A", A, true);
- double[][] B = getRandomMatrix(rows, 1, 0, 1, 0.5, 1);
- B = TestUtils.round(B);
- writeInputMatrixWithMTD("B", B, true);
+ double[][] B = getRandomMatrix(rows, 1, 0, 1, 0.5, 1);
+ B = TestUtils.round(B);
+ writeInputMatrixWithMTD("B", B, true);
runTest(true, false, null, -1);
runRScript(true);
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinWinsorizeTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinWinsorizeTest.java
index cfe4392..7638d00 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinWinsorizeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinWinsorizeTest.java
@@ -31,7 +31,8 @@ import org.apache.sysds.test.TestUtils;
public class BuiltinWinsorizeTest extends AutomatedTestBase
{
- private final static String TEST_NAME = "winsorize";
+ private final static String TEST_NAME1 = "winsorize";
+ private final static String TEST_NAME2 = "winsorizeMain";
private final static String TEST_DIR = "functions/builtin/";
private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinWinsorizeTest.class.getSimpleName() + "/";
@@ -41,31 +42,42 @@ public class BuiltinWinsorizeTest extends AutomatedTestBase
@Override
public void setUp() {
- addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+ addTestConfiguration(TEST_NAME1,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,new String[]{"B"}));
+ addTestConfiguration(TEST_NAME2,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"B"}));
}
@Test
public void testWinsorizeDefaultCP() {
- runWinsorizeTest(true, ExecType.CP);
+ runWinsorizeTest(TEST_NAME1, true, ExecType.CP);
}
@Test
public void testWinsorizeDefaultSP() {
- runWinsorizeTest(true, ExecType.SPARK);
+ runWinsorizeTest(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testWinsorizeSourcedFooCP() {
+ runWinsorizeTest(TEST_NAME2, true, ExecType.CP);
+ }
+
+ @Test
+ public void testWinsorizeSourcedFooSP() {
+ runWinsorizeTest(TEST_NAME2, true, ExecType.SPARK);
}
- private void runWinsorizeTest(boolean defaultProb, ExecType instType)
+ private void runWinsorizeTest(String testname, boolean defaultProb,
ExecType instType)
{
ExecMode platformOld = setExecMode(instType);
try
{
- loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ loadTestConfiguration(getTestConfiguration(testname));
String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{"-args", input("A"),
output("B") };
- fullRScriptName = HOME + TEST_NAME + ".R";
+ fullRScriptName = HOME + TEST_NAME1 + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
//generate actual dataset
diff --git a/src/test/scripts/functions/builtin/winsorizeFoo.dml
b/src/test/scripts/functions/builtin/winsorizeFoo.dml
new file mode 100644
index 0000000..78472bd
--- /dev/null
+++ b/src/test/scripts/functions/builtin/winsorizeFoo.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] X, Boolean verbose)
+ return(Matrix[Double] R)
+{
+ while(FALSE){} #no inlining
+ if( verbose )
+ print( min(X)+" "+max(X) )
+ R = winsorize(X, verbose);
+}
diff --git a/src/test/scripts/functions/builtin/winsorizeMain.dml
b/src/test/scripts/functions/builtin/winsorizeMain.dml
new file mode 100644
index 0000000..a84d07e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/winsorizeMain.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("./src/test/scripts/functions/builtin/winsorizeFoo.dml") as ns
+
+X = read($1);
+Y = ns::foo(X, FALSE);
+write(Y, $2)