Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18742#discussion_r131516889
  
    --- Diff: python/pyspark/ml/util.py ---
    @@ -283,3 +333,204 @@ def numFeatures(self):
             Returns the number of features the model was trained on. If 
unknown, returns -1
             """
             return self._call_java("numFeatures")
    +
    +
    +@inherit_doc
    +class DefaultParamsWritable(MLWritable):
    +    """
    +    .. note:: DeveloperApi
    +
    +    Helper trait for making simple `Params` types writable.  If a `Params` 
class stores
    +    all data as [[pyspark.ml.param.Param]] values, then extending this 
trait will provide
    +    a default implementation of writing saved instances of the class.
    +    This only handles simple [[pyspark.ml.param.Param]] types; e.g., it 
will not handle
    +    [[pyspark.sql.Dataset]].
    +
    +    @see `DefaultParamsReadable`, the counterpart to this trait
    +
    +    .. versionadded:: 2.3.0
    +    """
    +
    +    def write(self):
    +        """Returns a DefaultParamsWriter instance for this class."""
    +        from pyspark.ml.param import Params
    +
    +        if isinstance(self, Params):
    +            return DefaultParamsWriter(self)
    +        else:
    +            raise TypeError("Cannot use DefautParamsWritable with type %s 
because it does not " +
    +                            " extend Params.", type(self))
    +
    +
    +@inherit_doc
    +class DefaultParamsWriter(MLWriter):
    +    """
    +    .. note:: DeveloperApi
    +
    +    Specialization of :py:class:`MLWriter` for :py:class:`Params` types
    +
    +    Class for writing Estimators and Transformers whose parameters are 
JSON-serializable.
    +
    +    .. versionadded:: 2.3.0
    +    """
    +
    +    def __init__(self, instance):
    +        super(DefaultParamsWriter, self).__init__()
    +        self.instance = instance
    +
    +    def saveImpl(self, path):
    +        DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
    +
    +    @staticmethod
    +    def saveMetadata(instance, path, sc, extraMetadata=None, 
paramMap=None):
    +        """
    +        Saves metadata + Params to: path + "/metadata"
    +        - class
    +        - timestamp
    +        - sparkVersion
    +        - uid
    +        - paramMap
    +        - (optionally, extra metadata)
    +        @param extraMetadata  Extra metadata to be saved at same level as 
uid, paramMap, etc.
    +        @param paramMap  If given, this is saved in the "paramMap" field.
    +        """
    +        metadataPath = os.path.join(path, "metadata")
    +        metadataJson = DefaultParamsWriter._get_metadata_to_save(instance,
    +                                                                 sc,
    +                                                                 
extraMetadata,
    +                                                                 paramMap)
    +        sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
    +
    +    @staticmethod
    +    def _get_metadata_to_save(instance, sc, extraMetadata=None, 
paramMap=None):
    +        """
    +        Helper for [[saveMetadata()]] which extracts the JSON to save.
    +        This is useful for ensemble models which need to save metadata for 
many sub-models.
    +
    +        @see [[saveMetadata()]] for details on what this includes.
    +        """
    +        uid = instance.uid
    +        cls = instance.__module__ + '.' + instance.__class__.__name__
    +        params = instance.extractParamMap()
    +        jsonParams = {}
    +        if paramMap is not None:
    +            jsonParams = paramMap
    +        else:
    +            for p in params:
    +                jsonParams[p.name] = params[p]
    +        basicMetadata = {"class": cls, "timestamp": long(round(time.time() 
* 1000)),
    +                         "sparkVersion": sc.version, "uid": uid, 
"paramMap": jsonParams}
    +        if extraMetadata is not None:
    +            basicMetadata.update(extraMetadata)
    +        return json.dumps(basicMetadata, separators=[',',  ':'])
    +
    +
    +@inherit_doc
    +class DefaultParamsReadable(MLReadable):
    +    """
    +    .. note:: DeveloperApi
    +
    +    Helper trait for making simple `Params` types readable.  If a `Params` 
class stores
    +    all data as [[pyspark.ml.param.Param]] values, then extending this 
trait will provide
    +    a default implementation of reading saved instances of the class.
    +    This only handles simple [[pyspark.ml.param.Param]] types; e.g., it 
will not handle
    +    [[pyspark.sql.Dataset]].
    +
    +    @see `DefaultParamsWritable`, the counterpart to this trait
    +
    +    .. versionadded:: 2.3.0
    +    """
    +
    +    @classmethod
    +    def read(cls):
    +        """Returns a DefaultParamsReader instance for this class."""
    +        return DefaultParamsReader(cls)
    +
    +
    +@inherit_doc
    +class DefaultParamsReader(MLReader):
    +    """
    +    .. note:: DeveloperApi
    +
    +    Specialization of :py:class:`MLReader` for :py:class:`Params` types
    +
    +    Default `MLReader` implementation for transformers and estimators that
    +    contain basic (json-serializable) params and no data. This will not 
handle
    +    more complex params or types with data (e.g., models with 
coefficients).
    +
    +    .. versionadded:: 2.3.0
    +    """
    +
    +    def __init__(self, cls):
    +        super(DefaultParamsReader, self).__init__()
    +        self.cls = cls
    +
    +    @staticmethod
    +    def __get_class(clazz):
    +        """
    +        Loads Python class from its name.
    +        """
    +        parts = clazz.split('.')
    +        module = ".".join(parts[:-1])
    +        m = __import__(module)
    +        for comp in parts[1:]:
    +            m = getattr(m, comp)
    +        return m
    +
    +    def load(self, path):
    +        metadata = DefaultParamsReader.loadMetadata(path, self.sc)
    +        py_type = DefaultParamsReader.__get_class(metadata['class'])
    +        instance = py_type()
    +        instance._resetUid(metadata['uid'])
    +        DefaultParamsReader.getAndSetParams(instance, metadata)
    +        return instance
    +
    +    @staticmethod
    +    def loadMetadata(path, sc, expectedClassName=""):
    +        """
    +        Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
    +
    +        @param expectedClassName  If non empty, this is checked against 
the loaded metadata.
    --- End diff --
    
    I missed this before: the Scala-style ```@param``` and ```@see``` do not 
work in Python.  Look in other file like pipeline.py for examples of 
```:param```, and maybe use ```.. note::``` to replace ```@see```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to