Repository: incubator-systemml Updated Branches: refs/heads/master 9ac1d4f86 -> 4fff6f769
[SYSTEMML-853] Python API for new MLContext This adds a new Python API that targets the new MLContext API on the Java/Scala side. Closes #211. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/4fff6f76 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/4fff6f76 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/4fff6f76 Branch: refs/heads/master Commit: 4fff6f76951c42dcf902d584a26089eda27c43a8 Parents: 9ac1d4f Author: MechCoder <[email protected]> Authored: Fri Aug 19 15:46:07 2016 -0700 Committer: Mike Dusenberry <[email protected]> Committed: Fri Aug 19 15:46:07 2016 -0700 ---------------------------------------------------------------------- .../apache/sysml/api/mlcontext/MLResults.java | 108 ++++++---- .../org/apache/sysml/api/mlcontext/Script.java | 98 +++++---- src/main/python/SystemML.py | 203 +++++++++++++++++++ src/main/python/SystemMLtests.py | 87 ++++++++ 4 files changed, 412 insertions(+), 84 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4fff6f76/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java index 582a73e..289f490 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -92,7 +92,7 @@ public class MLResults { /** * Obtain an output as a {@code Data} object. - * + * * @param outputName * the name of the output * @return the output as a {@code Data} object @@ -108,7 +108,7 @@ public class MLResults { /** * Obtain an output as a {@code MatrixObject} - * + * * @param outputName * the name of the output * @return the output as a {@code MatrixObject} @@ -124,7 +124,7 @@ public class MLResults { /** * Obtain an output as a two-dimensional {@code double} array. - * + * * @param outputName * the name of the output * @return the output as a two-dimensional {@code double} array @@ -150,7 +150,7 @@ public class MLResults { * <br>2 1 3.0 * <br>2 2 4.0 * </code> - * + * * @param outputName * the name of the output * @return the output as a {@code JavaRDD<String>} in IJV format @@ -174,7 +174,7 @@ public class MLResults { * <code>1.0,2.0 * <br>3.0,4.0 * </code> - * + * * @param outputName * the name of the output * @return the output as a {@code JavaRDD<String>} in CSV format @@ -198,7 +198,7 @@ public class MLResults { * <code>1.0,2.0 * <br>3.0,4.0 * </code> - * + * * @param outputName * the name of the output * @return the output as a {@code RDD<String>} in CSV format @@ -224,7 +224,7 @@ public class MLResults { * <br>2 1 3.0 * <br>2 2 4.0 * </code> - * + * * @param outputName * the name of the output * @return the output as a {@code RDD<String>} in IJV format @@ -248,7 +248,7 @@ public class MLResults { * <code>[0.0,1.0,2.0] * <br>[1.0,3.0,4.0] * </code> - * + * * @param outputName * the name of the output * @return the output as a {@code DataFrame} of doubles @@ -258,7 +258,7 @@ public class MLResults { DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false); return df; } - + public DataFrame getDataFrame(String outputName, boolean isVectorDF) { MatrixObject mo = getMatrixObject(outputName); DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, isVectorDF); @@ -267,7 +267,7 @@ public class MLResults { /** * Obtain an output as a {@code Matrix}. - * + * * @param outputName * the name of the output * @return the output as a {@code Matrix} @@ -277,11 +277,11 @@ public class MLResults { Matrix matrix = new Matrix(mo, sparkExecutionContext); return matrix; } - + /** * Obtain an output as a {@code BinaryBlockMatrix}. - * + * * @param outputName * the name of the output * @return the output as a {@code BinaryBlockMatrix} @@ -295,7 +295,7 @@ public class MLResults { /** * Obtain an output as a two-dimensional {@code String} array. - * + * * @param outputName * the name of the output * @return the output as a two-dimensional {@code String} array @@ -320,7 +320,7 @@ public class MLResults { /** * Obtain a {@code double} output - * + * * @param outputName * the name of the output * @return the output as a {@code double} @@ -331,8 +331,26 @@ public class MLResults { } /** + * Obtain a serializable object as output + * + * @param outputName + * the name of the output + * @return the output as a serializable object. + */ + + public Object get(String outputName) { + Data data = getData(outputName); + if (data instanceof ScalarObject) { + ScalarObject so = (ScalarObject) data; + return so.getValue(); + } else { + return data; + } + } + + /** * Obtain an output as a {@code Scalar} object. - * + * * @param outputName * the name of the output * @return the output as a {@code Scalar} object @@ -348,7 +366,7 @@ public class MLResults { /** * Obtain a {@code boolean} output - * + * * @param outputName * the name of the output * @return the output as a {@code boolean} @@ -360,7 +378,7 @@ public class MLResults { /** * Obtain a {@code long} output - * + * * @param outputName * the name of the output * @return the output as a {@code long} @@ -372,7 +390,7 @@ public class MLResults { /** * Obtain a {@code String} output - * + * * @param outputName * the name of the output * @return the output as a {@code String} @@ -384,7 +402,7 @@ public class MLResults { /** * Obtain the Script object associated with these results. - * + * * @return the DML or PYDML Script object */ public Script getScript() { @@ -393,7 +411,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @return a Scala tuple @@ -405,7 +423,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -419,7 +437,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -436,7 +454,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -456,7 +474,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -478,7 +496,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -503,7 +521,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -531,7 +549,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -561,7 +579,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -594,7 +612,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -629,7 +647,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -668,7 +686,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -709,7 +727,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -752,7 +770,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -798,7 +816,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -846,7 +864,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -898,7 +916,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -952,7 +970,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -1008,7 +1026,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -1067,7 +1085,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -1128,7 +1146,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -1192,7 +1210,7 @@ public class MLResults { /** * Obtain a Scala tuple. - * + * * @param outputName1 * the name of the first output * @param outputName2 @@ -1262,7 +1280,7 @@ public class MLResults { * specific output type. MLResults tuple support requires specifying the * object types at runtime to avoid the items in the tuple being returned as * Anys. - * + * * @param outputName * the name of the output * @return the output value cast to a specific output type @@ -1289,7 +1307,7 @@ public class MLResults { /** * Obtain the symbol table, which is essentially a {@code Map<String, Data>} * representing variables and their values as SystemML representations. - * + * * @return the symbol table */ public LocalVariableMap getSymbolTable() { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4fff6f76/src/main/java/org/apache/sysml/api/mlcontext/Script.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Script.java b/src/main/java/org/apache/sysml/api/mlcontext/Script.java index 782f2c6..28667cf 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/Script.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/Script.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -97,7 +97,7 @@ public class Script { /** * Script constructor, specifying the type of script ({@code ScriptType.DML} * or {@code ScriptType.PYDML}). - * + * * @param scriptType * {@code ScriptType.DML} or {@code ScriptType.PYDML} */ @@ -108,7 +108,7 @@ public class Script { /** * Script constructor, specifying the script content. By default, the script * type is DML. - * + * * @param scriptString * the script content as a string */ @@ -120,7 +120,7 @@ public class Script { /** * Script constructor, specifying the script content and the type of script * (DML or PYDML). - * + * * @param scriptString * the script content as a string * @param scriptType @@ -133,7 +133,7 @@ public class Script { /** * Obtain the script type. - * + * * @return {@code ScriptType.DML} or {@code ScriptType.PYDML} */ public ScriptType getScriptType() { @@ -142,7 +142,7 @@ public class Script { /** * Set the type of script (DML or PYDML). - * + * * @param scriptType * {@code ScriptType.DML} or {@code ScriptType.PYDML} */ @@ -152,7 +152,7 @@ public class Script { /** * Obtain the script string. - * + * * @return the script string */ public String getScriptString() { @@ -161,7 +161,7 @@ public class Script { /** * Set the script string. - * + * * @param scriptString * the script string * @return {@code this} Script object to allow chaining of methods @@ -173,7 +173,7 @@ public class Script { /** * Obtain the input variable names as an unmodifiable set of strings. - * + * * @return the input variable names */ public Set<String> getInputVariables() { @@ -182,7 +182,7 @@ public class Script { /** * Obtain the output variable names as an unmodifiable set of strings. - * + * * @return the output variable names */ public Set<String> getOutputVariables() { @@ -192,7 +192,7 @@ public class Script { /** * Obtain the symbol table, which is essentially a * {@code HashMap<String, Data>} representing variables and their values. - * + * * @return the symbol table */ public LocalVariableMap getSymbolTable() { @@ -201,7 +201,7 @@ public class Script { /** * Obtain an unmodifiable map of all inputs (parameters ($) and variables). - * + * * @return all inputs to the script */ public Map<String, Object> getInputs() { @@ -210,7 +210,7 @@ public class Script { /** * Obtain an unmodifiable map of input matrix metadata. - * + * * @return input matrix metadata */ public Map<String, MatrixMetadata> getInputMatrixMetadata() { @@ -219,7 +219,7 @@ public class Script { /** * Pass a map of inputs to the script. - * + * * @param inputs * map of inputs (parameters ($) and variables). * @return {@code this} Script object to allow chaining of methods @@ -232,9 +232,13 @@ public class Script { return this; } + public Script input(Map<String, Object> inputs) { + return in(inputs); + } + /** * Pass a Scala Map of inputs to the script. - * + * * @param inputs * Scala Map of inputs (parameters ($) and variables). * @return {@code this} Script object to allow chaining of methods @@ -246,12 +250,16 @@ public class Script { return this; } + public Script input(scala.collection.Map<String, Object> inputs) { + return in(inputs); + } + /** * Pass a Scala Seq of inputs to the script. The inputs are either two-value * or three-value tuples, where the first value is the variable name, the * second value is the variable value, and the third optional value is the * metadata. - * + * * @param inputs * Scala Seq of inputs (parameters ($) and variables). * @return {@code this} Script object to allow chaining of methods @@ -274,9 +282,13 @@ public class Script { return this; } + public Script input(scala.collection.Seq<Object> inputs) { + return in(inputs); + } + /** * Obtain an unmodifiable map of all input parameters ($). - * + * * @return input parameters ($) */ public Map<String, Object> getInputParameters() { @@ -285,7 +297,7 @@ public class Script { /** * Register an input (parameter ($) or variable). - * + * * @param name * name of the input * @param value @@ -296,10 +308,14 @@ public class Script { return in(name, value, null); } + public Script input(String name, Object value) { + return in(name, value); + } + /** * Register an input (parameter ($) or variable) with optional matrix * metadata. - * + * * @param name * name of the input * @param value @@ -336,9 +352,13 @@ public class Script { return this; } + public Script input(String name, Object value, MatrixMetadata matrixMetadata) { + return in(name, value, matrixMetadata); + } + /** * Register an output variable. - * + * * @param outputName * name of the output variable * @return {@code this} Script object to allow chaining of methods @@ -350,7 +370,7 @@ public class Script { /** * Register output variables. - * + * * @param outputNames * names of the output variables * @return {@code this} Script object to allow chaining of methods @@ -411,7 +431,7 @@ public class Script { /** * Obtain the results of the script execution. - * + * * @return the results of the script execution. */ public MLResults results() { @@ -420,7 +440,7 @@ public class Script { /** * Obtain the results of the script execution. - * + * * @return the results of the script execution. */ public MLResults getResults() { @@ -429,7 +449,7 @@ public class Script { /** * Set the results of the script execution. - * + * * @param results * the results of the script execution. */ @@ -439,7 +459,7 @@ public class Script { /** * Obtain the script executor used by this Script. - * + * * @return the ScriptExecutor used by this Script. */ public ScriptExecutor getScriptExecutor() { @@ -448,7 +468,7 @@ public class Script { /** * Set the ScriptExecutor used by this Script. - * + * * @param scriptExecutor * the script executor */ @@ -458,7 +478,7 @@ public class Script { /** * Is the script type DML? - * + * * @return {@code true} if the script type is DML, {@code false} otherwise */ public boolean isDML() { @@ -467,7 +487,7 @@ public class Script { /** * Is the script type PYDML? - * + * * @return {@code true} if the script type is PYDML, {@code false} otherwise */ public boolean isPYDML() { @@ -477,7 +497,7 @@ public class Script { /** * Generate the script execution string, which adds read/load/write/save * statements to the beginning and end of the script to execute. - * + * * @return the script execution string */ public String getScriptExecutionString() { @@ -545,7 +565,7 @@ public class Script { * script type, inputs, outputs, input parameters, input variables, output * variables, the symbol table, the script string, and the script execution * string. - * + * * @return information about this script as a String */ public String info() { @@ -576,7 +596,7 @@ public class Script { /** * Display the script inputs. - * + * * @return the script inputs */ public String displayInputs() { @@ -585,7 +605,7 @@ public class Script { /** * Display the script outputs. - * + * * @return the script outputs as a String */ public String displayOutputs() { @@ -594,7 +614,7 @@ public class Script { /** * Display the script input parameters. - * + * * @return the script input parameters as a String */ public String displayInputParameters() { @@ -603,7 +623,7 @@ public class Script { /** * Display the script input variables. - * + * * @return the script input variables as a String */ public String displayInputVariables() { @@ -612,7 +632,7 @@ public class Script { /** * Display the script output variables. - * + * * @return the script output variables as a String */ public String displayOutputVariables() { @@ -621,7 +641,7 @@ public class Script { /** * Display the script symbol table. - * + * * @return the script symbol table as a String */ public String displaySymbolTable() { @@ -630,7 +650,7 @@ public class Script { /** * Obtain the script name. - * + * * @return the script name */ public String getName() { @@ -639,7 +659,7 @@ public class Script { /** * Set the script name. - * + * * @param name * the script name * @return {@code this} Script object to allow chaining of methods http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4fff6f76/src/main/python/SystemML.py ---------------------------------------------------------------------- diff --git a/src/main/python/SystemML.py b/src/main/python/SystemML.py new file mode 100644 index 0000000..85731ed --- /dev/null +++ b/src/main/python/SystemML.py @@ -0,0 +1,203 @@ +#!/usr/bin/python +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +import os + +from py4j.java_gateway import JavaObject +from py4j.java_collections import ListConverter, JavaArray, JavaList +from pyspark import SparkContext, RDD +from pyspark.mllib.common import _java2py, _py2java +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql import DataFrame + + +class MLResults(object): + """ + Wrapper around the Java ML Results object. + + Parameters + ---------- + results: JavaObject + A Java MLResults object as returned by calling ml.execute() + + sc: SparkContext + SparkContext + """ + def __init__(self, results, sc): + self._java_results = results + self.sc = sc + + def __repr__(self): + return "MLResults" + + def get(self, *outputs): + """ + Parameters + ---------- + outputs: string, list of strings + Output variables as defined inside the DML script. + """ + outs = [_java2py(self.sc, self._java_results.get(out)) for out in outputs] + if len(outs) == 1: + return outs[0] + return outs + + +class Script(object): + """ + Instance of a DML/PyDML Script. + + Parameters + ---------- + path: string + Can be either a file path to a DML script or a DML script itself. + """ + def __init__(self, scriptString, scriptType="dml"): + self.scriptString = scriptString + self.scriptType = scriptType + self._input = {} + self._output = [] + + def input(self, *args, **kwargs): + """ + Parameters + ---------- + args: name, value tuple + where name is a string and currently supported value formats + are double, string, rdds and list of such object. + + kwargs: dict of name, value pairs + To know what formats are supported for name and value, look above. + """ + if args and len(args) != 2: + raise ValueError("Expected name, value pair.") + elif args: + self._input[args[0]] = args[1] + for name, value in kwargs.items(): + self._input[name] = value + return self + + def out(self, *names): + """ + Parameters + ---------- + outputs: string, list of strings + Output variables as defined inside the DML script. + """ + self._output.extend(names) + return self + + +def pydml(scriptString): + """ + Create a pydml script object based on a string. + + Parameters + ---------- + scriptString: string + Can be a path to a pydml script or a pydml script itself. + + Returns + ------- + script: Script instance + Instance of a script object. + """ + if not isinstance(scriptString, str): + raise ValueError("scriptString should be a string, got %s" % type(scriptString)) + return Script(scriptString, scriptType="pydml") + + +def dml(scriptString): + """ + Create a dml script object based on a string. + + Parameters + ---------- + scriptString: string + Can be a path to a dml script or a dml script itself. + + Returns + ------- + script: Script instance + Instance of a script object. + """ + if not isinstance(scriptString, str): + raise ValueError("scriptString should be a string, got %s" % type(scriptString)) + return Script(scriptString, scriptType="dml") + + +class MLContext(object): + """ + Wrapper around the new SystemML MLContext. + + Parameters + ---------- + sc: SparkContext + SparkContext + """ + def __init__(self, sc): + if not isinstance(sc, SparkContext): + raise ValueError("Expected sc to be a SparkContext, got " % sc) + self._sc = sc + self._ml = sc._jvm.org.apache.sysml.api.mlcontext.MLContext(sc._jsc) + + def __repr__(self): + return "MLContext" + + def execute(self, script): + """ + Execute a DML / PyDML script. + + Parameters + ---------- + script: Script instance + Script instance defined with the appropriate input and output variables. + + Returns + ------- + ml_results: MLResults + MLResults instance. + """ + if not isinstance(script, Script): + raise ValueError("Expected script to be an instance of Script") + scriptString = script.scriptString + if script.scriptType == "dml": + if scriptString.endswith(".dml"): + if os.path.exists(scriptString): + script_java = self._sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString) + else: + raise ValueError("path: %s does not exist" % scriptString) + else: + script_java = self._sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString) + elif script.scriptType == "pydml": + if scriptString.endswith(".pydml"): + if os.path.exists(scriptString): + script_java = self._sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString) + else: + raise ValueError("path: %s does not exist" % scriptString) + else: + script_java = self._sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml(scriptString) + + for key, val in script._input.items(): + script_java.input(key, _py2java(self._sc, val)) + for val in script._output: + script_java.out(val) + return MLResults(self._ml.execute(script_java), self._sc) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4fff6f76/src/main/python/SystemMLtests.py ---------------------------------------------------------------------- diff --git a/src/main/python/SystemMLtests.py b/src/main/python/SystemMLtests.py new file mode 100644 index 0000000..5dcae4a --- /dev/null +++ b/src/main/python/SystemMLtests.py @@ -0,0 +1,87 @@ +#!/usr/bin/python +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +import unittest + +from pyspark.sql import SQLContext +from pyspark.context import SparkContext + +from SystemML import dml +from SystemML import pydml +from SystemML import MLContext + +sc = SparkContext() +ml = MLContext(sc) + +class TestAPI(unittest.TestCase): + + def test_output_string(self): + script = dml("x1 = 'Hello World'").out("x1") + self.assertEqual(ml.execute(script).get("x1"), "Hello World") + + def test_output_list(self): + script = """ + x1 = 0.2 + x2 = x1 + 1 + x3 = x1 + 2 + """ + script = dml(script).out("x1", "x2", "x3") + self.assertEqual(ml.execute(script).get("x1", "x2"), [0.2, 1.2]) + self.assertEqual(ml.execute(script).get("x1", "x3"), [0.2, 2.2]) + + def test_input_single(self): + script = """ + x2 = x1 + 1 + x3 = x1 + 2 + """ + script = dml(script).input("x1", 5).out("x2", "x3") + self.assertEqual(ml.execute(script).get("x2", "x3"), [6, 7]) + + def test_input(self): + script = """ + x3 = x1 + x2 + """ + script = dml(script).input(x1=5, x2=3).out("x3") + self.assertEqual(ml.execute(script).get("x3"), 8) + + def test_rdd(self): + sums = """ + s1 = sum(m1) + s2 = sum(m2) + s3 = 'whatever' + """ + rdd1 = sc.parallelize(["1.0,2.0", "3.0,4.0"]) + rdd2 = sc.parallelize(["5.0,6.0", "7.0,8.0"]) + script = dml(sums).input(m1=rdd1).input(m2=rdd2).out("s1", "s2", "s3") + self.assertEqual( + ml.execute(script).get("s1", "s2", "s3"), [10.0, 26.0, "whatever"]) + + def test_pydml(self): + script = "A = full('1 2 3 4 5 6 7 8 9', rows=3, cols=3)\nx = toString(A)" + script = pydml(script).out("x") + self.assertEqual( + ml.execute(script).get("x"), + '1.000 2.000 3.000\n4.000 5.000 6.000\n7.000 8.000 9.000\n' + ) + + +if __name__ == "__main__": + unittest.main()
