http://git-wip-us.apache.org/repos/asf/systemml/blob/9e7ee19a/src/main/python/systemml/mlcontext.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mlcontext.py 
b/src/main/python/systemml/mlcontext.py
index c945b2b..9956815 100644
--- a/src/main/python/systemml/mlcontext.py
+++ b/src/main/python/systemml/mlcontext.py
@@ -1,4 +1,4 @@
-#-------------------------------------------------------------
+# -------------------------------------------------------------
 #
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
@@ -17,19 +17,29 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-#-------------------------------------------------------------
+# -------------------------------------------------------------
 
 # Methods to create Script object
-script_factory_methods = [ 'dml', 'pydml', 'dmlFromResource', 
'pydmlFromResource', 'dmlFromFile', 'pydmlFromFile', 'dmlFromUrl', 
'pydmlFromUrl' ]
+script_factory_methods = [
+    'dml',
+    'pydml',
+    'dmlFromResource',
+    'pydmlFromResource',
+    'dmlFromFile',
+    'pydmlFromFile',
+    'dmlFromUrl',
+    'pydmlFromUrl']
 # Utility methods
-util_methods = [ '_java2py',  'getHopDAG' ]
-__all__ = ['MLResults', 'MLContext', 'Script', 'Matrix' ] + 
script_factory_methods + util_methods
+util_methods = ['_java2py', 'getHopDAG']
+__all__ = ['MLResults', 'MLContext', 'Script', 'Matrix'] + \
+    script_factory_methods + util_methods
 
 import os
 import numpy as np
 import pandas as pd
-import threading, time
-    
+import threading
+import time
+
 try:
     import py4j.java_gateway
     from py4j.java_gateway import JavaObject
@@ -38,12 +48,15 @@ try:
     import pyspark.mllib.common
     from pyspark.sql import SparkSession
 except ImportError:
-    raise ImportError('Unable to import `pyspark`. Hint: Make sure you are 
running with PySpark.')
+    raise ImportError(
+        'Unable to import `pyspark`. Hint: Make sure you are running with 
PySpark.')
 
 from .converters import *
 from .classloader import *
 
-def getHopDAG(ml, script, lines=None, conf=None, apply_rewrites=True, 
with_subgraph=False):
+
+def getHopDAG(ml, script, lines=None, conf=None,
+              apply_rewrites=True, with_subgraph=False):
     """
     Compile a DML / PyDML script.
 
@@ -51,39 +64,42 @@ def getHopDAG(ml, script, lines=None, conf=None, 
apply_rewrites=True, with_subgr
     ----------
     ml: MLContext instance
         MLContext instance.
-        
+
     script: Script instance
         Script instance defined with the appropriate input and output 
variables.
-    
+
     lines: list of integers
         Optional: only display the hops that have begin and end line number 
equals to the given integers.
-    
+
     conf: SparkConf instance
         Optional spark configuration
-        
+
     apply_rewrites: boolean
         If True, perform static rewrites, perform intra-/inter-procedural 
analysis to propagate size information into functions and apply dynamic rewrites
-    
+
     with_subgraph: boolean
-        If False, the dot graph will be created without subgraphs for 
statement blocks. 
-    
+        If False, the dot graph will be created without subgraphs for 
statement blocks.
+
     Returns
     -------
     hopDAG: string
-        hop DAG in dot format 
+        hop DAG in dot format
     """
     if not isinstance(script, Script):
         raise ValueError("Expected script to be an instance of Script")
     scriptString = script.scriptString
     script_java = script.script_java
-    lines = [ int(x) for x in lines ] if lines is not None else [int(-1)]
+    lines = [int(x) for x in lines] if lines is not None else [int(-1)]
     sc = get_spark_context()
     if conf is not None:
-        hopDAG = 
sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, 
script_java, lines, conf._jconf, apply_rewrites, with_subgraph)
+        hopDAG = 
sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(
+            ml._ml, script_java, lines, conf._jconf, apply_rewrites, 
with_subgraph)
     else:
-        hopDAG = 
sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, 
script_java, lines, apply_rewrites, with_subgraph)
+        hopDAG = 
sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(
+            ml._ml, script_java, lines, apply_rewrites, with_subgraph)
     return hopDAG
 
+
 def dml(scriptString):
     """
     Create a dml script object based on a string.
@@ -99,9 +115,12 @@ def dml(scriptString):
         Instance of a script object.
     """
     if not isinstance(scriptString, str):
-        raise ValueError("scriptString should be a string, got %s" % 
type(scriptString))
+        raise ValueError(
+            "scriptString should be a string, got %s" %
+            type(scriptString))
     return Script(scriptString, scriptType="dml")
 
+
 def dmlFromResource(resourcePath):
     """
     Create a dml script object based on a resource path.
@@ -117,7 +136,9 @@ def dmlFromResource(resourcePath):
         Instance of a script object.
     """
     if not isinstance(resourcePath, str):
-        raise ValueError("resourcePath should be a string, got %s" % 
type(resourcePath))
+        raise ValueError(
+            "resourcePath should be a string, got %s" %
+            type(resourcePath))
     return Script(resourcePath, scriptType="dml", isResource=True)
 
 
@@ -136,9 +157,12 @@ def pydml(scriptString):
         Instance of a script object.
     """
     if not isinstance(scriptString, str):
-        raise ValueError("scriptString should be a string, got %s" % 
type(scriptString))
+        raise ValueError(
+            "scriptString should be a string, got %s" %
+            type(scriptString))
     return Script(scriptString, scriptType="pydml")
 
+
 def pydmlFromResource(resourcePath):
     """
     Create a pydml script object based on a resource path.
@@ -154,9 +178,12 @@ def pydmlFromResource(resourcePath):
         Instance of a script object.
     """
     if not isinstance(resourcePath, str):
-        raise ValueError("resourcePath should be a string, got %s" % 
type(resourcePath))
+        raise ValueError(
+            "resourcePath should be a string, got %s" %
+            type(resourcePath))
     return Script(resourcePath, scriptType="pydml", isResource=True)
 
+
 def dmlFromFile(filePath):
     """
     Create a dml script object based on a file path.
@@ -172,9 +199,13 @@ def dmlFromFile(filePath):
         Instance of a script object.
     """
     if not isinstance(filePath, str):
-        raise ValueError("filePath should be a string, got %s" % 
type(filePath))
-    return Script(filePath, scriptType="dml", isResource=False, 
scriptFormat="file")
-    
+        raise ValueError(
+            "filePath should be a string, got %s" %
+            type(filePath))
+    return Script(filePath, scriptType="dml",
+                  isResource=False, scriptFormat="file")
+
+
 def pydmlFromFile(filePath):
     """
     Create a pydml script object based on a file path.
