[ https://issues.apache.org/jira/browse/SPARK-21542?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16681891#comment-16681891 ]
John Bauer commented on SPARK-21542: ------------------------------------ {code} from pyspark.sql import SparkSession from pyspark.sql.functions import col, when, randn from pyspark import keyword_only from pyspark.ml import Estimator, Model #from pyspark.ml.feature import SQLTransformer from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasInputCol, HasOutputCol spark = SparkSession\ .builder\ .appName("ImputeNormal")\ .getOrCreate() class ImputeNormal(Estimator, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable, ): @keyword_only def __init__(self, inputCol="inputCol", outputCol="outputCol"): super(ImputeNormal, self).__init__() self._setDefault(inputCol="inputCol", outputCol="outputCol") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, inputCol="inputCol", outputCol="outputCol"): """ setParams(self, inputCol="inputCol", outputCol="outputCol") """ kwargs = self._input_kwargs self._set(**kwargs) return self def _fit(self, data): inputCol = self.getInputCol() outputCol = self.getOutputCol() stats = data.select(inputCol).describe() mean = stats.where(col("summary") == "mean").take(1)[0][inputCol] stddev = stats.where(col("summary") == "stddev").take(1)[0][inputCol] return ImputeNormalModel(mean=float(mean), stddev=float(stddev), inputCol=inputCol, outputCol=outputCol, ) # FOR A TRULY MINIMAL BUT LESS DIDACTICALLY EFFECTIVE DEMO, DO INSTEAD: # sql_text = "SELECT *, IF({inputCol} IS NULL, {stddev} * randn() + {mean}, {inputCol}) AS {outputCol} FROM __THIS__" # # return SQLTransformer(statement=sql_text.format(stddev=stddev, mean=mean, inputCol=inputCol, outputCol=outputCol)) class ImputeNormalModel(Model, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable, ): mean = Param(Params._dummy(), "mean", "Mean value of imputations. Calculated by fit method.", typeConverter=TypeConverters.toFloat) stddev = Param(Params._dummy(), "stddev", "Standard deviation of imputations. Calculated by fit method.", typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, mean=0.0, stddev=1.0, inputCol="inputCol", outputCol="outputCol"): super(ImputeNormalModel, self).__init__() self._setDefault(mean=0.0, stddev=1.0, inputCol="inputCol", outputCol="outputCol") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, mean=0.0, stddev=1.0, inputCol="inputCol", outputCol="outputCol"): """ setParams(self, mean=0.0, stddev=1.0, inputCol="inputCol", outputCol="outputCol") """ kwargs = self._input_kwargs self._set(**kwargs) return self def getMean(self): return self.getOrDefault(self.mean) def setMean(self, mean): self._set(mean=mean) def getStddev(self): return self.getOrDefault(self.stddev) def setStddev(self, stddev): self._set(stddev=stddev) def _transform(self, data): mean = self.getMean() stddev = self.getStddev() inputCol = self.getInputCol() outputCol = self.getOutputCol() df = data.withColumn(outputCol, when(col(inputCol).isNull(), stddev * randn() + mean).\ otherwise(col(inputCol))) return df if __name__ == "__main__": train = spark.createDataFrame([[0],[1],[2]] + [[None]]*100,['input']) impute = ImputeNormal(inputCol='input', outputCol='output') impute_model = impute.fit(train) print("Input column: {}".format(impute_model.getInputCol())) print("Output column: {}".format(impute_model.getOutputCol())) print("Mean: {}".format(impute_model.getMean())) print("Standard Deviation: {}".format(impute_model.getStddev())) test = impute_model.transform(train) test.show(10) test.describe().show() print("mean and stddev for outputCol should be close to those of inputCol"){code} > Helper functions for custom Python Persistence > ---------------------------------------------- > > Key: SPARK-21542 > URL: https://issues.apache.org/jira/browse/SPARK-21542 > Project: Spark > Issue Type: New Feature > Components: ML, PySpark > Affects Versions: 2.2.0 > Reporter: Ajay Saini > Assignee: Ajay Saini > Priority: Major > Fix For: 2.3.0 > > > Currently, there is no way to easily persist Json-serializable parameters in > Python only. All parameters in Python are persisted by converting them to > Java objects and using the Java persistence implementation. In order to > facilitate the creation of custom Python-only pipeline stages, it would be > good to have a Python-only persistence framework so that these stages do not > need to be implemented in Scala for persistence. > This task involves: > - Adding implementations for DefaultParamsReadable, DefaultParamsWriteable, > DefaultParamsReader, and DefaultParamsWriter in pyspark. -- This message was sent by Atlassian JIRA (v7.6.3#76005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org