http://git-wip-us.apache.org/repos/asf/systemml/blob/9e7ee19a/src/main/python/systemml/mlcontext.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mlcontext.py b/src/main/python/systemml/mlcontext.py index c945b2b..9956815 100644 --- a/src/main/python/systemml/mlcontext.py +++ b/src/main/python/systemml/mlcontext.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------- +# ------------------------------------------------------------- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -17,19 +17,29 @@ # specific language governing permissions and limitations # under the License. # -#------------------------------------------------------------- +# ------------------------------------------------------------- # Methods to create Script object -script_factory_methods = [ 'dml', 'pydml', 'dmlFromResource', 'pydmlFromResource', 'dmlFromFile', 'pydmlFromFile', 'dmlFromUrl', 'pydmlFromUrl' ] +script_factory_methods = [ + 'dml', + 'pydml', + 'dmlFromResource', + 'pydmlFromResource', + 'dmlFromFile', + 'pydmlFromFile', + 'dmlFromUrl', + 'pydmlFromUrl'] # Utility methods -util_methods = [ '_java2py', 'getHopDAG' ] -__all__ = ['MLResults', 'MLContext', 'Script', 'Matrix' ] + script_factory_methods + util_methods +util_methods = ['_java2py', 'getHopDAG'] +__all__ = ['MLResults', 'MLContext', 'Script', 'Matrix'] + \ + script_factory_methods + util_methods import os import numpy as np import pandas as pd -import threading, time - +import threading +import time + try: import py4j.java_gateway from py4j.java_gateway import JavaObject @@ -38,12 +48,15 @@ try: import pyspark.mllib.common from pyspark.sql import SparkSession except ImportError: - raise ImportError('Unable to import `pyspark`. Hint: Make sure you are running with PySpark.') + raise ImportError( + 'Unable to import `pyspark`. Hint: Make sure you are running with PySpark.') from .converters import * from .classloader import * -def getHopDAG(ml, script, lines=None, conf=None, apply_rewrites=True, with_subgraph=False): + +def getHopDAG(ml, script, lines=None, conf=None, + apply_rewrites=True, with_subgraph=False): """ Compile a DML / PyDML script. @@ -51,39 +64,42 @@ def getHopDAG(ml, script, lines=None, conf=None, apply_rewrites=True, with_subgr ---------- ml: MLContext instance MLContext instance. - + script: Script instance Script instance defined with the appropriate input and output variables. - + lines: list of integers Optional: only display the hops that have begin and end line number equals to the given integers. - + conf: SparkConf instance Optional spark configuration - + apply_rewrites: boolean If True, perform static rewrites, perform intra-/inter-procedural analysis to propagate size information into functions and apply dynamic rewrites - + with_subgraph: boolean - If False, the dot graph will be created without subgraphs for statement blocks. - + If False, the dot graph will be created without subgraphs for statement blocks. + Returns ------- hopDAG: string - hop DAG in dot format + hop DAG in dot format """ if not isinstance(script, Script): raise ValueError("Expected script to be an instance of Script") scriptString = script.scriptString script_java = script.script_java - lines = [ int(x) for x in lines ] if lines is not None else [int(-1)] + lines = [int(x) for x in lines] if lines is not None else [int(-1)] sc = get_spark_context() if conf is not None: - hopDAG = sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, script_java, lines, conf._jconf, apply_rewrites, with_subgraph) + hopDAG = sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG( + ml._ml, script_java, lines, conf._jconf, apply_rewrites, with_subgraph) else: - hopDAG = sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, script_java, lines, apply_rewrites, with_subgraph) + hopDAG = sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG( + ml._ml, script_java, lines, apply_rewrites, with_subgraph) return hopDAG + def dml(scriptString): """ Create a dml script object based on a string. @@ -99,9 +115,12 @@ def dml(scriptString): Instance of a script object. """ if not isinstance(scriptString, str): - raise ValueError("scriptString should be a string, got %s" % type(scriptString)) + raise ValueError( + "scriptString should be a string, got %s" % + type(scriptString)) return Script(scriptString, scriptType="dml") + def dmlFromResource(resourcePath): """ Create a dml script object based on a resource path. @@ -117,7 +136,9 @@ def dmlFromResource(resourcePath): Instance of a script object. """ if not isinstance(resourcePath, str): - raise ValueError("resourcePath should be a string, got %s" % type(resourcePath)) + raise ValueError( + "resourcePath should be a string, got %s" % + type(resourcePath)) return Script(resourcePath, scriptType="dml", isResource=True) @@ -136,9 +157,12 @@ def pydml(scriptString): Instance of a script object. """ if not isinstance(scriptString, str): - raise ValueError("scriptString should be a string, got %s" % type(scriptString)) + raise ValueError( + "scriptString should be a string, got %s" % + type(scriptString)) return Script(scriptString, scriptType="pydml") + def pydmlFromResource(resourcePath): """ Create a pydml script object based on a resource path. @@ -154,9 +178,12 @@ def pydmlFromResource(resourcePath): Instance of a script object. """ if not isinstance(resourcePath, str): - raise ValueError("resourcePath should be a string, got %s" % type(resourcePath)) + raise ValueError( + "resourcePath should be a string, got %s" % + type(resourcePath)) return Script(resourcePath, scriptType="pydml", isResource=True) + def dmlFromFile(filePath): """ Create a dml script object based on a file path. @@ -172,9 +199,13 @@ def dmlFromFile(filePath): Instance of a script object. """ if not isinstance(filePath, str): - raise ValueError("filePath should be a string, got %s" % type(filePath)) - return Script(filePath, scriptType="dml", isResource=False, scriptFormat="file") - + raise ValueError( + "filePath should be a string, got %s" % + type(filePath)) + return Script(filePath, scriptType="dml", + isResource=False, scriptFormat="file") + + def pydmlFromFile(filePath): """ Create a pydml script object based on a file path. @@ -190,9 +221,12 @@ def pydmlFromFile(filePath): Instance of a script object. """ if not isinstance(filePath, str): - raise ValueError("filePath should be a string, got %s" % type(filePath)) - return Script(filePath, scriptType="pydml", isResource=False, scriptFormat="file") - + raise ValueError( + "filePath should be a string, got %s" % + type(filePath)) + return Script(filePath, scriptType="pydml", + isResource=False, scriptFormat="file") + def dmlFromUrl(url): """ @@ -212,6 +246,7 @@ def dmlFromUrl(url): raise ValueError("url should be a string, got %s" % type(url)) return Script(url, scriptType="dml", isResource=False, scriptFormat="url") + def pydmlFromUrl(url): """ Create a pydml script object based on a url. @@ -228,7 +263,9 @@ def pydmlFromUrl(url): """ if not isinstance(url, str): raise ValueError("url should be a string, got %s" % type(url)) - return Script(url, scriptType="pydml", isResource=False, scriptFormat="url") + return Script(url, scriptType="pydml", + isResource=False, scriptFormat="url") + def _java2py(sc, obj): """ Convert Java object to Python. """ @@ -265,6 +302,7 @@ class Matrix(object): sc: SparkContext SparkContext """ + def __init__(self, javaMatrix, sc): self._java_matrix = javaMatrix self._sc = sc @@ -297,7 +335,8 @@ class Matrix(object): NumPy Array A NumPy Array representing the Matrix object. """ - np_array = convertToNumPyArr(self._sc, self._java_matrix.toMatrixBlock()) + np_array = convertToNumPyArr( + self._sc, self._java_matrix.toMatrixBlock()) return np_array @@ -313,6 +352,7 @@ class MLResults(object): sc: SparkContext SparkContext """ + def __init__(self, results, sc): self._java_results = results self._sc = sc @@ -327,7 +367,8 @@ class MLResults(object): outputs: string, list of strings Output variables as defined inside the DML script. """ - outs = [_java2py(self._sc, self._java_results.get(out)) for out in outputs] + outs = [_java2py(self._sc, self._java_results.get(out)) + for out in outputs] if len(outs) == 1: return outs[0] return outs @@ -347,70 +388,87 @@ class Script(object): isResource: boolean If true, scriptString is a path to a resource on the classpath - + scriptFormat: string Optional script format, either "auto" or "url" or "file" or "resource" or "string" """ - def __init__(self, scriptString, scriptType="dml", isResource=False, scriptFormat="auto"): + + def __init__(self, scriptString, scriptType="dml", + isResource=False, scriptFormat="auto"): self.sc = get_spark_context() self.scriptString = scriptString self.scriptType = scriptType self.isResource = isResource if scriptFormat != "auto": if scriptFormat == "url" and self.scriptType == "dml": - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl( + scriptString) elif scriptFormat == "url" and self.scriptType == "pydml": - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl( + scriptString) elif scriptFormat == "file" and self.scriptType == "dml": - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile( + scriptString) elif scriptFormat == "file" and self.scriptType == "pydml": - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile( + scriptString) elif isResource and self.scriptType == "dml": - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource( + scriptString) elif isResource and self.scriptType == "pydml": - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource( + scriptString) elif scriptFormat == "string" and self.scriptType == "dml": - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml( + scriptString) elif scriptFormat == "string" and self.scriptType == "pydml": - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml( + scriptString) else: raise ValueError('Unsupported script format' + scriptFormat) elif self.scriptType == "dml": if scriptString.endswith(".dml"): if scriptString.startswith("http"): - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl( + scriptString) elif os.path.exists(scriptString): - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile( + scriptString) elif self.isResource == True: - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource( + scriptString) else: raise ValueError("path: %s does not exist" % scriptString) else: - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml( + scriptString) elif self.scriptType == "pydml": if scriptString.endswith(".pydml"): if scriptString.startswith("http"): - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl( + scriptString) elif os.path.exists(scriptString): - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile( + scriptString) elif self.isResource == True: - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource( + scriptString) else: raise ValueError("path: %s does not exist" % scriptString) else: - self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml(scriptString) + self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml( + scriptString) - def getScriptString(self): """ Obtain the script string (in unicode). """ return self.script_java.getScriptString() - + def setScriptString(self, scriptString): """ Set the script string. - + Parameters ---------- scriptString: string @@ -438,80 +496,80 @@ class Script(object): """ self.script_java.clearIOS() return self - + def clearIO(self): """ Clear the inputs and outputs, but not the symbol table. """ self.script_java.clearIO() return self - + def clearAll(self): """ Clear the script string, inputs, outputs, and symbol table. """ self.script_java.clearAll() return self - + def clearInputs(self): """ Clear the inputs. """ self.script_java.clearInputs() return self - + def clearOutputs(self): """ Clear the outputs. """ self.script_java.clearOutputs() return self - + def clearSymbolTable(self): """ Clear the symbol table. """ self.script_java.clearSymbolTable() return self - + def results(self): """ Obtain the results of the script execution. """ return MLResults(self.script_java.results(), self.sc) - + def getResults(self): """ Obtain the results of the script execution. """ return MLResults(self.script_java.getResults(), self.sc) - + def setResults(self, results): """ Set the results of the script execution. """ self.script_java.setResults(results._java_results) return self - + def isDML(self): """ Is the script type DML? """ return self.script_java.isDML() - + def isPYDML(self): """ Is the script type DML? """ return self.script_java.isPYDML() - + def getScriptExecutionString(self): """ Generate the script execution string, which adds read/load/write/save statements to the beginning and end of the script to execute. """ - return self.script_java.getScriptExecutionString() - + return self.script_java.getScriptExecutionString() + def __repr__(self): return "Script" @@ -528,56 +586,56 @@ class Script(object): Display the script inputs. """ return self.script_java.displayInputs() - + def displayOutputs(self): """ Display the script outputs. """ return self.script_java.displayOutputs() - + def displayInputParameters(self): """ Display the script input parameters. """ return self.script_java.displayInputParameters() - + def displayInputVariables(self): """ Display the script input variables. """ return self.script_java.displayInputVariables() - + def displayOutputVariables(self): """ Display the script output variables. """ return self.script_java.displayOutputVariables() - + def displaySymbolTable(self): """ Display the script symbol table. """ return self.script_java.displaySymbolTable() - + def getName(self): """ Obtain the script name. """ return self.script_java.getName() - + def setName(self, name): """ Set the script name. """ self.script_java.setName(name) return self - + def getScriptType(self): """ Obtain the script type. """ return self.scriptType - + def input(self, *args, **kwargs): """ Parameters @@ -606,9 +664,11 @@ class Script(object): if isinstance(val, py4j.java_gateway.JavaObject): py4j.java_gateway.get_method(self.script_java, "in")(key, val) else: - py4j.java_gateway.get_method(self.script_java, "in")(key, _py2java(self.sc, val)) - - + py4j.java_gateway.get_method( + self.script_java, "in")( + key, _py2java( + self.sc, val)) + def output(self, *names): """ Parameters @@ -630,17 +690,20 @@ class MLContext(object): sc: SparkContext or SparkSession An instance of pyspark.SparkContext or pyspark.sql.SparkSession. """ + def __init__(self, sc): if isinstance(sc, pyspark.sql.session.SparkSession): sc = sc._sc elif not isinstance(sc, SparkContext): - raise ValueError("Expected sc to be a SparkContext or SparkSession, got " % str(type(sc))) + raise ValueError( + "Expected sc to be a SparkContext or SparkSession, got " % str( + type(sc))) self._sc = sc self._ml = createJavaObject(sc, 'mlcontext') def __repr__(self): return "MLContext" - + def execute(self, script): """ Execute a DML / PyDML script. @@ -688,7 +751,7 @@ class MLContext(object): """ self._ml.setGPU(bool(enable)) return self - + def setForceGPU(self, enable): """ Whether or not to force the usage of GPU operators. @@ -699,7 +762,7 @@ class MLContext(object): """ self._ml.setForceGPU(bool(enable)) return self - + def setStatisticsMaxHeavyHitters(self, maxHeavyHitters): """ The maximum number of heavy hitters that are printed as part of the statistics.
http://git-wip-us.apache.org/repos/asf/systemml/blob/9e7ee19a/src/main/python/systemml/mllearn/__init__.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mllearn/__init__.py b/src/main/python/systemml/mllearn/__init__.py index 907ffb2..eda2197 100644 --- a/src/main/python/systemml/mllearn/__init__.py +++ b/src/main/python/systemml/mllearn/__init__.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------- +# ------------------------------------------------------------- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. # -#------------------------------------------------------------- +# ------------------------------------------------------------- """ =================== http://git-wip-us.apache.org/repos/asf/systemml/blob/9e7ee19a/src/main/python/systemml/mllearn/estimators.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py index 52fe9cc..0b3de41 100644 --- a/src/main/python/systemml/mllearn/estimators.py +++ b/src/main/python/systemml/mllearn/estimators.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------- +# ------------------------------------------------------------- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -17,9 +17,15 @@ # specific language governing permissions and limitations # under the License. # -#------------------------------------------------------------- +# ------------------------------------------------------------- -__all__ = ['LinearRegression', 'LogisticRegression', 'SVM', 'NaiveBayes', 'Caffe2DML', 'Keras2DML'] +__all__ = [ + 'LinearRegression', + 'LogisticRegression', + 'SVM', + 'NaiveBayes', + 'Caffe2DML', + 'Keras2DML'] import numpy as np from pyspark.ml import Estimator @@ -37,15 +43,17 @@ import math from ..converters import * from ..classloader import * + def assemble(sparkSession, pdf, inputCols, outputCol): tmpDF = sparkSession.createDataFrame(pdf, list(pdf.columns)) assembler = VectorAssembler(inputCols=list(inputCols), outputCol=outputCol) return assembler.transform(tmpDF) + class BaseSystemMLEstimator(Estimator): features_col = 'features' label_col = 'label' - + def set_features_col(self, colName): """ Sets the default column name for features of PySpark DataFrame. @@ -76,7 +84,7 @@ class BaseSystemMLEstimator(Estimator): """ self.estimator.setGPU(enable) return self - + def setForceGPU(self, enable): """ Whether or not to force the usage of GPU operators. @@ -87,7 +95,7 @@ class BaseSystemMLEstimator(Estimator): """ self.estimator.setForceGPU(enable) return self - + def setExplain(self, explain): """ Explanation about the program. Mainly intended for developers. @@ -98,11 +106,11 @@ class BaseSystemMLEstimator(Estimator): """ self.estimator.setExplain(explain) return self - + def setExplainLevel(self, explainLevel): """ Set explain level. Mainly intended for developers. - + Parameters ---------- explainLevel: string @@ -111,7 +119,7 @@ class BaseSystemMLEstimator(Estimator): """ self.estimator.setExplainLevel(explainLevel) return self - + def setStatistics(self, statistics): """ Whether or not to output statistics (such as execution time, elapsed time) @@ -123,7 +131,7 @@ class BaseSystemMLEstimator(Estimator): """ self.estimator.setStatistics(statistics) return self - + def setStatisticsMaxHeavyHitters(self, maxHeavyHitters): """ The maximum number of heavy hitters that are printed as part of the statistics. @@ -134,7 +142,7 @@ class BaseSystemMLEstimator(Estimator): """ self.estimator.setStatisticsMaxHeavyHitters(maxHeavyHitters) return self - + def setConfigProperty(self, propertyName, propertyValue): """ Set configuration property, such as setConfigProperty("sysml.localtmpdir", "/tmp/systemml"). @@ -146,7 +154,7 @@ class BaseSystemMLEstimator(Estimator): """ self.estimator.setConfigProperty(propertyName, propertyValue) return self - + def _fit_df(self): global default_jvm_stdout, default_jvm_stdout_parallel_flush try: @@ -157,28 +165,30 @@ class BaseSystemMLEstimator(Estimator): self.model = self.estimator.fit(self.X._jdf) except Py4JError: traceback.print_exc() - + def fit_df(self, X): self.X = X self._fit_df() self.X = None return self - + def _fit_numpy(self): global default_jvm_stdout, default_jvm_stdout_parallel_flush try: - if type(self.y) == np.ndarray and len(self.y.shape) == 1: + if isinstance(self.y, np.ndarray) and len(self.y.shape) == 1: # Since we know that mllearn always needs a column vector self.y = np.matrix(self.y).T y_mb = convertToMatrixBlock(self.sc, self.y) if default_jvm_stdout: with jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush): - self.model = self.estimator.fit(convertToMatrixBlock(self.sc, self.X), y_mb) + self.model = self.estimator.fit( + convertToMatrixBlock(self.sc, self.X), y_mb) else: - self.model = self.estimator.fit(convertToMatrixBlock(self.sc, self.X), y_mb) + self.model = self.estimator.fit( + convertToMatrixBlock(self.sc, self.X), y_mb) except Py4JError: traceback.print_exc() - + def fit_numpy(self, X, y): self.X = X self.y = y @@ -198,7 +208,7 @@ class BaseSystemMLEstimator(Estimator): except Py4JError: traceback.print_exc() return self - + # Returns a model after calling fit(df) on Estimator object on JVM def _fit(self, X): """ @@ -208,10 +218,12 @@ class BaseSystemMLEstimator(Estimator): ---------- X: PySpark DataFrame that contain the columns features_col (default: 'features') and label_col (default: 'label') """ - if hasattr(X, '_jdf') and self.features_col in X.columns and self.label_col in X.columns: + if hasattr( + X, '_jdf') and self.features_col in X.columns and self.label_col in X.columns: return self.fit_df(X) else: - raise Exception('Incorrect usage: Expected dataframe as input with features/label as columns') + raise Exception( + 'Incorrect usage: Expected dataframe as input with features/label as columns') def fit(self, X, y=None, params=None): """ @@ -228,7 +240,8 @@ class BaseSystemMLEstimator(Estimator): return self.fit_file(X, y) elif isinstance(X, SUPPORTED_TYPES) and isinstance(y, SUPPORTED_TYPES): # Donot encode if y is a numpy matrix => useful for segmentation - skipEncodingY = len(y.shape) == 2 and y.shape[0] != 1 and y.shape[1] != 1 + skipEncodingY = len( + y.shape) == 2 and y.shape[0] != 1 and y.shape[1] != 1 y = y if skipEncodingY else self.encode(y) if self.transferUsingDF: pdfX = convertToPandasDF(X) @@ -239,7 +252,13 @@ class BaseSystemMLEstimator(Estimator): raise Exception('Number of rows of X and y should match') colNames = pdfX.columns pdfX[self.label_col] = pdfY[pdfY.columns[0]] - df = assemble(self.sparkSession, pdfX, colNames, self.features_col).select(self.features_col, self.label_col) + df = assemble( + self.sparkSession, + pdfX, + colNames, + self.features_col).select( + self.features_col, + self.label_col) self.fit_df(df) else: numColsy = getNumCols(y) @@ -254,7 +273,7 @@ class BaseSystemMLEstimator(Estimator): def transform(self, X): return self.predict(X) - + def _convertPythonXToJavaObject(self, X): """ Converts the input python object X to a java-side object (either MatrixBlock or Java DataFrame) @@ -265,7 +284,12 @@ class BaseSystemMLEstimator(Estimator): """ if isinstance(X, SUPPORTED_TYPES) and self.transferUsingDF: pdfX = convertToPandasDF(X) - df = assemble(self.sparkSession, pdfX, pdfX.columns, self.features_col).select(self.features_col) + df = assemble( + self.sparkSession, + pdfX, + pdfX.columns, + self.features_col).select( + self.features_col) return df._jdf elif isinstance(X, SUPPORTED_TYPES): return convertToMatrixBlock(self.sc, X) @@ -273,12 +297,13 @@ class BaseSystemMLEstimator(Estimator): # No need to assemble as input DF is likely coming via MLPipeline return X._jdf elif hasattr(X, '_jdf'): - assembler = VectorAssembler(inputCols=X.columns, outputCol=self.features_col) + assembler = VectorAssembler( + inputCols=X.columns, outputCol=self.features_col) df = assembler.transform(X) return df._jdf else: raise Exception('Unsupported input type') - + def _convertJavaOutputToPythonObject(self, X, output): """ Converts the a java-side object output (either MatrixBlock or Java DataFrame) to a python object (based on the type of X). @@ -300,7 +325,7 @@ class BaseSystemMLEstimator(Estimator): return retDF.sort('__INDEX') else: raise Exception('Unsupported input type') - + def predict_proba(self, X): """ Invokes the transform_probability method on Estimator object on JVM if X and y are on of the supported data types @@ -314,7 +339,8 @@ class BaseSystemMLEstimator(Estimator): if hasattr(X, '_jdf'): return self.predict(X) elif self.transferUsingDF: - raise ValueError('The parameter transferUsingDF is not valid for the method predict_proba') + raise ValueError( + 'The parameter transferUsingDF is not valid for the method predict_proba') try: if self.estimator is not None and self.model is not None: self.estimator.copyProperties(self.model) @@ -326,13 +352,16 @@ class BaseSystemMLEstimator(Estimator): jX = self._convertPythonXToJavaObject(X) if default_jvm_stdout: with jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush): - return self._convertJavaOutputToPythonObject(X, self.model.transform_probability(jX)) + return self._convertJavaOutputToPythonObject( + X, self.model.transform_probability(jX)) else: - return self._convertJavaOutputToPythonObject(X, self.model.transform_probability(jX)) + return self._convertJavaOutputToPythonObject( + X, self.model.transform_probability(jX)) except Py4JError: traceback.print_exc() - - # Returns either a DataFrame or MatrixBlock after calling transform(X:MatrixBlock, y:MatrixBlock) on Model object on JVM + + # Returns either a DataFrame or MatrixBlock after calling + # transform(X:MatrixBlock, y:MatrixBlock) on Model object on JVM def predict(self, X): """ Invokes the transform method on Estimator object on JVM if X and y are on of the supported data types @@ -353,20 +382,23 @@ class BaseSystemMLEstimator(Estimator): jX = self._convertPythonXToJavaObject(X) if default_jvm_stdout: with jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush): - ret = self._convertJavaOutputToPythonObject(X, self.model.transform(jX)) + ret = self._convertJavaOutputToPythonObject( + X, self.model.transform(jX)) else: - ret = self._convertJavaOutputToPythonObject(X, self.model.transform(jX)) + ret = self._convertJavaOutputToPythonObject( + X, self.model.transform(jX)) return self.decode(ret) if isinstance(X, SUPPORTED_TYPES) else ret except Py4JError: traceback.print_exc() + class BaseSystemMLClassifier(BaseSystemMLEstimator): def encode(self, y): self.le = LabelEncoder() self.le.fit(y) return self.le.transform(y) + 1 - + def decode(self, y): if not hasattr(self, 'le'): self.le = None @@ -375,14 +407,14 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator): if self.le is not None: return self.le.inverse_transform(np.asarray(y - 1, dtype=int)) elif self.labelMap is not None: - return [ self.labelMap[int(i)] for i in y ] + return [self.labelMap[int(i)] for i in y] else: return y - + def predict(self, X): predictions = super(BaseSystemMLClassifier, self).predict(X) from pyspark.sql.dataframe import DataFrame as df - if type(predictions) == df: + if isinstance(predictions, df): return predictions else: try: @@ -390,7 +422,7 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator): except ValueError: print(type(predictions)) return np.asarray(predictions, dtype='str') - + def score(self, X, y): """ Scores the predicted value with ground truth 'y' @@ -404,8 +436,9 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator): if np.issubdtype(predictions.dtype.type, np.number): return accuracy_score(y, predictions) else: - return accuracy_score(np.asarray(y, dtype='str'), np.asarray(predictions, dtype='str')) - + return accuracy_score(np.asarray( + y, dtype='str'), np.asarray(predictions, dtype='str')) + def loadLabels(self, file_path): createJavaObject(self.sc, 'dummy') utilObj = self.sc._jvm.org.apache.sysml.api.ml.Utils() @@ -417,10 +450,10 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator): for i in range(len(keys)): self.labelMap[int(keys[i])] = values[i] # self.encode(classes) # Giving incorrect results - + def load(self, weights, sep='/', eager=False): """ - Load a pretrained model. + Load a pretrained model. Parameters ---------- @@ -436,11 +469,11 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator): else: self.model.load(self.sc._jsc, weights, sep, eager) self.loadLabels(weights + '/labels.txt') - + def save(self, outputDir, format='binary', sep='/'): """ Save a trained model. - + Parameters ---------- outputDir: Directory to save the model to @@ -448,7 +481,7 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator): sep: seperator to use (default: '/') """ global default_jvm_stdout, default_jvm_stdout_parallel_flush - if self.model != None: + if self.model is not None: if default_jvm_stdout: with jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush): self.model.save(self.sc._jsc, outputDir, format, sep) @@ -462,21 +495,26 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator): labelMapping = self.labelMap if labelMapping is not None: - lStr = [ [ int(k), str(labelMapping[k]) ] for k in labelMapping ] + lStr = [[int(k), str(labelMapping[k])] for k in labelMapping] df = self.sparkSession.createDataFrame(lStr) - df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', header=False) + df.write.csv( + outputDir + sep + 'labels.txt', + mode='overwrite', + header=False) else: - raise Exception('Cannot save as you need to train the model first using fit') + raise Exception( + 'Cannot save as you need to train the model first using fit') return self + class BaseSystemMLRegressor(BaseSystemMLEstimator): def encode(self, y): return y - + def decode(self, y): return y - + def score(self, X, y): """ Scores the predicted value with ground truth 'y' @@ -487,10 +525,10 @@ class BaseSystemMLRegressor(BaseSystemMLEstimator): y: NumPy ndarray, Pandas DataFrame, scipy sparse matrix """ return r2_score(y, self.predict(X), multioutput='variance_weighted') - + def load(self, weights=None, sep='/', eager=False): """ - Load a pretrained model. + Load a pretrained model. Parameters ---------- @@ -509,7 +547,7 @@ class BaseSystemMLRegressor(BaseSystemMLEstimator): def save(self, outputDir, format='binary', sep='/'): """ Save a trained model. - + Parameters ---------- outputDir: Directory to save the model to @@ -517,14 +555,15 @@ class BaseSystemMLRegressor(BaseSystemMLEstimator): sep: seperator to use (default: '/') """ global default_jvm_stdout, default_jvm_stdout_parallel_flush - if self.model != None: + if self.model is not None: if default_jvm_stdout: with jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush): self.model.save(outputDir, format, sep) else: self.model.save(outputDir, format, sep) else: - raise Exception('Cannot save as you need to train the model first using fit') + raise Exception( + 'Cannot save as you need to train the model first using fit') return self @@ -534,9 +573,9 @@ class LogisticRegression(BaseSystemMLClassifier): Examples -------- - + Scikit-learn way - + >>> from sklearn import datasets, neighbors >>> from systemml.mllearn import LogisticRegression >>> from pyspark.sql import SparkSession @@ -551,9 +590,9 @@ class LogisticRegression(BaseSystemMLClassifier): >>> y_test = y_digits[.9 * n_samples:] >>> logistic = LogisticRegression(sparkSession) >>> print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test)) - + MLPipeline way - + >>> from pyspark.ml import Pipeline >>> from systemml.mllearn import LogisticRegression >>> from pyspark.ml.feature import HashingTF, Tokenizer @@ -585,13 +624,14 @@ class LogisticRegression(BaseSystemMLClassifier): >>> (15L, "apache hadoop")], ["id", "text"]) >>> prediction = model.transform(test) >>> prediction.show() - + """ - - def __init__(self, sparkSession, penalty='l2', fit_intercept=True, normalize=False, max_iter=100, max_inner_iter=0, tol=0.000001, C=1.0, solver='newton-cg', transferUsingDF=False): + + def __init__(self, sparkSession, penalty='l2', fit_intercept=True, normalize=False, max_iter=100, + max_inner_iter=0, tol=0.000001, C=1.0, solver='newton-cg', transferUsingDF=False): """ Performs both binomial and multinomial logistic regression. - + Parameters ---------- sparkSession: PySpark SparkSession @@ -608,30 +648,33 @@ class LogisticRegression(BaseSystemMLClassifier): self.sc = sparkSession._sc createJavaObject(self.sc, 'dummy') self.uid = "logReg" - self.estimator = self.sc._jvm.org.apache.sysml.api.ml.LogisticRegression(self.uid, self.sc._jsc.sc()) + self.estimator = self.sc._jvm.org.apache.sysml.api.ml.LogisticRegression( + self.uid, self.sc._jsc.sc()) self.estimator.setMaxOuterIter(max_iter) self.estimator.setMaxInnerIter(max_inner_iter) reg = 0.0 if C == float("inf") else 1.0 / C - icpt = 2 if fit_intercept == True and normalize == True else int(fit_intercept) + icpt = 2 if fit_intercept == True and normalize == True else int( + fit_intercept) self.estimator.setRegParam(reg) self.estimator.setTol(tol) self.estimator.setIcpt(icpt) self.transferUsingDF = transferUsingDF self.setOutputRawPredictionsToFalse = True - self.model = self.sc._jvm.org.apache.sysml.api.ml.LogisticRegressionModel(self.estimator) + self.model = self.sc._jvm.org.apache.sysml.api.ml.LogisticRegressionModel( + self.estimator) if penalty != 'l2': raise Exception('Only l2 penalty is supported') if solver != 'newton-cg': raise Exception('Only newton-cg solver supported') - + class LinearRegression(BaseSystemMLRegressor): """ Performs linear regression to model the relationship between one numerical response variable and one or more explanatory (feature) variables. - + Examples -------- - + >>> import numpy as np >>> from sklearn import datasets >>> from systemml.mllearn import LinearRegression @@ -652,11 +695,11 @@ class LinearRegression(BaseSystemMLRegressor): >>> regr.fit(diabetes_X_train, diabetes_y_train) >>> # The mean square error >>> print("Residual sum of squares: %.2f" % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2)) - + """ - - - def __init__(self, sparkSession, fit_intercept=True, normalize=False, max_iter=100, tol=0.000001, C=float("inf"), solver='newton-cg', transferUsingDF=False): + + def __init__(self, sparkSession, fit_intercept=True, normalize=False, max_iter=100, + tol=0.000001, C=float("inf"), solver='newton-cg', transferUsingDF=False): """ Performs linear regression to model the relationship between one numerical response variable and one or more explanatory (feature) variables. @@ -678,18 +721,21 @@ class LinearRegression(BaseSystemMLRegressor): createJavaObject(self.sc, 'dummy') self.uid = "lr" if solver == 'newton-cg' or solver == 'direct-solve': - self.estimator = self.sc._jvm.org.apache.sysml.api.ml.LinearRegression(self.uid, self.sc._jsc.sc(), solver) + self.estimator = self.sc._jvm.org.apache.sysml.api.ml.LinearRegression( + self.uid, self.sc._jsc.sc(), solver) else: raise Exception('Only newton-cg solver supported') self.estimator.setMaxIter(max_iter) reg = 0.0 if C == float("inf") else 1.0 / C - icpt = 2 if fit_intercept == True and normalize == True else int(fit_intercept) + icpt = 2 if fit_intercept == True and normalize == True else int( + fit_intercept) self.estimator.setRegParam(reg) self.estimator.setTol(tol) self.estimator.setIcpt(icpt) self.transferUsingDF = transferUsingDF self.setOutputRawPredictionsToFalse = False - self.model = self.sc._jvm.org.apache.sysml.api.ml.LinearRegressionModel(self.estimator) + self.model = self.sc._jvm.org.apache.sysml.api.ml.LinearRegressionModel( + self.estimator) class SVM(BaseSystemMLClassifier): @@ -698,14 +744,14 @@ class SVM(BaseSystemMLClassifier): Examples -------- - + >>> from sklearn import datasets, neighbors >>> from systemml.mllearn import SVM >>> from pyspark.sql import SparkSession >>> sparkSession = SparkSession.builder.getOrCreate() >>> digits = datasets.load_digits() >>> X_digits = digits.data - >>> y_digits = digits.target + >>> y_digits = digits.target >>> n_samples = len(X_digits) >>> X_train = X_digits[:.9 * n_samples] >>> y_train = y_digits[:.9 * n_samples] @@ -713,11 +759,11 @@ class SVM(BaseSystemMLClassifier): >>> y_test = y_digits[.9 * n_samples:] >>> svm = SVM(sparkSession, is_multi_class=True) >>> print('LogisticRegression score: %f' % svm.fit(X_train, y_train).score(X_test, y_test)) - - """ + """ - def __init__(self, sparkSession, fit_intercept=True, normalize=False, max_iter=100, tol=0.000001, C=1.0, is_multi_class=False, transferUsingDF=False): + def __init__(self, sparkSession, fit_intercept=True, normalize=False, max_iter=100, + tol=0.000001, C=1.0, is_multi_class=False, transferUsingDF=False): """ Performs both binary-class and multiclass SVM (Support Vector Machines). @@ -736,18 +782,22 @@ class SVM(BaseSystemMLClassifier): self.uid = "svm" createJavaObject(self.sc, 'dummy') self.is_multi_class = is_multi_class - self.estimator = self.sc._jvm.org.apache.sysml.api.ml.SVM(self.uid, self.sc._jsc.sc(), is_multi_class) + self.estimator = self.sc._jvm.org.apache.sysml.api.ml.SVM( + self.uid, self.sc._jsc.sc(), is_multi_class) self.estimator.setMaxIter(max_iter) if C <= 0: raise Exception('C has to be positive') reg = 0.0 if C == float("inf") else 1.0 / C - icpt = 2 if fit_intercept == True and normalize == True else int(fit_intercept) + icpt = 2 if fit_intercept == True and normalize == True else int( + fit_intercept) self.estimator.setRegParam(reg) self.estimator.setTol(tol) self.estimator.setIcpt(icpt) self.transferUsingDF = transferUsingDF self.setOutputRawPredictionsToFalse = False - self.model = self.sc._jvm.org.apache.sysml.api.ml.SVMModel(self.estimator, self.is_multi_class) + self.model = self.sc._jvm.org.apache.sysml.api.ml.SVMModel( + self.estimator, self.is_multi_class) + class NaiveBayes(BaseSystemMLClassifier): """ @@ -755,7 +805,7 @@ class NaiveBayes(BaseSystemMLClassifier): Examples -------- - + >>> from sklearn.datasets import fetch_20newsgroups >>> from sklearn.feature_extraction.text import TfidfVectorizer >>> from systemml.mllearn import NaiveBayes @@ -775,7 +825,7 @@ class NaiveBayes(BaseSystemMLClassifier): >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted') """ - + def __init__(self, sparkSession, laplace=1.0, transferUsingDF=False): """ Performs Naive Bayes. @@ -789,19 +839,22 @@ class NaiveBayes(BaseSystemMLClassifier): self.sc = sparkSession._sc self.uid = "nb" createJavaObject(self.sc, 'dummy') - self.estimator = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayes(self.uid, self.sc._jsc.sc()) + self.estimator = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayes( + self.uid, self.sc._jsc.sc()) self.estimator.setLaplace(laplace) self.transferUsingDF = transferUsingDF self.setOutputRawPredictionsToFalse = False - self.model = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayesModel(self.estimator) + self.model = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayesModel( + self.estimator) + class Caffe2DML(BaseSystemMLClassifier): """ Performs training/prediction for a given caffe network. - + Examples -------- - + >>> from systemml.mllearn import Caffe2DML >>> from mlxtend.data import mnist_data >>> import numpy as np @@ -815,9 +868,11 @@ class Caffe2DML(BaseSystemMLClassifier): >>> caffe2DML = Caffe2DML(spark, 'lenet_solver.proto').set(max_iter=500) >>> caffe2DML.fit(X, y) """ - def __init__(self, sparkSession, solver, input_shape, transferUsingDF=False): + + def __init__(self, sparkSession, solver, input_shape, + transferUsingDF=False): """ - Performs training/prediction for a given caffe network. + Performs training/prediction for a given caffe network. Parameters ---------- @@ -833,14 +888,19 @@ class Caffe2DML(BaseSystemMLClassifier): self.model = None if len(input_shape) != 3: raise ValueError('Expected input_shape as list of 3 element') - solver = self.sc._jvm.org.apache.sysml.api.dl.Utils.readCaffeSolver(solver) - self.estimator = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DML(self.sc._jsc.sc(), solver, str(input_shape[0]), str(input_shape[1]), str(input_shape[2])) + solver = self.sc._jvm.org.apache.sysml.api.dl.Utils.readCaffeSolver( + solver) + self.estimator = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DML( + self.sc._jsc.sc(), solver, str( + input_shape[0]), str( + input_shape[1]), str( + input_shape[2])) self.transferUsingDF = transferUsingDF self.setOutputRawPredictionsToFalse = False def load(self, weights=None, sep='/', ignore_weights=None, eager=False): """ - Load a pretrained model. + Load a pretrained model. Parameters ---------- @@ -852,7 +912,8 @@ class Caffe2DML(BaseSystemMLClassifier): global default_jvm_stdout, default_jvm_stdout_parallel_flush self.weights = weights self.estimator.setInput("$weights", str(weights)) - self.model = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DMLModel(self.estimator) + self.model = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DMLModel( + self.estimator) if default_jvm_stdout: with jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush): self.model.load(self.sc._jsc, weights, sep, eager) @@ -861,11 +922,12 @@ class Caffe2DML(BaseSystemMLClassifier): self.loadLabels(weights + '/labels.txt') if ignore_weights is not None: self.estimator.setWeightsToIgnore(ignore_weights) - - def set(self, debug=None, train_algo=None, test_algo=None, parallel_batches=None, output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None): + + def set(self, debug=None, train_algo=None, test_algo=None, parallel_batches=None, + output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None): """ Set input to Caffe2DML - + Parameters ---------- debug: to add debugging DML code such as classification report, print DML script, etc (default: False) @@ -876,37 +938,55 @@ class Caffe2DML(BaseSystemMLClassifier): perform_one_hot_encoding: should perform one-hot encoding in DML using table function (default: False) parfor_parameters: dictionary for parfor parameters when using allreduce-style algorithms (default: "") """ - if debug is not None: self.estimator.setInput("$debug", str(debug).upper()) - if train_algo is not None: self.estimator.setInput("$train_algo", str(train_algo).lower()) - if test_algo is not None: self.estimator.setInput("$test_algo", str(test_algo).lower()) - if parallel_batches is not None: self.estimator.setInput("$parallel_batches", str(parallel_batches)) - if output_activations is not None: self.estimator.setInput("$output_activations", str(output_activations)) - if perform_one_hot_encoding is not None: self.estimator.setInput("$perform_one_hot_encoding", str(perform_one_hot_encoding).lower()) + if debug is not None: + self.estimator.setInput("$debug", str(debug).upper()) + if train_algo is not None: + self.estimator.setInput("$train_algo", str(train_algo).lower()) + if test_algo is not None: + self.estimator.setInput("$test_algo", str(test_algo).lower()) + if parallel_batches is not None: + self.estimator.setInput("$parallel_batches", str(parallel_batches)) + if output_activations is not None: + self.estimator.setInput( + "$output_activations", + str(output_activations)) + if perform_one_hot_encoding is not None: + self.estimator.setInput( + "$perform_one_hot_encoding", + str(perform_one_hot_encoding).lower()) if parfor_parameters is not None: if isinstance(parfor_parameters, dict): # Convert dictionary to comma-separated list - parfor_parameters = ''.join([ ', ' + str(k) + '=' + str(v) for k, v in parfor_parameters.items()]) if len(parfor_parameters) > 0 else '' - self.estimator.setInput("$parfor_parameters", parfor_parameters) + parfor_parameters = ''.join( + [ + ', ' + + str(k) + + '=' + + str(v) for k, + v in parfor_parameters.items()]) if len(parfor_parameters) > 0 else '' + self.estimator.setInput( + "$parfor_parameters", parfor_parameters) else: - raise TypeError("parfor_parameters should be a dictionary") + raise TypeError("parfor_parameters should be a dictionary") return self - + def summary(self): """ Print the summary of the network """ import pyspark global default_jvm_stdout, default_jvm_stdout_parallel_flush - if type(self.sparkSession) == pyspark.sql.session.SparkSession: + if isinstance(self.sparkSession, pyspark.sql.session.SparkSession): if default_jvm_stdout: with jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush): self.estimator.summary(self.sparkSession._jsparkSession) else: self.estimator.summary(self.sparkSession._jsparkSession) else: - raise TypeError('Please use spark session of type pyspark.sql.session.SparkSession in the constructor') - - + raise TypeError( + 'Please use spark session of type pyspark.sql.session.SparkSession in the constructor') + + class Keras2DML(Caffe2DML): """ Peforms training/prediction for a given keras model. @@ -914,7 +994,8 @@ class Keras2DML(Caffe2DML): """ - def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, load_keras_weights=True, weights=None, labels=None, batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"): + def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, load_keras_weights=True, weights=None, labels=None, + batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"): """ Performs training/prediction for a given keras model. @@ -936,9 +1017,9 @@ class Keras2DML(Caffe2DML): weight_decay: regularation strength (default: 5e-4) regularization_type: regularization type (default: "L2") """ - from .keras2caffe import * + from .keras2caffe import convertKerasToCaffeNetwork, convertKerasToCaffeSolver import tempfile - if type(keras_model) == keras.models.Sequential: + if isinstance(keras_model, keras.models.Sequential): # Convert the sequential model to functional model if keras_model.model is None: keras_model.build() @@ -946,19 +1027,47 @@ class Keras2DML(Caffe2DML): self.name = keras_model.name createJavaObject(sparkSession._sc, 'dummy') if not hasattr(keras_model, 'optimizer'): - keras_model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.95, decay=5e-4, nesterov=True)) - convertKerasToCaffeNetwork(keras_model, self.name + ".proto", int(batch_size)) - convertKerasToCaffeSolver(keras_model, self.name + ".proto", self.name + "_solver.proto", int(max_iter), int(test_iter), int(test_interval), int(display), lr_policy, weight_decay, regularization_type) + keras_model.compile( + loss='categorical_crossentropy', + optimizer=keras.optimizers.SGD( + lr=0.01, + momentum=0.95, + decay=5e-4, + nesterov=True)) + convertKerasToCaffeNetwork( + keras_model, + self.name + ".proto", + int(batch_size)) + convertKerasToCaffeSolver( + keras_model, + self.name + ".proto", + self.name + "_solver.proto", + int(max_iter), + int(test_iter), + int(test_interval), + int(display), + lr_policy, + weight_decay, + regularization_type) self.weights = tempfile.mkdtemp() if weights is None else weights if load_keras_weights: - convertKerasToSystemMLModel(sparkSession, keras_model, self.weights) - if labels is not None and (labels.startswith('https:') or labels.startswith('http:')): + convertKerasToSystemMLModel( + sparkSession, keras_model, self.weights) + if labels is not None and (labels.startswith( + 'https:') or labels.startswith('http:')): import urllib urllib.urlretrieve(labels, os.path.join(weights, 'labels.txt')) elif labels is not None: from shutil import copyfile copyfile(labels, os.path.join(weights, 'labels.txt')) - super(Keras2DML,self).__init__(sparkSession, self.name + "_solver.proto", input_shape, transferUsingDF) + super( + Keras2DML, + self).__init__( + sparkSession, + self.name + + "_solver.proto", + input_shape, + transferUsingDF) if load_keras_weights: self.load(self.weights)