@@ -190,9 +221,12 @@ def pydmlFromFile(filePath):
         Instance of a script object.
     """
     if not isinstance(filePath, str):
-        raise ValueError("filePath should be a string, got %s" % 
type(filePath))
-    return Script(filePath, scriptType="pydml", isResource=False, 
scriptFormat="file")
-    
+        raise ValueError(
+            "filePath should be a string, got %s" %
+            type(filePath))
+    return Script(filePath, scriptType="pydml",
+                  isResource=False, scriptFormat="file")
+
 
 def dmlFromUrl(url):
     """
@@ -212,6 +246,7 @@ def dmlFromUrl(url):
         raise ValueError("url should be a string, got %s" % type(url))
     return Script(url, scriptType="dml", isResource=False, scriptFormat="url")
 
+
 def pydmlFromUrl(url):
     """
     Create a pydml script object based on a url.
@@ -228,7 +263,9 @@ def pydmlFromUrl(url):
     """
     if not isinstance(url, str):
         raise ValueError("url should be a string, got %s" % type(url))
-    return Script(url, scriptType="pydml", isResource=False, 
scriptFormat="url")
+    return Script(url, scriptType="pydml",
+                  isResource=False, scriptFormat="url")
+
 
 def _java2py(sc, obj):
     """ Convert Java object to Python. """
@@ -265,6 +302,7 @@ class Matrix(object):
     sc: SparkContext
         SparkContext
     """
+
     def __init__(self, javaMatrix, sc):
         self._java_matrix = javaMatrix
         self._sc = sc
@@ -297,7 +335,8 @@ class Matrix(object):
         NumPy Array
             A NumPy Array representing the Matrix object.
         """
-        np_array = convertToNumPyArr(self._sc, 
self._java_matrix.toMatrixBlock())
+        np_array = convertToNumPyArr(
+            self._sc, self._java_matrix.toMatrixBlock())
         return np_array
 
 
@@ -313,6 +352,7 @@ class MLResults(object):
     sc: SparkContext
         SparkContext
     """
+
     def __init__(self, results, sc):
         self._java_results = results
         self._sc = sc
@@ -327,7 +367,8 @@ class MLResults(object):
         outputs: string, list of strings
             Output variables as defined inside the DML script.
         """
-        outs = [_java2py(self._sc, self._java_results.get(out)) for out in 
outputs]
+        outs = [_java2py(self._sc, self._java_results.get(out))
+                for out in outputs]
         if len(outs) == 1:
             return outs[0]
         return outs
@@ -347,70 +388,87 @@ class Script(object):
 
     isResource: boolean
         If true, scriptString is a path to a resource on the classpath
-    
+
     scriptFormat: string
         Optional script format, either "auto" or "url" or "file" or "resource" 
or "string"
     """
-    def __init__(self, scriptString, scriptType="dml", isResource=False, 
scriptFormat="auto"):
+
+    def __init__(self, scriptString, scriptType="dml",
+                 isResource=False, scriptFormat="auto"):
         self.sc = get_spark_context()
         self.scriptString = scriptString
         self.scriptType = scriptType
         self.isResource = isResource
         if scriptFormat != "auto":
             if scriptFormat == "url" and self.scriptType == "dml":
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl(
+                    scriptString)
             elif scriptFormat == "url" and self.scriptType == "pydml":
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl(
+                    scriptString)
             elif scriptFormat == "file" and self.scriptType == "dml":
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(
+                    scriptString)
             elif scriptFormat == "file" and self.scriptType == "pydml":
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(
+                    scriptString)
             elif isResource and self.scriptType == "dml":
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(
+                    scriptString)
             elif isResource and self.scriptType == "pydml":
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(
+                    scriptString)
             elif scriptFormat == "string" and self.scriptType == "dml":
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(
+                    scriptString)
             elif scriptFormat == "string" and self.scriptType == "pydml":
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml(
+                    scriptString)
             else:
                 raise ValueError('Unsupported script format' + scriptFormat)
         elif self.scriptType == "dml":
             if scriptString.endswith(".dml"):
                 if scriptString.startswith("http"):
-                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl(scriptString)
+                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl(
+                        scriptString)
                 elif os.path.exists(scriptString):
-                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString)
+                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(
+                        scriptString)
                 elif self.isResource == True:
-                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(scriptString)
+                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(
+                        scriptString)
                 else:
                     raise ValueError("path: %s does not exist" % scriptString)
             else:
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(
+                    scriptString)
         elif self.scriptType == "pydml":
             if scriptString.endswith(".pydml"):
                 if scriptString.startswith("http"):
-                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl(scriptString)
+                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl(
+                        scriptString)
                 elif os.path.exists(scriptString):
-                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString)
+                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(
+                        scriptString)
                 elif self.isResource == True:
-                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(scriptString)
+                    self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(
+                        scriptString)
                 else:
                     raise ValueError("path: %s does not exist" % scriptString)
             else:
-                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml(scriptString)
+                self.script_java = 
self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml(
+                    scriptString)
 
-    
     def getScriptString(self):
         """
         Obtain the script string (in unicode).
         """
         return self.script_java.getScriptString()
-    
+
     def setScriptString(self, scriptString):
         """
         Set the script string.
-        
+
         Parameters
         ----------
         scriptString: string
@@ -438,80 +496,80 @@ class Script(object):
         """
         self.script_java.clearIOS()
         return self
-    
+
     def clearIO(self):
         """
         Clear the inputs and outputs, but not the symbol table.
         """
         self.script_java.clearIO()
         return self
-    
+
     def clearAll(self):
         """
         Clear the script string, inputs, outputs, and symbol table.
         """
         self.script_java.clearAll()
         return self
-    
+
     def clearInputs(self):
         """
         Clear the inputs.
         """
         self.script_java.clearInputs()
         return self
-    
+
     def clearOutputs(self):
         """
         Clear the outputs.
         """
         self.script_java.clearOutputs()
         return self
-    
+
     def clearSymbolTable(self):
         """
         Clear the symbol table.
         """
         self.script_java.clearSymbolTable()
         return self
-        
+
     def results(self):
         """
         Obtain the results of the script execution.
         """
         return MLResults(self.script_java.results(), self.sc)
-    
+
     def getResults(self):
         """
         Obtain the results of the script execution.
         """
         return MLResults(self.script_java.getResults(), self.sc)
-        
+
     def setResults(self, results):
         """
         Set the results of the script execution.
         """
         self.script_java.setResults(results._java_results)
         return self
-        
+
     def isDML(self):
         """
         Is the script type DML?
         """
         return self.script_java.isDML()
-    
+
     def isPYDML(self):
         """
         Is the script type DML?
         """
         return self.script_java.isPYDML()
-    
+
     def getScriptExecutionString(self):
         """
         Generate the script execution string, which adds read/load/write/save
         statements to the beginning and end of the script to execute.
         """
-        return self.script_java.getScriptExecutionString()    
-    
+        return self.script_java.getScriptExecutionString()
+
     def __repr__(self):
         return "Script"
 
@@ -528,56 +586,56 @@ class Script(object):
         Display the script inputs.
         """
         return self.script_java.displayInputs()
