Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/18742#discussion_r131516889 --- Diff: python/pyspark/ml/util.py --- @@ -283,3 +333,204 @@ def numFeatures(self): Returns the number of features the model was trained on. If unknown, returns -1 """ return self._call_java("numFeatures") + + +@inherit_doc +class DefaultParamsWritable(MLWritable): + """ + .. note:: DeveloperApi + + Helper trait for making simple `Params` types writable. If a `Params` class stores + all data as [[pyspark.ml.param.Param]] values, then extending this trait will provide + a default implementation of writing saved instances of the class. + This only handles simple [[pyspark.ml.param.Param]] types; e.g., it will not handle + [[pyspark.sql.Dataset]]. + + @see `DefaultParamsReadable`, the counterpart to this trait + + .. versionadded:: 2.3.0 + """ + + def write(self): + """Returns a DefaultParamsWriter instance for this class.""" + from pyspark.ml.param import Params + + if isinstance(self, Params): + return DefaultParamsWriter(self) + else: + raise TypeError("Cannot use DefautParamsWritable with type %s because it does not " + + " extend Params.", type(self)) + + +@inherit_doc +class DefaultParamsWriter(MLWriter): + """ + .. note:: DeveloperApi + + Specialization of :py:class:`MLWriter` for :py:class:`Params` types + + Class for writing Estimators and Transformers whose parameters are JSON-serializable. + + .. versionadded:: 2.3.0 + """ + + def __init__(self, instance): + super(DefaultParamsWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + DefaultParamsWriter.saveMetadata(self.instance, path, self.sc) + + @staticmethod + def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): + """ + Saves metadata + Params to: path + "/metadata" + - class + - timestamp + - sparkVersion + - uid + - paramMap + - (optionally, extra metadata) + @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + @param paramMap If given, this is saved in the "paramMap" field. + """ + metadataPath = os.path.join(path, "metadata") + metadataJson = DefaultParamsWriter._get_metadata_to_save(instance, + sc, + extraMetadata, + paramMap) + sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath) + + @staticmethod + def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): + """ + Helper for [[saveMetadata()]] which extracts the JSON to save. + This is useful for ensemble models which need to save metadata for many sub-models. + + @see [[saveMetadata()]] for details on what this includes. + """ + uid = instance.uid + cls = instance.__module__ + '.' + instance.__class__.__name__ + params = instance.extractParamMap() + jsonParams = {} + if paramMap is not None: + jsonParams = paramMap + else: + for p in params: + jsonParams[p.name] = params[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + if extraMetadata is not None: + basicMetadata.update(extraMetadata) + return json.dumps(basicMetadata, separators=[',', ':']) + + +@inherit_doc +class DefaultParamsReadable(MLReadable): + """ + .. note:: DeveloperApi + + Helper trait for making simple `Params` types readable. If a `Params` class stores + all data as [[pyspark.ml.param.Param]] values, then extending this trait will provide + a default implementation of reading saved instances of the class. + This only handles simple [[pyspark.ml.param.Param]] types; e.g., it will not handle + [[pyspark.sql.Dataset]]. + + @see `DefaultParamsWritable`, the counterpart to this trait + + .. versionadded:: 2.3.0 + """ + + @classmethod + def read(cls): + """Returns a DefaultParamsReader instance for this class.""" + return DefaultParamsReader(cls) + + +@inherit_doc +class DefaultParamsReader(MLReader): + """ + .. note:: DeveloperApi + + Specialization of :py:class:`MLReader` for :py:class:`Params` types + + Default `MLReader` implementation for transformers and estimators that + contain basic (json-serializable) params and no data. This will not handle + more complex params or types with data (e.g., models with coefficients). + + .. versionadded:: 2.3.0 + """ + + def __init__(self, cls): + super(DefaultParamsReader, self).__init__() + self.cls = cls + + @staticmethod + def __get_class(clazz): + """ + Loads Python class from its name. + """ + parts = clazz.split('.') + module = ".".join(parts[:-1]) + m = __import__(module) + for comp in parts[1:]: + m = getattr(m, comp) + return m + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + py_type = DefaultParamsReader.__get_class(metadata['class']) + instance = py_type() + instance._resetUid(metadata['uid']) + DefaultParamsReader.getAndSetParams(instance, metadata) + return instance + + @staticmethod + def loadMetadata(path, sc, expectedClassName=""): + """ + Load metadata saved using [[DefaultParamsWriter.saveMetadata()]] + + @param expectedClassName If non empty, this is checked against the loaded metadata. --- End diff -- I missed this before: the Scala-style ```@param``` and ```@see``` do not work in Python. Look in other file like pipeline.py for examples of ```:param```, and maybe use ```.. note::``` to replace ```@see```
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org