Repository: systemml Updated Branches: refs/heads/master 1b1c3fea3 -> e270960ca
[SYSTEMML-2084,2317-20] Language and compiler support paramserv builtin Closes #764. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e270960c Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e270960c Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e270960c Branch: refs/heads/master Commit: e270960ca41c7c0197373b53960cae6e7aca84ab Parents: 1b1c3fe Author: EdgarLGB <[email protected]> Authored: Sat May 19 13:56:26 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat May 19 14:35:45 2018 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 5 +- .../sysml/hops/ParameterizedBuiltinOp.java | 19 ++- .../sysml/hops/ipa/FunctionCallGraph.java | 39 ++++-- .../hops/ipa/IPAPassRemoveUnusedFunctions.java | 2 +- .../sysml/hops/rewrite/HopRewriteUtils.java | 4 + .../apache/sysml/lops/ParameterizedBuiltin.java | 9 +- .../org/apache/sysml/parser/DMLTranslator.java | 3 +- .../org/apache/sysml/parser/Expression.java | 1 + .../ParameterizedBuiltinFunctionExpression.java | 128 +++++++++++++++---- .../java/org/apache/sysml/parser/Statement.java | 36 +++++- .../parfor/opt/OptimizerRuleBased.java | 4 +- .../functionobjects/ParameterizedBuiltin.java | 6 +- .../instructions/CPInstructionParser.java | 5 +- .../cp/ParameterizedBuiltinCPInstruction.java | 27 ++-- .../cp/ParamservBuiltinCPInstruction.java | 41 ++++++ .../ParameterizedBuiltinSPInstruction.java | 77 +++++------ .../test/integration/AutomatedTestBase.java | 29 ++++- .../functions/paramserv/ParamservFuncTest.java | 100 +++++++++++++++ .../functions/paramserv/paramserv-all-args.dml | 43 +++++++ .../functions/paramserv/paramserv-ipa-test.dml | 47 +++++++ .../functions/paramserv/paramserv-miss-args.dml | 42 ++++++ .../paramserv-without-optional-args.dml | 48 +++++++ .../paramserv/paramserv-wrong-args.dml | 41 ++++++ .../paramserv/paramserv-wrong-args2.dml | 42 ++++++ .../paramserv/paramserv-wrong-named-args.dml | 41 ++++++ .../paramserv/paramserv-wrong-type-args.dml | 40 ++++++ .../functions/paramserv/ZPackageSuite.java | 36 ++++++ 27 files changed, 806 insertions(+), 109 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index a42b7ab..7b0ac5b 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1088,7 +1088,7 @@ public abstract class Hop implements ParseInfo // Operations that require a variable number of operands public enum OpOpN { - PRINTF, CBIND, RBIND, EVAL, LIST, + PRINTF, CBIND, RBIND, EVAL, LIST } public enum AggOp { @@ -1117,7 +1117,7 @@ public abstract class Hop implements ParseInfo INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA, - TOSTRING, LIST, + TOSTRING, LIST, PARAMSERV } public enum FileFormatTypes { @@ -1428,6 +1428,7 @@ public abstract class Hop implements ParseInfo HopsParameterizedBuiltinLops.put(ParamBuiltinOp.TRANSFORMMETA, ParameterizedBuiltin.OperationTypes.TRANSFORMMETA); HopsParameterizedBuiltinLops.put(ParamBuiltinOp.TOSTRING, ParameterizedBuiltin.OperationTypes.TOSTRING); HopsParameterizedBuiltinLops.put(ParamBuiltinOp.LIST, ParameterizedBuiltin.OperationTypes.LIST); + HopsParameterizedBuiltinLops.put(ParamBuiltinOp.PARAMSERV, ParameterizedBuiltin.OperationTypes.PARAMSERV); } protected static final HashMap<Hop.OpOp2, String> HopsOpOp2String; http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java index e287b20..82e9f12 100644 --- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java @@ -193,6 +193,7 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop case TRANSFORMCOLMAP: case TRANSFORMMETA: case TOSTRING: + case PARAMSERV: case LIST: { ExecType et = optFindExecType(); ParameterizedBuiltin pbilop = new ParameterizedBuiltin(inputlops, @@ -1063,15 +1064,19 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop checkAndSetInvalidCPDimsAndSize(); } - //force CP for in-memory only transform builtins - if( (_op == ParamBuiltinOp.TRANSFORMAPPLY && REMOTE==ExecType.MR) - || _op == ParamBuiltinOp.TRANSFORMDECODE && REMOTE==ExecType.MR - || _op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA - || _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST - || _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF) { + // 1. Force CP for in-memory only transform builtins. + // 2. For paramserv function, always be CP mode so that + // the parameter server could have a central instruction + // to determine the local or remote workers + if ((_op == ParamBuiltinOp.TRANSFORMAPPLY && REMOTE == ExecType.MR) + || _op == ParamBuiltinOp.TRANSFORMDECODE && REMOTE == ExecType.MR + || _op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA + || _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST + || _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF + || _op == ParamBuiltinOp.PARAMSERV) { _etype = ExecType.CP; } - + //mark for recompile (forever) setRequiresRecompileIfNecessary(); http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java index 4268784..c4e11db 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map.Entry; import java.util.Set; @@ -64,8 +65,10 @@ public class FunctionCallGraph //subset of direct or indirect recursive functions private final HashSet<String> _fRecursive; - - private final boolean _containsEval; + + // a boolean value to indicate if exists the second order function (e.g. eval, paramserv) + // and the UDFs that are marked secondorder="true" + private final boolean _containsSecondOrder; /** * Constructs the function call graph for all functions @@ -78,7 +81,7 @@ public class FunctionCallGraph _fCalls = new HashMap<>(); _fCallsSB = new HashMap<>(); _fRecursive = new HashSet<>(); - _containsEval = constructFunctionCallGraph(prog); + _containsSecondOrder = constructFunctionCallGraph(prog); } /** @@ -92,7 +95,7 @@ public class FunctionCallGraph _fCalls = new HashMap<>(); _fCallsSB = new HashMap<>(); _fRecursive = new HashSet<>(); - _containsEval = constructFunctionCallGraph(sb); + _containsSecondOrder = constructFunctionCallGraph(sb); } /** @@ -240,13 +243,13 @@ public class FunctionCallGraph /** * Indicates if the function call graph, i.e., functions that are transitively - * reachable from the main program, contains a second-order eval call, which - * prohibits the removal of unused functions. - * - * @return true if the function call graph contains an eval call. + * reachable from the main program, contains a second-order builtin function call + * (e.g., eval, paramserv), which prohibits the removal of unused functions. + * + * @return true if the function call graph contains a second-order builtin function call. */ - public boolean containsEvalCall() { - return _containsEval; + public boolean containsSecondOrderCall() { + return _containsSecondOrder; } private boolean constructFunctionCallGraph(DMLProgram prog) { @@ -311,6 +314,20 @@ public class FunctionCallGraph ArrayList<Hop> hopsDAG = sb.getHops(); if( hopsDAG == null || hopsDAG.isEmpty() ) return false; //nothing to do + + // BFS traverse the dag to find paramserv operator + // which can occur anyway in the entire dag + LinkedList<Hop> queue = new LinkedList<>(hopsDAG); + while (!queue.isEmpty()) { + Hop h = queue.poll(); + if (h.isVisited()) + continue; + if (HopRewriteUtils.isParameterBuiltinOp(h, Hop.ParamBuiltinOp.PARAMSERV)) + return true; + if (!h.getInput().isEmpty()) + queue.addAll(h.getInput()); + h.setVisited(); + } //function ops can only occur as root nodes of the dag for( Hop h : hopsDAG ) { @@ -366,7 +383,7 @@ public class FunctionCallGraph } } else if( HopRewriteUtils.isData(h, DataOpTypes.TRANSIENTWRITE) - && HopRewriteUtils.isNary(h.getInput().get(0), OpOpN.EVAL) ) { + && HopRewriteUtils.isNary(h.getInput().get(0), OpOpN.EVAL) ) { //NOTE: after RewriteSplitDagDataDependentOperators, eval operators //will always appear as childs to root nodes which allows for an //efficient existence check without DAG traversal. http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java index 1420ca6..4feb06e 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java @@ -40,7 +40,7 @@ public class IPAPassRemoveUnusedFunctions extends IPAPass @Override public boolean isApplicable(FunctionCallGraph fgraph) { return InterProceduralAnalysis.REMOVE_UNUSED_FUNCTIONS - && !fgraph.containsEvalCall(); + && !fgraph.containsSecondOrderCall(); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 6da1b7a..bfc9d40 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -1018,6 +1018,10 @@ public class HopRewriteUtils public static boolean isSumSq(Hop hop) { return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM_SQ); } + + public static boolean isParameterBuiltinOp(Hop hop, ParamBuiltinOp type) { + return hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp().equals(type); + } public static boolean isNary(Hop hop, OpOpN type) { return hop instanceof NaryOp && ((NaryOp)hop).getOp()==type; http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java index 898d875..c1c97bb 100644 --- a/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysml/lops/ParameterizedBuiltin.java @@ -38,7 +38,7 @@ public class ParameterizedBuiltin extends Lop public enum OperationTypes { CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA, - TOSTRING, LIST + TOSTRING, LIST, PARAMSERV } private OperationTypes _operation; @@ -233,6 +233,13 @@ public class ParameterizedBuiltin extends Lop sb.append(compileGenericParamMap(_inputParams)); break; } + + case PARAMSERV: { + sb.append("paramserv"); + sb.append(OPERAND_DELIMITOR); + sb.append(compileGenericParamMap(_inputParams)); + break; + } default: throw new LopsException(this.printErrorLocation() + "In ParameterizedBuiltin Lop, Unknown operation: " + _operation); http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index cc7a211..ff02012 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2085,6 +2085,7 @@ public class DMLTranslator case TRANSFORMDECODE: case TRANSFORMCOLMAP: case TRANSFORMMETA: + case PARAMSERV: currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.valueOf(source.getOpCode().name()), paramHops); break; @@ -2111,7 +2112,7 @@ public class DMLTranslator currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.LIST, paramHops); break; - + default: throw new ParseException(source.printErrorLocation() + "processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + source.getOpCode()); http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index 66f08c5..fd3f855 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -161,6 +161,7 @@ public abstract class Expression implements ParseInfo TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMENCODE, TRANSFORMCOLMAP, TRANSFORMMETA, TOSTRING, // The "toString" method for DML; named arguments accepted to format output LIST, // named argument lists; unnamed lists become builtin function + PARAMSERV, INVALID } http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java index ffc8bc6..3d74f8d 100644 --- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -82,6 +83,8 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier // toString opcodeMap.put("toString", Expression.ParameterizedBuiltinFunctionOp.TOSTRING); opcodeMap.put("list", Expression.ParameterizedBuiltinFunctionOp.LIST); + + opcodeMap.put("paramserv", Expression.ParameterizedBuiltinFunctionOp.PARAMSERV); } public static HashMap<Expression.ParameterizedBuiltinFunctionOp, ParamBuiltinOp> pbHopMap; @@ -265,7 +268,11 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier case LIST: validateNamedList(output, conditional); break; - + + case PARAMSERV: + validateParamserv(output, conditional); + break; + default: //always unconditional (because unsupported operation) //handle common issue of transformencode if( getOpCode()==ParameterizedBuiltinFunctionOp.TRANSFORMENCODE ) @@ -307,7 +314,70 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier raiseValidateError("Unsupported parameterized function "+ getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); } } - + + + private void validateParamserv(DataIdentifier output, boolean conditional) { + String fname = getOpCode().name(); + // validate the first five arguments + if (getVarParams().size() < 1) { + raiseValidateError("Should provide more arguments for function " + fname, false, LanguageErrorCodes.INVALID_PARAMETERS); + } + //check for invalid parameters + Set<String> valid = UtilFunctions.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING); + checkInvalidParameters(getOpCode(), getVarParams(), valid); + + // check existence and correctness of parameters + checkDataType(fname, Statement.PS_MODEL, DataType.LIST, conditional); // check the model which is the only non-parameterized argument + checkDataType(fname, Statement.PS_FEATURES, DataType.MATRIX, conditional); + checkDataType(fname, Statement.PS_LABELS, DataType.MATRIX, conditional); + checkDataType(fname, Statement.PS_VAL_FEATURES, DataType.MATRIX, conditional); + checkDataType(fname, Statement.PS_VAL_LABELS, DataType.MATRIX, conditional); + checkDataValueType(false, fname, Statement.PS_UPDATE_FUN, DataType.SCALAR, ValueType.STRING, conditional); + checkDataValueType(false, fname, Statement.PS_AGGREGATION_FUN, DataType.SCALAR, ValueType.STRING, conditional); + Set<String> modes = Arrays.stream(Statement.PSModeType.values()).map(Enum::name) + .collect(Collectors.toSet()); + checkStringParam(false, fname, Statement.PS_MODE, modes, conditional); + Set<String> utypes = Arrays.stream(Statement.PSUpdateType.values()).map(Enum::name) + .collect(Collectors.toSet()); + checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, utypes, conditional); + Set<String> frequencies = Arrays.stream(Statement.PSFrequency.values()).map(Enum::name).collect(Collectors.toSet()); + checkStringParam(false, fname, Statement.PS_FREQUENCY, frequencies, conditional); + checkDataValueType(false, fname, Statement.PS_EPOCHS, DataType.SCALAR, ValueType.INT, conditional); + checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT, conditional); + checkDataValueType(false, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT, conditional); + Set<String> schemes = Arrays.stream(Statement.PSScheme.values()).map(Enum::name).collect(Collectors.toSet()); + checkStringParam(false, fname, Statement.PS_SCHEME, schemes, conditional); + checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional); + Set<String> checkpointings = Arrays.stream(Statement.PSCheckpointing.values()).map(Enum::name).collect(Collectors.toSet()); + checkStringParam(true, fname, Statement.PS_CHECKPOINTING, checkpointings, conditional); + + // set output characteristics + output.setDataType(DataType.LIST); + output.setValueType(ValueType.UNKNOWN); + output.setDimensions(getVarParam(Statement.PS_MODEL).getOutput().getDim1(), 1); + output.setBlockDimensions(-1, -1); + } + + private void checkStringParam(boolean optional, String fname, String pname, Set<String> validOptions, boolean conditional) { + Expression param = getVarParam(pname); + if (param == null) { + if (optional) { + return; + } + raiseValidateError(String.format("Function %s should provide parameter '%s'", fname, pname), conditional); + } + if (!(param.getOutput().getDataType().isScalar() && param.getOutput().getValueType().equals(ValueType.STRING))) { + raiseValidateError( + String.format("Function %s should provide a string value for %s parameter.", fname, pname), + conditional); + } + StringIdentifier si = (StringIdentifier) param; + if (!validOptions.contains(si.getValue())) { + raiseValidateError(String.format("Function %s does not support value '%s' as the '%s' parameter.", fname, + si.getValue(), pname), conditional, LanguageErrorCodes.INVALID_PARAMETERS); + } + } + // example: A = transformapply(target=X, meta=M, spec=s) private void validateTransformApply(DataIdentifier output, boolean conditional) { @@ -316,7 +386,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier checkDataType("transformapply", TF_FN_PARAM_MTD2, DataType.FRAME, conditional); //validate specification - checkDataValueType("transformapply", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); + checkDataValueType(false, "transformapply", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); validateTransformSpec(TF_FN_PARAM_SPEC, conditional); //set output dimensions @@ -332,7 +402,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier checkDataType("transformdecode", TF_FN_PARAM_MTD2, DataType.FRAME, conditional); //validate specification - checkDataValueType("transformdecode", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); + checkDataValueType(false, "transformdecode", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); validateTransformSpec(TF_FN_PARAM_SPEC, conditional); //set output dimensions @@ -348,7 +418,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier checkDataType("transformcolmap", TF_FN_PARAM_DATA, DataType.FRAME, conditional); //validate specification - checkDataValueType("transformcolmap", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); + checkDataValueType(false,"transformcolmap", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); validateTransformSpec(TF_FN_PARAM_SPEC, conditional); //set output dimensions @@ -360,11 +430,11 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier private void validateTransformMeta(DataIdentifier output, boolean conditional) { //validate specification - checkDataValueType("transformmeta", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); + checkDataValueType(false,"transformmeta", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); validateTransformSpec(TF_FN_PARAM_SPEC, conditional); //validate meta data path - checkDataValueType("transformmeta", TF_FN_PARAM_MTD, DataType.SCALAR, ValueType.STRING, conditional); + checkDataValueType(false,"transformmeta", TF_FN_PARAM_MTD, DataType.SCALAR, ValueType.STRING, conditional); //set output dimensions output.setDataType(DataType.FRAME); @@ -378,7 +448,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier checkDataType("transformencode", TF_FN_PARAM_DATA, DataType.FRAME, conditional); //validate specification - checkDataValueType("transformencode", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); + checkDataValueType(false, "transformencode", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); validateTransformSpec(TF_FN_PARAM_SPEC, conditional); //set output dimensions @@ -408,11 +478,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier //check for invalid parameters Set<String> valid = UtilFunctions.asSet("target", "diag", "values"); - Set<String> invalid = _varParams.keySet().stream() - .filter(k -> !valid.contains(k)).collect(Collectors.toSet()); - if( !invalid.isEmpty() ) - raiseValidateError("Invalid parameters for " + op.name() + ": " - + Arrays.toString(invalid.toArray(new String[0])), false); + checkInvalidParameters(op, getVarParams(), valid); //check existence and correctness of arguments checkTargetParam(getVarParam("target"), conditional); @@ -663,7 +729,19 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier +"["+param.getOutput().getValueType()+"].", conditional, LanguageErrorCodes.INVALID_PARAMETERS); } } - + + private void checkInvalidParameters(ParameterizedBuiltinFunctionOp op, HashMap<String, Expression> params, + Set<String> valid) { + Set<String> invalid = params.keySet().stream().filter(k -> !valid.contains(k)).collect(Collectors.toSet()); + if (!invalid.isEmpty()) { + List<String> invalidMsg = invalid.stream().map(k -> { + String val = params.get(k).getText(); + return k == null ? val : k + "=" + val; + }).collect(Collectors.toList()); + raiseValidateError(String.format("Invalid parameters for %s: %s", op.name(), invalidMsg), false); + } + } + private void validateDistributionFunctions(DataIdentifier output, boolean conditional) { // CDF and INVCDF expects one unnamed parameter, it must be renamed as "quantile" // (i.e., we must compute P(X <= x) where x is called as "quantile" ) @@ -784,17 +862,23 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier if( data==null ) raiseValidateError("Named parameter '" + pname + "' missing. Please specify the input.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); else if( data.getOutput().getDataType() != dt ) - raiseValidateError("Input to "+fname+"::"+pname+" must be of type '"+dt.toString()+"'. It is of type '"+data.getOutput().getDataType()+"'.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); + raiseValidateError("Input to "+fname+"::"+pname+" must be of type '"+dt.toString()+"'. It should not be of type '"+data.getOutput().getDataType()+"'.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); } - private void checkDataValueType( String fname, String pname, DataType dt, ValueType vt, boolean conditional ) - { + private void checkDataValueType(boolean optional, String fname, String pname, DataType dt, ValueType vt, + boolean conditional) { Expression data = getVarParam(pname); - if( data==null ) - raiseValidateError("Named parameter '" + pname + "' missing. Please specify the input.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); - else if( data.getOutput().getDataType() != dt || data.getOutput().getValueType() != vt ) - raiseValidateError("Input to "+fname+"::"+pname+" must be of type '"+dt.toString()+"', '"+vt.toString()+"'. " - + "It is of type '"+data.getOutput().getDataType().toString()+"', '"+data.getOutput().getValueType().toString()+"'.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); + if (data == null) { + if (optional) { + return; + } + raiseValidateError(String.format("Named parameter '%s' is missing. Please specify the input.", fname), + conditional, LanguageErrorCodes.INVALID_PARAMETERS); + } else if (data.getOutput().getDataType() != dt || data.getOutput().getValueType() != vt) + raiseValidateError(String.format("Input to %s::%s must be of type '%s', '%s'.It should not be of type '%s', '%s'.", + fname, pname, dt.toString(), vt.toString(), data.getOutput().getDataType().toString(), + data.getOutput().getValueType().toString()), conditional, + LanguageErrorCodes.INVALID_PARAMETERS); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/parser/Statement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Statement.java b/src/main/java/org/apache/sysml/parser/Statement.java index 4a54df4..4853a47 100644 --- a/src/main/java/org/apache/sysml/parser/Statement.java +++ b/src/main/java/org/apache/sysml/parser/Statement.java @@ -61,7 +61,41 @@ public abstract class Statement implements ParseInfo public static final String GAGG_FN_CM = "centralmoment"; public static final String GAGG_FN_CM_ORDER = "order"; public static final String GAGG_NUM_GROUPS = "ngroups"; - + + // String constants related to parameter server builtin function + public static final String PS_MODEL = "model"; + public static final String PS_FEATURES = "features"; + public static final String PS_LABELS = "labels"; + public static final String PS_VAL_FEATURES = "val_features"; + public static final String PS_VAL_LABELS = "val_labels"; + public static final String PS_UPDATE_FUN = "upd"; + public static final String PS_AGGREGATION_FUN = "agg"; + public static final String PS_MODE = "mode"; + public enum PSModeType { + LOCAL, REMOTE_SPARK + } + public static final String PS_UPDATE_TYPE = "utype"; + public enum PSUpdateType { + BSP, ASP, SSP + } + public static final String PS_FREQUENCY = "freq"; + public enum PSFrequency { + BATCH, EPOCH + } + public static final String PS_EPOCHS = "epochs"; + public static final String PS_BATCH_SIZE = "batchsize"; + public static final String PS_PARALLELISM = "k"; + public static final String PS_SCHEME = "scheme"; + public enum PSScheme { + DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, OVERLAP_RESHUFFLE + } + public static final String PS_HYPER_PARAMS = "hyperparams"; + public static final String PS_CHECKPOINTING = "checkpointing"; + public enum PSCheckpointing { + NONE, EPOCH, EPOCH10 + } + + public abstract boolean controlStatement(); public abstract VariableSet variablesRead(); http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java index a7f5834..043e4ed 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java @@ -1313,9 +1313,9 @@ public class OptimizerRuleBased extends Optimizer Hop h = OptTreeConverter.getAbstractPlanMapping().getMappedHop(c.getID()); if( ConfigurationManager.isParallelMatrixOperations() && h instanceof MultiThreadedHop //abop, datagenop, qop, paramop - && !( h instanceof ParameterizedBuiltinOp //only paramop-grpagg + && !( h instanceof ParameterizedBuiltinOp //paramop-grpagg, rexpand, paramserv && !HopRewriteUtils.isValidOp(((ParameterizedBuiltinOp)h).getOp(), - ParamBuiltinOp.GROUPEDAGG, ParamBuiltinOp.REXPAND)) + ParamBuiltinOp.GROUPEDAGG, ParamBuiltinOp.REXPAND, ParamBuiltinOp.PARAMSERV)) && !( h instanceof UnaryOp //only unaryop-cumulativeagg && !((UnaryOp)h).isCumulativeUnaryOperation() && !((UnaryOp)h).isExpensiveUnaryOperation()) http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java b/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java index 302f680..7f4345e 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/ParameterizedBuiltin.java @@ -45,7 +45,7 @@ public class ParameterizedBuiltin extends ValueFunction public enum ParameterizedBuiltinCode { CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, - TRANSFORMAPPLY, TRANSFORMDECODE } + TRANSFORMAPPLY, TRANSFORMDECODE, PARAMSERV } public enum ProbabilityDistributionCode { INVALID, NORMAL, EXP, CHISQ, F, T } @@ -64,6 +64,7 @@ public class ParameterizedBuiltin extends ValueFunction String2ParameterizedBuiltinCode.put( "rexpand", ParameterizedBuiltinCode.REXPAND); String2ParameterizedBuiltinCode.put( "transformapply", ParameterizedBuiltinCode.TRANSFORMAPPLY); String2ParameterizedBuiltinCode.put( "transformdecode", ParameterizedBuiltinCode.TRANSFORMDECODE); + String2ParameterizedBuiltinCode.put( "paramserv", ParameterizedBuiltinCode.PARAMSERV); } static public HashMap<String, ProbabilityDistributionCode> String2DistCode; @@ -178,6 +179,9 @@ public class ParameterizedBuiltin extends ValueFunction case TRANSFORMDECODE: return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORMDECODE); + + case PARAMSERV: + return new ParameterizedBuiltin(ParameterizedBuiltinCode.PARAMSERV); default: throw new DMLRuntimeException("Invalid parameterized builtin code: " + code); http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 16db227..82d4418 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -186,6 +186,7 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "list", CPType.BuiltinNary); // Parameterized Builtin Functions + String2CPInstructionType.put("paramserv", CPType.ParameterizedBuiltin); String2CPInstructionType.put( "nvlist", CPType.ParameterizedBuiltin); String2CPInstructionType.put( "cdf", CPType.ParameterizedBuiltin); String2CPInstructionType.put( "invcdf", CPType.ParameterizedBuiltin); @@ -362,8 +363,8 @@ public class CPInstructionParser extends InstructionParser case External: return FunctionCallCPInstruction.parseInstruction(str); - - case ParameterizedBuiltin: + + case ParameterizedBuiltin: return ParameterizedBuiltinCPInstruction.parseInstruction(str); case MultiReturnParameterizedBuiltin: http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 8fac54c..b2506b8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -127,26 +127,23 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction Operator op = GroupedAggregateInstruction.parseGroupedAggOperator(fnStr, paramsMap.get("order")); return new ParameterizedBuiltinCPInstruction(op, paramsMap, out, opcode, str); - } - else if( opcode.equalsIgnoreCase("rmempty") - || opcode.equalsIgnoreCase("replace") + } else if (opcode.equalsIgnoreCase("rmempty") + || opcode.equalsIgnoreCase("replace") || opcode.equalsIgnoreCase("rexpand") || opcode.equalsIgnoreCase("lowertri") - || opcode.equalsIgnoreCase("uppertri")) - { + || opcode.equalsIgnoreCase("uppertri")) { func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode); return new ParameterizedBuiltinCPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str); - } - else if ( opcode.equals("transformapply") - || opcode.equals("transformdecode") - || opcode.equals("transformcolmap") - || opcode.equals("transformmeta") - || opcode.equals("toString") - || opcode.equals("nvlist") ) - { + } else if (opcode.equals("transformapply") + || opcode.equals("transformdecode") + || opcode.equals("transformcolmap") + || opcode.equals("transformmeta") + || opcode.equals("toString") + || opcode.equals("nvlist")) { return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str); - } - else { + } else if ("paramserv".equals(opcode)) { + return new ParamservBuiltinCPInstruction(null, paramsMap, out, opcode, str); + } else { throw new DMLRuntimeException("Unknown opcode (" + opcode + ") for ParameterizedBuiltin Instruction."); } http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java new file mode 100644 index 0000000..ddc56ae --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import java.util.LinkedHashMap; + +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.matrix.operators.Operator; + +public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction { + + protected ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, + String opcode, String istr) { + super(op, paramsMap, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + ListObject model = (ListObject) ec.getVariable(getParam(Statement.PS_MODEL)); + ListObject outList = model.slice(0, model.getLength() - 1); + ec.setVariable(output.getName(), outList); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java index 425dbd3..f9d7ef3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java @@ -89,8 +89,8 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction // removeEmpty-specific attributes private boolean _bRmEmptyBC = false; - private ParameterizedBuiltinSPInstruction(Operator op, HashMap<String, String> paramsMap, CPOperand out, - String opcode, String istr, boolean bRmEmptyBC) { + ParameterizedBuiltinSPInstruction(Operator op, HashMap<String, String> paramsMap, CPOperand out, String opcode, + String istr, boolean bRmEmptyBC) { super(SPType.ParameterizedBuiltin, op, null, null, out, opcode, istr); params = paramsMap; _bRmEmptyBC = bRmEmptyBC; @@ -142,7 +142,6 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction // determine the appropriate value function ValueFunction func = null; - if ( opcode.equalsIgnoreCase("groupedagg")) { // check for mandatory arguments String fnStr = paramsMap.get("fn"); @@ -152,26 +151,20 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction if ( paramsMap.get("order") == null ) throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate."); } - Operator op = GroupedAggregateInstruction.parseGroupedAggOperator(fnStr, paramsMap.get("order")); return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false); - } - else if( opcode.equalsIgnoreCase("rmempty") ) - { - boolean bRmEmptyBC = false; - if(parts.length > 6) - bRmEmptyBC = Boolean.parseBoolean(parts[5]); - + } + else if (opcode.equalsIgnoreCase("rmempty")) { func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode); - return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, bRmEmptyBC); + return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, + parts.length > 6 ? Boolean.parseBoolean(parts[5]) : false); } - else if( opcode.equalsIgnoreCase("rexpand") - || opcode.equalsIgnoreCase("replace") - || opcode.equalsIgnoreCase("lowertri") - || opcode.equalsIgnoreCase("uppertri") - || opcode.equalsIgnoreCase("transformapply") - || opcode.equalsIgnoreCase("transformdecode")) - { + else if (opcode.equalsIgnoreCase("rexpand") + || opcode.equalsIgnoreCase("replace") + || opcode.equalsIgnoreCase("lowertri") + || opcode.equalsIgnoreCase("uppertri") + || opcode.equalsIgnoreCase("transformapply") + || opcode.equalsIgnoreCase("transformdecode")) { func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode); return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false); } @@ -190,7 +183,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction //opcode guaranteed to be a valid opcode (see parsing) if( opcode.equalsIgnoreCase("mapgroupedagg") ) - { + { //get input rdd handle String targetVar = params.get(Statement.GAGG_TARGET); String groupsVar = params.get(Statement.GAGG_GROUPS); @@ -214,21 +207,21 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction //multi-block aggregation else { //execute map grouped aggregate - JavaPairRDD<MatrixIndexes, MatrixBlock> out = - target.flatMapToPair(new RDDMapGroupedAggFunction(groups, _optr, - ngroups, mc1.getRowsPerBlock(), mc1.getColsPerBlock())); + JavaPairRDD<MatrixIndexes, MatrixBlock> out = + target.flatMapToPair(new RDDMapGroupedAggFunction(groups, _optr, + ngroups, mc1.getRowsPerBlock(), mc1.getColsPerBlock())); out = RDDAggregateUtils.sumByKeyStable(out, false); //updated characteristics and handle outputs mcOut.set(ngroups, mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock(), -1); - sec.setRDDHandleForVariable(output.getName(), out); + sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD( output.getName(), targetVar ); - sec.addLineageBroadcast( output.getName(), groupsVar ); + sec.addLineageBroadcast( output.getName(), groupsVar ); } } - else if ( opcode.equalsIgnoreCase("groupedagg") ) - { + else if ( opcode.equalsIgnoreCase("groupedagg") ) + { boolean broadcastGroups = Boolean.parseBoolean(params.get("broadcast")); //get input rdd handle @@ -256,7 +249,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction } groupWeightedCells = groups.join(target).join(weights) - .flatMapToPair(new ExtractGroupNWeights()); + .flatMapToPair(new ExtractGroupNWeights()); } else //input vector or matrix { @@ -267,7 +260,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction if( broadcastGroups ) { PartitionedBroadcast<MatrixBlock> pbm = sec.getBroadcastForVariable(groupsVar); groupWeightedCells = target - .flatMapToPair(new ExtractGroupBroadcast(pbm, mc1.getColsPerBlock(), ngroups, _optr)); + .flatMapToPair(new ExtractGroupBroadcast(pbm, mc1.getColsPerBlock(), ngroups, _optr)); } else { //general case @@ -278,7 +271,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction } groupWeightedCells = groups.join(target) - .flatMapToPair(new ExtractGroupJoin(mc1.getColsPerBlock(), ngroups, _optr)); + .flatMapToPair(new ExtractGroupJoin(mc1.getColsPerBlock(), ngroups, _optr)); } } @@ -293,20 +286,20 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator() || _optr instanceof AggregateOperator ) { out = groupWeightedCells.reduceByKey(new PerformGroupByAggInCombiner(_optr)) - .mapValues(new CreateMatrixCell(brlen, _optr)); + .mapValues(new CreateMatrixCell(brlen, _optr)); } else { // Use groupby key because partial aggregation is not supported out = groupWeightedCells.groupByKey() - .mapValues(new PerformGroupByAggInReducer(_optr)) - .mapValues(new CreateMatrixCell(brlen, _optr)); + .mapValues(new PerformGroupByAggInReducer(_optr)) + .mapValues(new CreateMatrixCell(brlen, _optr)); } // Step 4: Set output characteristics and rdd handle setOutputCharacteristicsForGroupedAgg(mc1, mcOut, out); //store output rdd handle - sec.setRDDHandleForVariable(output.getName(), out); + sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD( output.getName(), params.get(Statement.GAGG_TARGET) ); sec.addLineage( output.getName(), groupsVar, broadcastGroups ); if ( params.get(Statement.GAGG_WEIGHTS) != null ) { @@ -425,7 +418,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction //repartition input vector for higher degree of parallelism //(avoid scenarios where few input partitions create huge outputs) - MatrixCharacteristics mcTmp = new MatrixCharacteristics(dirRows?lmaxVal:mcIn.getRows(), + MatrixCharacteristics mcTmp = new MatrixCharacteristics(dirRows?lmaxVal:mcIn.getRows(), dirRows?mcIn.getRows():lmaxVal, (int)brlen, (int)bclen, mcIn.getRows()); int numParts = (int)Math.min(SparkUtils.getNumPreferredPartitions(mcTmp, in), mcIn.getNumBlocks()); if( numParts > in.getNumPartitions()*2 ) @@ -434,7 +427,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction //execute rexpand rows/cols operation (no shuffle required because outputs are //block-aligned with the input, i.e., one input block generates n output blocks) JavaPairRDD<MatrixIndexes,MatrixBlock> out = in - .flatMapToPair(new RDDRExpandFunction(maxVal, dirRows, cast, ignore, brlen, bclen)); + .flatMapToPair(new RDDRExpandFunction(maxVal, dirRows, cast, ignore, brlen, bclen)); //store output rdd handle sec.setRDDHandleForVariable(output.getName(), out); @@ -454,7 +447,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(params.get("target")); MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); String[] colnames = !TfMetaUtils.isIDSpec(params.get("spec")) ? - in.lookup(1L).get(0).getColumnNames() : null; + in.lookup(1L).get(0).getColumnNames() : null; //compute omit offset map for block shifts TfOffsetMap omap = null; @@ -462,19 +455,19 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction omap = new TfOffsetMap(SparkUtils.toIndexedLong(in.mapToPair( new RDDTransformApplyOffsetFunction(params.get("spec"), colnames)).collect())); } - + //create encoder broadcast (avoiding replication per task) Encoder encoder = EncoderFactory.createEncoder(params.get("spec"), colnames, - fo.getSchema(), (int)fo.getNumColumns(), meta); + fo.getSchema(), (int)fo.getNumColumns(), meta); mcOut.setDimension(mcIn.getRows()-((omap!=null)?omap.getNumRmRows():0), encoder.getNumCols()); Broadcast<Encoder> bmeta = sec.getSparkContext().broadcast(encoder); Broadcast<TfOffsetMap> bomap = (omap!=null) ? sec.getSparkContext().broadcast(omap) : null; //execute transform apply JavaPairRDD<Long,FrameBlock> tmp = in - .mapToPair(new RDDTransformApplyFunction(bmeta, bomap)); + .mapToPair(new RDDTransformApplyFunction(bmeta, bomap)); JavaPairRDD<MatrixIndexes,MatrixBlock> out = FrameRDDConverterUtils - .binaryBlockToMatrixBlock(tmp, mcOut, mcOut); + .binaryBlockToMatrixBlock(tmp, mcOut, mcOut); //set output and maintain lineage/output characteristics sec.setRDDHandleForVariable(output.getName(), out); @@ -749,7 +742,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction { //get all inputs MatrixIndexes ix = arg0._1(); - MatrixBlock target = arg0._2(); + MatrixBlock target = arg0._2(); MatrixBlock groups = _pbm.getBlock((int)ix.getRowIndex(), 1); //execute map grouped aggregate operations http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index 48f6428..47ea66e 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -1105,7 +1105,7 @@ public abstract class AutomatedTestBase * -1 there is no limit. */ protected void runTest(boolean exceptionExpected, Class<?> expectedException, int maxMRJobs) { - runTest(false, exceptionExpected, expectedException, maxMRJobs); + runTest(false, exceptionExpected, expectedException, null, maxMRJobs); } /** @@ -1125,6 +1125,29 @@ public abstract class AutomatedTestBase * -1 there is no limit. */ protected void runTest(boolean newWay, boolean exceptionExpected, Class<?> expectedException, int maxMRJobs) { + runTest(newWay, exceptionExpected, expectedException, null, maxMRJobs); + } + + /** + * <p> + * Runs a test for which the exception expectation and the error message + * can be specified as well as the specific expectation which is expected. + * If SystemML executes more MR jobs than specified in maxMRJobs this test + * will fail. + * </p> + * @param newWay + * in the new way if it is set to true + * @param exceptionExpected + * exception expected + * @param expectedException + * expected exception + * @param errMessage + * expected error message + * @param maxMRJobs + * specifies a maximum limit for the number of MR jobs. If set to + * -1 there is no limit. + */ + protected void runTest(boolean newWay, boolean exceptionExpected, Class<?> expectedException, String errMessage, int maxMRJobs) { String executionFile = sourceDirectory + selectedTest + ".dml"; @@ -1227,6 +1250,10 @@ public abstract class AutomatedTestBase if (exceptionExpected) fail("expected exception which has not been raised: " + expectedException); } catch (Exception e) { + if (exceptionExpected && e.getClass().equals(expectedException) && errMessage != null + && !e.getMessage().contains(errMessage)) { + fail("expected exception message has not been raised: " + errMessage); + } if (!exceptionExpected || (expectedException != null && !(e.getClass().equals(expectedException)))) { e.printStackTrace(); StringBuilder errorMessage = new StringBuilder(); http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java new file mode 100644 index 0000000..1b227f1 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.paramserv; + +import org.apache.sysml.api.DMLException; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.junit.Test; + +public class ParamservFuncTest extends AutomatedTestBase { + + private static final String TEST_NAME1 = "paramserv-all-args"; + private static final String TEST_NAME2 = "paramserv-without-optional-args"; + private static final String TEST_NAME3 = "paramserv-miss-args"; + private static final String TEST_NAME4 = "paramserv-wrong-type-args"; + private static final String TEST_NAME5 = "paramserv-wrong-args"; + private static final String TEST_NAME6 = "paramserv-wrong-args2"; + private static final String TEST_NAME7 = "paramserv-ipa-test"; + + private static final String TEST_DIR = "functions/paramserv/"; + private static final String TEST_CLASS_DIR = TEST_DIR + ParamservFuncTest.class.getSimpleName() + "/"; + + private final String HOME = SCRIPT_DIR + TEST_DIR; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {})); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {})); + addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {})); + addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {})); + addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {})); + } + + @Test + public void testParamservWithAllArgs() { + runDMLTest(TEST_NAME1, true, false, null, null); + } + + @Test + public void testParamservWithoutOptionalArgs() { + runDMLTest(TEST_NAME2, true, false, null, null); + } + + @Test + public void testParamservMissArgs() { + final String errmsg = "Named parameter 'features' missing. Please specify the input."; + runDMLTest(TEST_NAME3, true, true, DMLException.class, errmsg); + } + + @Test + public void testParamservWrongTypeArgs() { + final String errmsg = "Input to PARAMSERV::model must be of type 'LIST'. It should not be of type 'MATRIX'"; + runDMLTest(TEST_NAME4, true, true, DMLException.class, errmsg); + } + + @Test + public void testParamservWrongArgs() { + final String errmsg = "Function PARAMSERV does not support value 'NSP' as the 'utype' parameter."; + runDMLTest(TEST_NAME5, true, true, DMLException.class, errmsg); + } + + @Test + public void testParamservWrongArgs2() { + final String errmsg = "Invalid parameters for PARAMSERV: [modelList, val_featur=X_val]"; + runDMLTest(TEST_NAME6, true, true, DMLException.class, errmsg); + } + + @Test + public void testParamservIpaTest() { + runDMLTest(TEST_NAME7, true, false, null, "1"); + } + + private void runDMLTest(String testname, boolean newWay, boolean exceptionExpected, Class<?> exceptionClass, + String errmsg) { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + programArgs = new String[] { "-explain" }; + fullDMLScriptName = HOME + testname + ".dml"; + runTest(newWay, exceptionExpected, exceptionClass, errmsg, -1); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/scripts/functions/paramserv/paramserv-all-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-all-args.dml b/src/test/scripts/functions/paramserv/paramserv-all-args.dml new file mode 100644 index 0000000..bcb3ac3 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-all-args.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +paramsList = list(e1) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +e2 = "element2" +hps = list(e2) + +# Use paramserv function +paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE") + +print(length(paramsList2)) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml b/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml new file mode 100644 index 0000000..5aed767 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml @@ -0,0 +1,47 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +paramsList = list(e1) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +e2 = "element2" +hps = list(e2) + +# Use paramserv function +paramsList2 = list(1, 2, 3) + +if (length(paramsList2) == 3) { + paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE") +} + +print(length(paramsList2)) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/scripts/functions/paramserv/paramserv-miss-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-miss-args.dml b/src/test/scripts/functions/paramserv/paramserv-miss-args.dml new file mode 100644 index 0000000..f3a2c91 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-miss-args.dml @@ -0,0 +1,42 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +modelList = list(e1) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +e2 = "element2" +params = list(e2) + +# Use paramserv function +# Miss "features" parameterized argument +modelList2 = paramserv(model=modelList, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml b/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml new file mode 100644 index 0000000..c504303 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml @@ -0,0 +1,48 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +modelList = list(e1) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +e2 = "element2" +params = list(e2) + +# Use paramserv function +# Remove the optional "hyperparams" +modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH") + +# Remove the optional "batchsize" +modelList3 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="ASP", freq="BATCH", epochs=100, k=7, scheme="DISJOINT_RANDOM", hyperparams=params, checkpointing="NONE") + +# Remove the optional "checkpointing" +modelList4 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="SSP", freq="EPOCH", batchsize=64, epochs=100, k=7, scheme="OVERLAP_RESHUFFLE", hyperparams=params) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml b/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml new file mode 100644 index 0000000..13a05c9 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml @@ -0,0 +1,41 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +modelList = list(e1) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +e2 = "element2" +params = list(e2) + +# Use paramserv function +modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="NSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/scripts/functions/paramserv/paramserv-wrong-args2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-wrong-args2.dml b/src/test/scripts/functions/paramserv/paramserv-wrong-args2.dml new file mode 100644 index 0000000..4002d45 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-wrong-args2.dml @@ -0,0 +1,42 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +modelList = list(e1) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +e2 = "element2" +params = list(e2) + +# Use paramserv function +# Miss "model" parameterized argument and another wrongly named argument +modelList = paramserv(modelList, features=X, labels=Y, val_featur=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/scripts/functions/paramserv/paramserv-wrong-named-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-wrong-named-args.dml b/src/test/scripts/functions/paramserv/paramserv-wrong-named-args.dml new file mode 100644 index 0000000..b00dcff --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-wrong-named-args.dml @@ -0,0 +1,41 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +model = list(e1) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +e2 = "element2" +params = list(e2) + +# Use paramserv function +model2 = paramserv(model, labels=Y, val_features=X_val, val_label=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="disjoint_contiguous", hyperparams=params, checkpointing="NONE") \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test/scripts/functions/paramserv/paramserv-wrong-type-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-wrong-type-args.dml b/src/test/scripts/functions/paramserv/paramserv-wrong-type-args.dml new file mode 100644 index 0000000..4b09a49 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-wrong-type-args.dml @@ -0,0 +1,40 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +modelList = matrix(3, rows=1, cols=2) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input +} + +e2 = "element2" +params = list(e2) + +# Use paramserv function +modelList2 = paramserv(model=modelList, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/e270960c/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java new file mode 100644 index 0000000..ad3d526 --- /dev/null +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.paramserv; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** Group together the tests in this package into a single suite so that the Maven build + * won't run two of them at once. */ +@RunWith(Suite.class) [email protected]({ + ParamservFuncTest.class +}) + + +/** This class is just a holder for the above JUnit annotations. */ +public class ZPackageSuite { + +}