-    
+
     def displayOutputs(self):
         """
         Display the script outputs.
         """
         return self.script_java.displayOutputs()
-        
+
     def displayInputParameters(self):
         """
         Display the script input parameters.
         """
         return self.script_java.displayInputParameters()
-    
+
     def displayInputVariables(self):
         """
         Display the script input variables.
         """
         return self.script_java.displayInputVariables()
-        
+
     def displayOutputVariables(self):
         """
         Display the script output variables.
         """
         return self.script_java.displayOutputVariables()
-        
+
     def displaySymbolTable(self):
         """
         Display the script symbol table.
         """
         return self.script_java.displaySymbolTable()
-        
+
     def getName(self):
         """
         Obtain the script name.
         """
         return self.script_java.getName()
-        
+
     def setName(self, name):
         """
         Set the script name.
         """
         self.script_java.setName(name)
         return self
-        
+
     def getScriptType(self):
         """
         Obtain the script type.
         """
         return self.scriptType
-        
+
     def input(self, *args, **kwargs):
         """
         Parameters
@@ -606,9 +664,11 @@ class Script(object):
         if isinstance(val, py4j.java_gateway.JavaObject):
             py4j.java_gateway.get_method(self.script_java, "in")(key, val)
         else:
-            py4j.java_gateway.get_method(self.script_java, "in")(key, 
_py2java(self.sc, val))
-    
-    
+            py4j.java_gateway.get_method(
+                self.script_java, "in")(
+                key, _py2java(
+                    self.sc, val))
+
     def output(self, *names):
         """
         Parameters
@@ -630,17 +690,20 @@ class MLContext(object):
     sc: SparkContext or SparkSession
         An instance of pyspark.SparkContext or pyspark.sql.SparkSession.
     """
+
     def __init__(self, sc):
         if isinstance(sc, pyspark.sql.session.SparkSession):
             sc = sc._sc
         elif not isinstance(sc, SparkContext):
-            raise ValueError("Expected sc to be a SparkContext or 
SparkSession, got " % str(type(sc)))
+            raise ValueError(
+                "Expected sc to be a SparkContext or SparkSession, got " % str(
+                    type(sc)))
         self._sc = sc
         self._ml = createJavaObject(sc, 'mlcontext')
 
     def __repr__(self):
         return "MLContext"
-        
+
     def execute(self, script):
         """
         Execute a DML / PyDML script.
@@ -688,7 +751,7 @@ class MLContext(object):
         """
         self._ml.setGPU(bool(enable))
         return self
-    
+
     def setForceGPU(self, enable):
         """
         Whether or not to force the usage of GPU operators.
@@ -699,7 +762,7 @@ class MLContext(object):
         """
         self._ml.setForceGPU(bool(enable))
         return self
-        
+
     def setStatisticsMaxHeavyHitters(self, maxHeavyHitters):
         """
         The maximum number of heavy hitters that are printed as part of the 
statistics.

http://git-wip-us.apache.org/repos/asf/systemml/blob/9e7ee19a/src/main/python/systemml/mllearn/__init__.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/__init__.py 
b/src/main/python/systemml/mllearn/__init__.py
index 907ffb2..eda2197 100644
--- a/src/main/python/systemml/mllearn/__init__.py
+++ b/src/main/python/systemml/mllearn/__init__.py
@@ -1,4 +1,4 @@
-#-------------------------------------------------------------
+# -------------------------------------------------------------
 #
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
@@ -17,7 +17,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-#-------------------------------------------------------------
+# -------------------------------------------------------------
 
 """
 ===================

http://git-wip-us.apache.org/repos/asf/systemml/blob/9e7ee19a/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py 
b/src/main/python/systemml/mllearn/estimators.py
index 52fe9cc..0b3de41 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -1,4 +1,4 @@
-#-------------------------------------------------------------
+# -------------------------------------------------------------
 #
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
@@ -17,9 +17,15 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-#-------------------------------------------------------------
+# -------------------------------------------------------------
 
-__all__ = ['LinearRegression', 'LogisticRegression', 'SVM', 'NaiveBayes', 
'Caffe2DML', 'Keras2DML']
+__all__ = [
+    'LinearRegression',
+    'LogisticRegression',
+    'SVM',
+    'NaiveBayes',
+    'Caffe2DML',
+    'Keras2DML']
 
 import numpy as np
 from pyspark.ml import Estimator
@@ -37,15 +43,17 @@ import math
 from ..converters import *
 from ..classloader import *
 
+
 def assemble(sparkSession, pdf, inputCols, outputCol):
     tmpDF = sparkSession.createDataFrame(pdf, list(pdf.columns))
     assembler = VectorAssembler(inputCols=list(inputCols), outputCol=outputCol)
     return assembler.transform(tmpDF)
 
+
 class BaseSystemMLEstimator(Estimator):
     features_col = 'features'
     label_col = 'label'
-    
+
     def set_features_col(self, colName):
         """
         Sets the default column name for features of PySpark DataFrame.
@@ -76,7 +84,7 @@ class BaseSystemMLEstimator(Estimator):
         """
         self.estimator.setGPU(enable)
         return self
-    
+
     def setForceGPU(self, enable):
         """
         Whether or not to force the usage of GPU operators.
@@ -87,7 +95,7 @@ class BaseSystemMLEstimator(Estimator):
         """
         self.estimator.setForceGPU(enable)
         return self
-        
+
     def setExplain(self, explain):
         """
         Explanation about the program. Mainly intended for developers.
@@ -98,11 +106,11 @@ class BaseSystemMLEstimator(Estimator):
         """
         self.estimator.setExplain(explain)
         return self
-    
+
     def setExplainLevel(self, explainLevel):
         """
         Set explain level. Mainly intended for developers.
-        
+
         Parameters
         ----------
         explainLevel: string
@@ -111,7 +119,7 @@ class BaseSystemMLEstimator(Estimator):
         """
         self.estimator.setExplainLevel(explainLevel)
         return self
-            
+
     def setStatistics(self, statistics):
         """
         Whether or not to output statistics (such as execution time, elapsed 
time)
@@ -123,7 +131,7 @@ class BaseSystemMLEstimator(Estimator):
         """
         self.estimator.setStatistics(statistics)
         return self
-    
+
     def setStatisticsMaxHeavyHitters(self, maxHeavyHitters):
         """
         The maximum number of heavy hitters that are printed as part of the 
