zero323 commented on a change in pull request #30471:
URL: https://github.com/apache/spark/pull/30471#discussion_r534487321
##########
File path: python/pyspark/ml/classification.py
##########
@@ -2991,8 +2994,59 @@ def _to_java(self):
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
return _java_obj
+ @classmethod
+ def read(cls):
+ return OneVsRestReader(cls)
+
+ def write(self):
+ if isinstance(self.getClassifier(), JavaMLWritable):
+ return JavaMLWriter(self)
+ else:
+ return OneVsRestWriter(self)
+
+
+class OneVsRestSharedReadWrite:
+ @staticmethod
+ def saveImpl(instance, sc, path, extraMetadata=None):
+ skipParams = ['classifier']
+ jsonParams = DefaultParamsWriter.extractJsonParams(instance,
skipParams)
+ DefaultParamsWriter.saveMetadata(instance, path, sc,
paramMap=jsonParams,
+ extraMetadata=extraMetadata)
+ classifierPath = os.path.join(path, 'classifier')
+ instance.getClassifier().save(classifierPath)
+
+ @staticmethod
+ def loadClassifier(path, sc):
+ classifierPath = os.path.join(path, 'classifier')
+ return DefaultParamsReader.loadParamsInstance(classifierPath, sc)
+
+
+class OneVsRestReader(MLReader):
+ def __init__(self, cls):
+ super(OneVsRestReader, self).__init__()
+ self.cls = cls
+
+ def load(self, path):
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ if not DefaultParamsReader.isPythonParamsInstance(metadata):
+ return JavaMLReader(self.cls).load(path)
+ else:
+ classifier = OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
+ ova = OneVsRest(classifier=classifier)._resetUid(metadata['uid'])
+ DefaultParamsReader.getAndSetParams(ova, metadata,
skipParams=['classifier'])
+ return ova
+
+
+class OneVsRestWriter(MLWriter):
+ def __init__(self, instance):
+ super(OneVsRestWriter, self).__init__()
+ self.instance = instance
+
+ def saveImpl(self, path):
+ OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path)
Review comment:
> Except OneVsRestSharedReadWrite, other Reader/Writer should be public,
similar to Pipeline Reader/Writer
My point was that the particular implementation of `Reader` / `Writer` is
completely irrelevant to the end user. In terms of types we still expect
`MLReader` / `MLWriter` and `OneVsRest` variants are unlikely to be useful for
anyone by devs.
I believe that for the same reason we mark `PipelineReader` and
`PipelineModelWriter` as private in comments:
https://github.com/apache/spark/blob/92bfbcb2e372e8fecfe65bc582c779d9df4036bb/python/pyspark/ml/pipeline.py#L197
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]