zero323 commented on a change in pull request #30471:
URL: https://github.com/apache/spark/pull/30471#discussion_r533570707
##########
File path: python/pyspark/ml/classification.py
##########
@@ -2991,8 +2994,59 @@ def _to_java(self):
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
return _java_obj
+ @classmethod
+ def read(cls):
+ return OneVsRestReader(cls)
+
+ def write(self):
+ if isinstance(self.getClassifier(), JavaMLWritable):
+ return JavaMLWriter(self)
+ else:
+ return OneVsRestWriter(self)
+
+
+class OneVsRestSharedReadWrite:
+ @staticmethod
+ def saveImpl(instance, sc, path, extraMetadata=None):
+ skipParams = ['classifier']
+ jsonParams = DefaultParamsWriter.extractJsonParams(instance,
skipParams)
+ DefaultParamsWriter.saveMetadata(instance, path, sc,
paramMap=jsonParams,
+ extraMetadata=extraMetadata)
+ classifierPath = os.path.join(path, 'classifier')
+ instance.getClassifier().save(classifierPath)
+
+ @staticmethod
+ def loadClassifier(path, sc):
+ classifierPath = os.path.join(path, 'classifier')
+ return DefaultParamsReader.loadParamsInstance(classifierPath, sc)
+
+
+class OneVsRestReader(MLReader):
+ def __init__(self, cls):
+ super(OneVsRestReader, self).__init__()
+ self.cls = cls
+
+ def load(self, path):
+ metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ if not DefaultParamsReader.isPythonParamsInstance(metadata):
+ return JavaMLReader(self.cls).load(path)
+ else:
+ classifier = OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
+ ova = OneVsRest(classifier=classifier)._resetUid(metadata['uid'])
+ DefaultParamsReader.getAndSetParams(ova, metadata,
skipParams=['classifier'])
+ return ova
+
+
+class OneVsRestWriter(MLWriter):
+ def __init__(self, instance):
+ super(OneVsRestWriter, self).__init__()
+ self.instance = instance
+
+ def saveImpl(self, path):
+ OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path)
Review comment:
These seem to be intended for internal usage only. If that's the case,
shall we mark them as a such (`OneVsRestSharedReadWrite` ->
`_OneVsRestSharedReadWrite`, etc.) and skip annotations?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]