statistics.
@@ -134,7 +142,7 @@ class BaseSystemMLEstimator(Estimator):
         """
         self.estimator.setStatisticsMaxHeavyHitters(maxHeavyHitters)
         return self
-        
+
     def setConfigProperty(self, propertyName, propertyValue):
         """
         Set configuration property, such as 
setConfigProperty("sysml.localtmpdir", "/tmp/systemml").
@@ -146,7 +154,7 @@ class BaseSystemMLEstimator(Estimator):
         """
         self.estimator.setConfigProperty(propertyName, propertyValue)
         return self
-    
+
     def _fit_df(self):
         global default_jvm_stdout, default_jvm_stdout_parallel_flush
         try:
@@ -157,28 +165,30 @@ class BaseSystemMLEstimator(Estimator):
                 self.model = self.estimator.fit(self.X._jdf)
         except Py4JError:
             traceback.print_exc()
-    
+
     def fit_df(self, X):
         self.X = X
         self._fit_df()
         self.X = None
         return self
-    
+
     def _fit_numpy(self):
         global default_jvm_stdout, default_jvm_stdout_parallel_flush
         try:
-            if type(self.y) == np.ndarray and len(self.y.shape) == 1:
+            if isinstance(self.y, np.ndarray) and len(self.y.shape) == 1:
                 # Since we know that mllearn always needs a column vector
                 self.y = np.matrix(self.y).T
             y_mb = convertToMatrixBlock(self.sc, self.y)
             if default_jvm_stdout:
                 with 
jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush):
-                    self.model = 
self.estimator.fit(convertToMatrixBlock(self.sc, self.X), y_mb)
+                    self.model = self.estimator.fit(
+                        convertToMatrixBlock(self.sc, self.X), y_mb)
             else:
-                self.model = self.estimator.fit(convertToMatrixBlock(self.sc, 
self.X), y_mb)
+                self.model = self.estimator.fit(
+                    convertToMatrixBlock(self.sc, self.X), y_mb)
         except Py4JError:
             traceback.print_exc()
-                    
+
     def fit_numpy(self, X, y):
         self.X = X
         self.y = y
@@ -198,7 +208,7 @@ class BaseSystemMLEstimator(Estimator):
         except Py4JError:
             traceback.print_exc()
         return self
-                
+
     # Returns a model after calling fit(df) on Estimator object on JVM
     def _fit(self, X):
         """
@@ -208,10 +218,12 @@ class BaseSystemMLEstimator(Estimator):
         ----------
         X: PySpark DataFrame that contain the columns features_col (default: 
'features') and label_col (default: 'label')
         """
-        if hasattr(X, '_jdf') and self.features_col in X.columns and 
self.label_col in X.columns:
+        if hasattr(
+                X, '_jdf') and self.features_col in X.columns and 
self.label_col in X.columns:
             return self.fit_df(X)
         else:
-            raise Exception('Incorrect usage: Expected dataframe as input with 
features/label as columns')
+            raise Exception(
+                'Incorrect usage: Expected dataframe as input with 
features/label as columns')
 
     def fit(self, X, y=None, params=None):
         """
@@ -228,7 +240,8 @@ class BaseSystemMLEstimator(Estimator):
             return self.fit_file(X, y)
         elif isinstance(X, SUPPORTED_TYPES) and isinstance(y, SUPPORTED_TYPES):
             # Donot encode if y is a numpy matrix => useful for segmentation
-            skipEncodingY = len(y.shape) == 2 and y.shape[0] != 1 and 
y.shape[1] != 1
+            skipEncodingY = len(
+                y.shape) == 2 and y.shape[0] != 1 and y.shape[1] != 1
             y = y if skipEncodingY else self.encode(y)
             if self.transferUsingDF:
                 pdfX = convertToPandasDF(X)
@@ -239,7 +252,13 @@ class BaseSystemMLEstimator(Estimator):
                     raise Exception('Number of rows of X and y should match')
                 colNames = pdfX.columns
                 pdfX[self.label_col] = pdfY[pdfY.columns[0]]
-                df = assemble(self.sparkSession, pdfX, colNames, 
self.features_col).select(self.features_col, self.label_col)
+                df = assemble(
+                    self.sparkSession,
+                    pdfX,
+                    colNames,
+                    self.features_col).select(
+                    self.features_col,
+                    self.label_col)
                 self.fit_df(df)
             else:
                 numColsy = getNumCols(y)
@@ -254,7 +273,7 @@ class BaseSystemMLEstimator(Estimator):
 
     def transform(self, X):
         return self.predict(X)
-    
+
     def _convertPythonXToJavaObject(self, X):
         """
         Converts the input python object X to a java-side object (either 
MatrixBlock or Java DataFrame)
@@ -265,7 +284,12 @@ class BaseSystemMLEstimator(Estimator):
         """
         if isinstance(X, SUPPORTED_TYPES) and self.transferUsingDF:
             pdfX = convertToPandasDF(X)
-            df = assemble(self.sparkSession, pdfX, pdfX.columns, 
self.features_col).select(self.features_col)
+            df = assemble(
+                self.sparkSession,
+                pdfX,
+                pdfX.columns,
+                self.features_col).select(
+                self.features_col)
             return df._jdf
         elif isinstance(X, SUPPORTED_TYPES):
             return convertToMatrixBlock(self.sc, X)
@@ -273,12 +297,13 @@ class BaseSystemMLEstimator(Estimator):
             # No need to assemble as input DF is likely coming via MLPipeline
             return X._jdf
         elif hasattr(X, '_jdf'):
-            assembler = VectorAssembler(inputCols=X.columns, 
outputCol=self.features_col)
+            assembler = VectorAssembler(
+                inputCols=X.columns, outputCol=self.features_col)
             df = assembler.transform(X)
             return df._jdf
         else:
             raise Exception('Unsupported input type')
-        
+
     def _convertJavaOutputToPythonObject(self, X, output):
         """
         Converts the a java-side object output (either MatrixBlock or Java 
DataFrame) to a python object (based on the type of X).
@@ -300,7 +325,7 @@ class BaseSystemMLEstimator(Estimator):
             return retDF.sort('__INDEX')
         else:
             raise Exception('Unsupported input type')
-        
+
     def predict_proba(self, X):
         """
         Invokes the transform_probability method on Estimator object on JVM if 
X and y are on of the supported data types
@@ -314,7 +339,8 @@ class BaseSystemMLEstimator(Estimator):
         if hasattr(X, '_jdf'):
             return self.predict(X)
         elif self.transferUsingDF:
-            raise ValueError('The parameter transferUsingDF is not valid for 
the method predict_proba')
+            raise ValueError(
+                'The parameter transferUsingDF is not valid for the method 
predict_proba')
         try:
             if self.estimator is not None and self.model is not None:
                 self.estimator.copyProperties(self.model)
@@ -326,13 +352,16 @@ class BaseSystemMLEstimator(Estimator):
             jX = self._convertPythonXToJavaObject(X)
             if default_jvm_stdout:
                 with 
jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush):
-                    return self._convertJavaOutputToPythonObject(X, 
self.model.transform_probability(jX))
+                    return self._convertJavaOutputToPythonObject(
+                        X, self.model.transform_probability(jX))
             else:
-                return self._convertJavaOutputToPythonObject(X, 
self.model.transform_probability(jX))
+                return self._convertJavaOutputToPythonObject(
+                    X, self.model.transform_probability(jX))
         except Py4JError:
             traceback.print_exc()
-    
-    # Returns either a DataFrame or MatrixBlock after calling 
transform(X:MatrixBlock, y:MatrixBlock) on Model object on JVM
+
+    # Returns either a DataFrame or MatrixBlock after calling
+    # transform(X:MatrixBlock, y:MatrixBlock) on Model object on JVM
     def predict(self, X):
         """
         Invokes the transform method on Estimator object on JVM if X and y are 
on of the supported data types
@@ -353,20 +382,23 @@ class BaseSystemMLEstimator(Estimator):
             jX = self._convertPythonXToJavaObject(X)
             if default_jvm_stdout:
                 with 
jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush):
-                    ret = self._convertJavaOutputToPythonObject(X, 
self.model.transform(jX))
+                    ret = self._convertJavaOutputToPythonObject(
+                        X, self.model.transform(jX))
             else:
-                ret = self._convertJavaOutputToPythonObject(X, 
self.model.transform(jX))
+                ret = self._convertJavaOutputToPythonObject(
+                    X, self.model.transform(jX))
             return self.decode(ret) if isinstance(X, SUPPORTED_TYPES) else ret
         except Py4JError:
             traceback.print_exc()
 
+
 class BaseSystemMLClassifier(BaseSystemMLEstimator):
 
     def encode(self, y):
         self.le = LabelEncoder()
         self.le.fit(y)
         return self.le.transform(y) + 1
-        
+
     def decode(self, y):
         if not hasattr(self, 'le'):
             self.le = None
@@ -375,14 +407,14 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
         if self.le is not None:
             return self.le.inverse_transform(np.asarray(y - 1, dtype=int))
         elif self.labelMap is not None:
-            return [ self.labelMap[int(i)] for i in y ]
+            return [self.labelMap[int(i)] for i in y]
         else:
             return y
-        
+
     def predict(self, X):
         predictions = super(BaseSystemMLClassifier, self).predict(X)
         from pyspark.sql.dataframe import DataFrame as df
-        if type(predictions) == df:
+        if isinstance(predictions, df):
             return predictions
         else:
             try:
@@ -390,7 +422,7 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
             except ValueError:
                 print(type(predictions))
                 return np.asarray(predictions, dtype='str')
-            
+
     def score(self, X, y):
         """
         Scores the predicted value with ground truth 'y'
@@ -404,8 +436,9 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
         if np.issubdtype(predictions.dtype.type, np.number):
             return accuracy_score(y, predictions)
         else:
-            return accuracy_score(np.asarray(y, dtype='str'), 
np.asarray(predictions, dtype='str'))
-            
+            return accuracy_score(np.asarray(
+                y, dtype='str'), np.asarray(predictions, dtype='str'))
+
     def loadLabels(self, file_path):
         createJavaObject(self.sc, 'dummy')
         utilObj = self.sc._jvm.org.apache.sysml.api.ml.Utils()
@@ -417,10 +450,10 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
             for i in range(len(keys)):
                 self.labelMap[int(keys[i])] = values[i]
             # self.encode(classes) # Giving incorrect results
-        
+
     def load(self, weights, sep='/', eager=False):
         """
-        Load a pretrained model. 
+        Load a pretrained model.
 
         Parameters
         ----------
@@ -436,11 +469,11 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
         else:
             self.model.load(self.sc._jsc, weights, sep, eager)
         self.loadLabels(weights + '/labels.txt')
-        
+
     def save(self, outputDir, format='binary', sep='/'):
         """
         Save a trained model.
-        
+
         Parameters
         ----------
         outputDir: Directory to save the model to
@@ -448,7 +481,7 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
         sep: seperator to use (default: '/')
         """
         global default_jvm_stdout, default_jvm_stdout_parallel_flush
-        if self.model != None:
+        if self.model is not None:
             if default_jvm_stdout:
                 with 
jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush):
                     self.model.save(self.sc._jsc, outputDir, format, sep)
@@ -462,21 +495,26 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
                 labelMapping = self.labelMap
 
             if labelMapping is not None:
-                lStr = [ [ int(k), str(labelMapping[k]) ] for k in 
labelMapping ]
+                lStr = [[int(k), str(labelMapping[k])] for k in labelMapping]
                 df = self.sparkSession.createDataFrame(lStr)
-                df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', 
header=False)
+                df.write.csv(
+                    outputDir + sep + 'labels.txt',
+                    mode='overwrite',
+                    header=False)
         else:
-            raise Exception('Cannot save as you need to train the model first 
using fit')
+            raise Exception(
+                'Cannot save as you need to train the model first using fit')
         return self
 
+
 class BaseSystemMLRegressor(BaseSystemMLEstimator):
 
     def encode(self, y):
         return y
-        
+
     def decode(self, y):
         return y
-    
+
     def score(self, X, y):
         """
         Scores the predicted value with ground truth 'y'
@@ -487,10 +525,10 @@ class BaseSystemMLRegressor(BaseSystemMLEstimator):
         y: NumPy ndarray, Pandas DataFrame, scipy sparse matrix
         """
         return r2_score(y, self.predict(X), multioutput='variance_weighted')
-        
+
     def load(self, weights=None, sep='/', eager=False):
         """
-        Load a pretrained model. 
+        Load a pretrained model.
 
         Parameters
         ----------
@@ -509,7 +547,7 @@ class BaseSystemMLRegressor(BaseSystemMLEstimator):
     def save(self, outputDir, format='binary', sep='/'):
         """
         Save a trained model.
-        
+
         Parameters
         ----------
         outputDir: Directory to save the model to
@@ -517,14 +555,15 @@ class BaseSystemMLRegressor(BaseSystemMLEstimator):
         sep: seperator to use (default: '/')
         """
         global default_jvm_stdout, default_jvm_stdout_parallel_flush
-        if self.model != None:
+        if self.model is not None:
             if default_jvm_stdout:
                 with 
jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush):
                     self.model.save(outputDir, format, sep)
             else:
                 self.model.save(outputDir, format, sep)
         else:
-            raise Exception('Cannot save as you need to train the model first 
using fit')
+            raise Exception(
+                'Cannot save as you need to train the model first using fit')
         return self
 
 
@@ -534,9 +573,9 @@ class LogisticRegression(BaseSystemMLClassifier):
 
     Examples
     --------
-    
+
     Scikit-learn way
-    
+
     >>> from sklearn import datasets, neighbors
     >>> from systemml.mllearn import LogisticRegression
     >>> from pyspark.sql import SparkSession
@@ -551,9 +590,9 @@ class LogisticRegression(BaseSystemMLClassifier):
     >>> y_test = y_digits[.9 * n_samples:]
     >>> logistic = LogisticRegression(sparkSession)
     >>> print('LogisticRegression score: %f' % logistic.fit(X_train, 
y_train).score(X_test, y_test))
-    
+
     MLPipeline way
-    
+
     >>> from pyspark.ml import Pipeline
     >>> from systemml.mllearn import LogisticRegression
     >>> from pyspark.ml.feature import HashingTF, Tokenizer
@@ -585,13 +624,14 @@ class LogisticRegression(BaseSystemMLClassifier):
     >>>     (15L, "apache hadoop")], ["id", "text"])
     >>> prediction = model.transform(test)
     >>> prediction.show()
-    
+
     """
-    
-    def __init__(self, sparkSession, penalty='l2', fit_intercept=True, 
normalize=False,  max_iter=100, max_inner_iter=0, tol=0.000001, C=1.0, 
solver='newton-cg', transferUsingDF=False):
+
+    def __init__(self, sparkSession, penalty='l2', fit_intercept=True, 
normalize=False, max_iter=100,
+                 max_inner_iter=0, tol=0.000001, C=1.0, solver='newton-cg', 
transferUsingDF=False):
         """
         Performs both binomial and multinomial logistic regression.
-        
+
         Parameters
         ----------
         sparkSession: PySpark SparkSession
@@ -608,30 +648,33 @@ class LogisticRegression(BaseSystemMLClassifier):
         self.sc = sparkSession._sc
         createJavaObject(self.sc, 'dummy')
         self.uid = "logReg"
-        self.estimator = 
self.sc._jvm.org.apache.sysml.api.ml.LogisticRegression(self.uid, 
self.sc._jsc.sc())
+        self.estimator = 
self.sc._jvm.org.apache.sysml.api.ml.LogisticRegression(
+            self.uid, self.sc._jsc.sc())
         self.estimator.setMaxOuterIter(max_iter)
         self.estimator.setMaxInnerIter(max_inner_iter)
         reg = 0.0 if C == float("inf") else 1.0 / C
-        icpt = 2 if fit_intercept == True and normalize == True else 
int(fit_intercept)
+        icpt = 2 if fit_intercept == True and normalize == True else int(
+            fit_intercept)
         self.estimator.setRegParam(reg)
         self.estimator.setTol(tol)
         self.estimator.setIcpt(icpt)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = True
-        self.model = 
self.sc._jvm.org.apache.sysml.api.ml.LogisticRegressionModel(self.estimator)
+        self.model = 
self.sc._jvm.org.apache.sysml.api.ml.LogisticRegressionModel(
+            self.estimator)
         if penalty != 'l2':
             raise Exception('Only l2 penalty is supported')
         if solver != 'newton-cg':
             raise Exception('Only newton-cg solver supported')
-        
+
 
 class LinearRegression(BaseSystemMLRegressor):
     """
     Performs linear regression to model the relationship between one numerical 
response variable and one or more explanatory (feature) variables.
-    
+
     Examples
     --------
-    
+
     >>> import numpy as np
     >>> from sklearn import datasets
     >>> from systemml.mllearn import LinearRegression
@@ -652,11 +695,11 @@ class LinearRegression(BaseSystemMLRegressor):
     >>> regr.fit(diabetes_X_train, diabetes_y_train)
     >>> # The mean square error
     >>> print("Residual sum of squares: %.2f" % 
np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))
-    
+
     """
-    
-    
-    def __init__(self, sparkSession, fit_intercept=True, normalize=False, 
max_iter=100, tol=0.000001, C=float("inf"), solver='newton-cg', 
transferUsingDF=False):
+
+    def __init__(self, sparkSession, fit_intercept=True, normalize=False, 
max_iter=100,
+                 tol=0.000001, C=float("inf"), solver='newton-cg', 
transferUsingDF=False):
         """
         Performs linear regression to model the relationship between one 
numerical response variable and one or more explanatory (feature) variables.
 
@@ -678,18 +721,21 @@ class LinearRegression(BaseSystemMLRegressor):
         createJavaObject(self.sc, 'dummy')
         self.uid = "lr"
         if solver == 'newton-cg' or solver == 'direct-solve':
-            self.estimator = 
self.sc._jvm.org.apache.sysml.api.ml.LinearRegression(self.uid, 
self.sc._jsc.sc(), solver)
+            self.estimator = 
self.sc._jvm.org.apache.sysml.api.ml.LinearRegression(
+                self.uid, self.sc._jsc.sc(), solver)
         else:
             raise Exception('Only newton-cg solver supported')
         self.estimator.setMaxIter(max_iter)
         reg = 0.0 if C == float("inf") else 1.0 / C
-        icpt = 2 if fit_intercept == True and normalize == True else 
int(fit_intercept)
+        icpt = 2 if fit_intercept == True and normalize == True else int(
+            fit_intercept)
         self.estimator.setRegParam(reg)
         self.estimator.setTol(tol)
         self.estimator.setIcpt(icpt)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = False
-        self.model = 
self.sc._jvm.org.apache.sysml.api.ml.LinearRegressionModel(self.estimator)
+        self.model = 
self.sc._jvm.org.apache.sysml.api.ml.LinearRegressionModel(
+            self.estimator)
 
 
 class SVM(BaseSystemMLClassifier):
@@ -698,14 +744,14 @@ class SVM(BaseSystemMLClassifier):
 
     Examples
     --------
-    
+
     >>> from sklearn import datasets, neighbors
     >>> from systemml.mllearn import SVM
     >>> from pyspark.sql import SparkSession
     >>> sparkSession = SparkSession.builder.getOrCreate()
     >>> digits = datasets.load_digits()
     >>> X_digits = digits.data
-    >>> y_digits = digits.target 
+    >>> y_digits = digits.target
     >>> n_samples = len(X_digits)
     >>> X_train = X_digits[:.9 * n_samples]
     >>> y_train = y_digits[:.9 * n_samples]
@@ -713,11 +759,11 @@ class SVM(BaseSystemMLClassifier):
     >>> y_test = y_digits[.9 * n_samples:]
     >>> svm = SVM(sparkSession, is_multi_class=True)
     >>> print('LogisticRegression score: %f' % svm.fit(X_train, 
y_train).score(X_test, y_test))
-     
-    """
 
+    """
 
-    def __init__(self, sparkSession, fit_intercept=True, normalize=False, 
max_iter=100, tol=0.000001, C=1.0, is_multi_class=False, transferUsingDF=False):
+    def __init__(self, sparkSession, fit_intercept=True, normalize=False, 
max_iter=100,
+                 tol=0.000001, C=1.0, is_multi_class=False, 
transferUsingDF=False):
         """
         Performs both binary-class and multiclass SVM (Support Vector 
Machines).
 
@@ -736,18 +782,22 @@ class SVM(BaseSystemMLClassifier):
         self.uid = "svm"
         createJavaObject(self.sc, 'dummy')
         self.is_multi_class = is_multi_class
-        self.estimator = self.sc._jvm.org.apache.sysml.api.ml.SVM(self.uid, 
self.sc._jsc.sc(), is_multi_class)
+        self.estimator = self.sc._jvm.org.apache.sysml.api.ml.SVM(
+            self.uid, self.sc._jsc.sc(), is_multi_class)
         self.estimator.setMaxIter(max_iter)
         if C <= 0:
             raise Exception('C has to be positive')
         reg = 0.0 if C == float("inf") else 1.0 / C
-        icpt = 2 if fit_intercept == True and normalize == True else 
int(fit_intercept)
+        icpt = 2 if fit_intercept == True and normalize == True else int(
+            fit_intercept)
         self.estimator.setRegParam(reg)
         self.estimator.setTol(tol)
         self.estimator.setIcpt(icpt)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = False
-        self.model = 
self.sc._jvm.org.apache.sysml.api.ml.SVMModel(self.estimator, 
self.is_multi_class)
+        self.model = self.sc._jvm.org.apache.sysml.api.ml.SVMModel(
+            self.estimator, self.is_multi_class)
+
 
 class NaiveBayes(BaseSystemMLClassifier):
     """
@@ -755,7 +805,7 @@ class NaiveBayes(BaseSystemMLClassifier):
 
     Examples
     --------
-    
+
     >>> from sklearn.datasets import fetch_20newsgroups
     >>> from sklearn.feature_extraction.text import TfidfVectorizer
     >>> from systemml.mllearn import NaiveBayes
@@ -775,7 +825,7 @@ class NaiveBayes(BaseSystemMLClassifier):
     >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
 
     """
-    
+
     def __init__(self, sparkSession, laplace=1.0, transferUsingDF=False):
         """
         Performs Naive Bayes.
@@ -789,19 +839,22 @@ class NaiveBayes(BaseSystemMLClassifier):
         self.sc = sparkSession._sc
         self.uid = "nb"
         createJavaObject(self.sc, 'dummy')
-        self.estimator = 
self.sc._jvm.org.apache.sysml.api.ml.NaiveBayes(self.uid, self.sc._jsc.sc())
+        self.estimator = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayes(
+            self.uid, self.sc._jsc.sc())
         self.estimator.setLaplace(laplace)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = False
-        self.model = 
self.sc._jvm.org.apache.sysml.api.ml.NaiveBayesModel(self.estimator)
+        self.model = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayesModel(
+            self.estimator)
+
 
 class Caffe2DML(BaseSystemMLClassifier):
     """
     Performs training/prediction for a given caffe network.
-    
+
     Examples
     --------
-    
+
     >>> from systemml.mllearn import Caffe2DML
     >>> from mlxtend.data import mnist_data
     >>> import numpy as np
@@ -815,9 +868,11 @@ class Caffe2DML(BaseSystemMLClassifier):
     >>> caffe2DML = Caffe2DML(spark, 'lenet_solver.proto').set(max_iter=500)
     >>> caffe2DML.fit(X, y)
     """
-    def __init__(self, sparkSession, solver, input_shape, 
transferUsingDF=False):
+
+    def __init__(self, sparkSession, solver, input_shape,
+                 transferUsingDF=False):
         """
-        Performs training/prediction for a given caffe network. 
+        Performs training/prediction for a given caffe network.
 
         Parameters
         ----------
@@ -833,14 +888,19 @@ class Caffe2DML(BaseSystemMLClassifier):
         self.model = None
         if len(input_shape) != 3:
             raise ValueError('Expected input_shape as list of 3 element')
-        solver = 
self.sc._jvm.org.apache.sysml.api.dl.Utils.readCaffeSolver(solver)
-        self.estimator = 
self.sc._jvm.org.apache.sysml.api.dl.Caffe2DML(self.sc._jsc.sc(), solver, 
str(input_shape[0]), str(input_shape[1]), str(input_shape[2]))
+        solver = self.sc._jvm.org.apache.sysml.api.dl.Utils.readCaffeSolver(
+            solver)
+        self.estimator = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DML(
+            self.sc._jsc.sc(), solver, str(
+                input_shape[0]), str(
+                input_shape[1]), str(
+                input_shape[2]))
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = False
 
     def load(self, weights=None, sep='/', ignore_weights=None, eager=False):
         """
-        Load a pretrained model. 
+        Load a pretrained model.
 
         Parameters
         ----------
@@ -852,7 +912,8 @@ class Caffe2DML(BaseSystemMLClassifier):
         global default_jvm_stdout, default_jvm_stdout_parallel_flush
         self.weights = weights
         self.estimator.setInput("$weights", str(weights))
-        self.model = 
self.sc._jvm.org.apache.sysml.api.dl.Caffe2DMLModel(self.estimator)
+        self.model = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DMLModel(
+            self.estimator)
         if default_jvm_stdout:
             with jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush):
                 self.model.load(self.sc._jsc, weights, sep, eager)
@@ -861,11 +922,12 @@ class Caffe2DML(BaseSystemMLClassifier):
         self.loadLabels(weights + '/labels.txt')
         if ignore_weights is not None:
             self.estimator.setWeightsToIgnore(ignore_weights)
-            
-    def set(self, debug=None, train_algo=None, test_algo=None, 
parallel_batches=None, output_activations=None, perform_one_hot_encoding=None, 
parfor_parameters=None):
+
+    def set(self, debug=None, train_algo=None, test_algo=None, 
parallel_batches=None,
+            output_activations=None, perform_one_hot_encoding=None, 
parfor_parameters=None):
         """
         Set input to Caffe2DML
-        
+
         Parameters
         ----------
         debug: to add debugging DML code such as classification report, print 
DML script, etc (default: False)
@@ -876,37 +938,55 @@ class Caffe2DML(BaseSystemMLClassifier):
         perform_one_hot_encoding: should perform one-hot encoding in DML using 
table function (default: False)
         parfor_parameters: dictionary for parfor parameters when using 
allreduce-style algorithms (default: "")
         """
-        if debug is not None: self.estimator.setInput("$debug", 
str(debug).upper())
-        if train_algo is not None: self.estimator.setInput("$train_algo", 
str(train_algo).lower())
-        if test_algo is not None: self.estimator.setInput("$test_algo", 
str(test_algo).lower())
-        if parallel_batches is not None: 
self.estimator.setInput("$parallel_batches", str(parallel_batches))
-        if output_activations is not None: 
self.estimator.setInput("$output_activations", str(output_activations))
-        if perform_one_hot_encoding is not None: 
self.estimator.setInput("$perform_one_hot_encoding", 
str(perform_one_hot_encoding).lower())
+        if debug is not None:
+            self.estimator.setInput("$debug", str(debug).upper())
+        if train_algo is not None:
+            self.estimator.setInput("$train_algo", str(train_algo).lower())
+        if test_algo is not None:
+            self.estimator.setInput("$test_algo", str(test_algo).lower())
+        if parallel_batches is not None:
+            self.estimator.setInput("$parallel_batches", str(parallel_batches))
+        if output_activations is not None:
+            self.estimator.setInput(
+                "$output_activations",
+                str(output_activations))
+        if perform_one_hot_encoding is not None:
+            self.estimator.setInput(
+                "$perform_one_hot_encoding",
+                str(perform_one_hot_encoding).lower())
         if parfor_parameters is not None:
             if isinstance(parfor_parameters, dict):
                 # Convert dictionary to comma-separated list
-                parfor_parameters = ''.join([ ', ' + str(k) + '=' + str(v) for 
k, v in parfor_parameters.items()]) if len(parfor_parameters) > 0 else ''
-                self.estimator.setInput("$parfor_parameters", 
parfor_parameters)
+                parfor_parameters = ''.join(
+                    [
+                        ', ' +
+                        str(k) +
+                        '=' +
+                        str(v) for k,
+                        v in parfor_parameters.items()]) if 
len(parfor_parameters) > 0 else ''
+                self.estimator.setInput(
+                    "$parfor_parameters", parfor_parameters)
             else:
-                raise TypeError("parfor_parameters should be a dictionary") 
+                raise TypeError("parfor_parameters should be a dictionary")
         return self
-    
+
     def summary(self):
         """
         Print the summary of the network
         """
         import pyspark
         global default_jvm_stdout, default_jvm_stdout_parallel_flush
-        if type(self.sparkSession) == pyspark.sql.session.SparkSession:
+        if isinstance(self.sparkSession, pyspark.sql.session.SparkSession):
             if default_jvm_stdout:
                 with 
jvm_stdout(parallel_flush=default_jvm_stdout_parallel_flush):
                     self.estimator.summary(self.sparkSession._jsparkSession)
             else:
                 self.estimator.summary(self.sparkSession._jsparkSession)
         else:
-            raise TypeError('Please use spark session of type 
pyspark.sql.session.SparkSession in the constructor')
-    
-    
+            raise TypeError(
+                'Please use spark session of type 
pyspark.sql.session.SparkSession in the constructor')
+
+
 class Keras2DML(Caffe2DML):
     """
     Peforms training/prediction for a given keras model.
@@ -914,7 +994,8 @@ class Keras2DML(Caffe2DML):
 
     """
 
-    def __init__(self, sparkSession, keras_model, input_shape, 
transferUsingDF=False, load_keras_weights=True, weights=None, labels=None, 
batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, 
lr_policy="step", weight_decay=5e-4, regularization_type="L2"):
+    def __init__(self, sparkSession, keras_model, input_shape, 
transferUsingDF=False, load_keras_weights=True, weights=None, labels=None,
+                 batch_size=64, max_iter=2000, test_iter=10, 
test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, 
regularization_type="L2"):
         """
         Performs training/prediction for a given keras model.
 
@@ -936,9 +1017,9 @@ class Keras2DML(Caffe2DML):
         weight_decay: regularation strength (default: 5e-4)
         regularization_type: regularization type (default: "L2")
         """
-        from .keras2caffe import *
+        from .keras2caffe import convertKerasToCaffeNetwork, 
convertKerasToCaffeSolver
         import tempfile
-        if type(keras_model) == keras.models.Sequential:
+        if isinstance(keras_model, keras.models.Sequential):
             # Convert the sequential model to functional model
             if keras_model.model is None:
                 keras_model.build()
@@ -946,19 +1027,47 @@ class Keras2DML(Caffe2DML):
         self.name = keras_model.name
         createJavaObject(sparkSession._sc, 'dummy')
         if not hasattr(keras_model, 'optimizer'):
-            keras_model.compile(loss='categorical_crossentropy', 
optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.95, decay=5e-4, 
nesterov=True))
-        convertKerasToCaffeNetwork(keras_model, self.name + ".proto", 
int(batch_size))
-        convertKerasToCaffeSolver(keras_model, self.name + ".proto", self.name 
+ "_solver.proto", int(max_iter), int(test_iter), int(test_interval), 
int(display), lr_policy, weight_decay, regularization_type)
+            keras_model.compile(
+                loss='categorical_crossentropy',
+                optimizer=keras.optimizers.SGD(
+                    lr=0.01,
+                    momentum=0.95,
+                    decay=5e-4,
+                    nesterov=True))
+        convertKerasToCaffeNetwork(
+            keras_model,
+            self.name + ".proto",
+            int(batch_size))
+        convertKerasToCaffeSolver(
+            keras_model,
+            self.name + ".proto",
+            self.name + "_solver.proto",
+            int(max_iter),
+            int(test_iter),
+            int(test_interval),
+            int(display),
+            lr_policy,
+            weight_decay,
+            regularization_type)
         self.weights = tempfile.mkdtemp() if weights is None else weights
         if load_keras_weights:
-            convertKerasToSystemMLModel(sparkSession, keras_model, 
self.weights)
-        if labels is not None and (labels.startswith('https:') or 
labels.startswith('http:')):
+            convertKerasToSystemMLModel(
+                sparkSession, keras_model, self.weights)
+        if labels is not None and (labels.startswith(
+                'https:') or labels.startswith('http:')):
             import urllib
             urllib.urlretrieve(labels, os.path.join(weights, 'labels.txt'))
         elif labels is not None:
             from shutil import copyfile
             copyfile(labels, os.path.join(weights, 'labels.txt'))
-        super(Keras2DML,self).__init__(sparkSession, self.name + 
"_solver.proto", input_shape, transferUsingDF)
+        super(
+            Keras2DML,
+            self).__init__(
+            sparkSession,
+            self.name +
+            "_solver.proto",
+            input_shape,
+            transferUsingDF)
         if load_keras_weights:
             self.load(self.weights)
 

Reply via email to