This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5bb9647 [SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark 5bb9647 is described below commit 5bb9647e1019ea7eb17af7d2057fdacb7f4c560b Author: Huaxin Gao <huax...@us.ibm.com> AuthorDate: Fri Feb 1 17:29:58 2019 -0600 [SPARK-26754][PYTHON] Add hasTrainingSummary to replace duplicate code in PySpark ## What changes were proposed in this pull request? Python version of https://github.com/apache/spark/pull/17654 ## How was this patch tested? Existing Python unit test Closes #23676 from huaxingao/spark26754. Authored-by: Huaxin Gao <huax...@us.ibm.com> Signed-off-by: Sean Owen <sean.o...@databricks.com> --- python/pyspark/ml/classification.py | 19 ++++++------------- python/pyspark/ml/clustering.py | 37 ++++++------------------------------- python/pyspark/ml/regression.py | 30 ++++++------------------------ python/pyspark/ml/util.py | 26 ++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 68 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 89b9278..134b9e0 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -483,7 +483,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti return self.getOrDefault(self.upperBoundsOnIntercepts) -class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): +class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable, + HasTrainingSummary): """ Model fitted by LogisticRegression. @@ -532,24 +533,16 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable trained on the training set. An exception is thrown if `trainingSummary is None`. """ if self.hasSummary: - java_lrt_summary = self._call_java("summary") if self.numClasses <= 2: - return BinaryLogisticRegressionTrainingSummary(java_lrt_summary) + return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel, + self).summary) else: - return LogisticRegressionTrainingSummary(java_lrt_summary) + return LogisticRegressionTrainingSummary(super(LogisticRegressionModel, + self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) - @property - @since("2.0.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model - instance. - """ - return self._call_java("hasSummary") - @since("2.0.0") def evaluate(self, dataset): """ diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b9c6bdf..864e2a3 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -97,7 +97,7 @@ class ClusteringSummary(JavaWrapper): return self._call_java("numIter") -class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): +class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by GaussianMixture. @@ -126,22 +126,13 @@ class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): @property @since("2.1.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model - instance. - """ - return self._call_java("hasSummary") - - @property - @since("2.1.0") def summary(self): """ Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An exception is thrown if no summary exists. """ if self.hasSummary: - return GaussianMixtureSummary(self._call_java("summary")) + return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) @@ -323,7 +314,7 @@ class KMeansSummary(ClusteringSummary): return self._call_java("trainingCost") -class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable): +class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by KMeans. @@ -337,21 +328,13 @@ class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable): @property @since("2.1.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model instance. - """ - return self._call_java("hasSummary") - - @property - @since("2.1.0") def summary(self): """ Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An exception is thrown if no summary exists. """ if self.hasSummary: - return KMeansSummary(self._call_java("summary")) + return KMeansSummary(super(KMeansModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) @@ -507,7 +490,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol return self.getOrDefault(self.distanceMeasure) -class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): +class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by BisectingKMeans. @@ -536,21 +519,13 @@ class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): @property @since("2.1.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model instance. - """ - return self._call_java("hasSummary") - - @property - @since("2.1.0") def summary(self): """ Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the training set. An exception is thrown if no summary exists. """ if self.hasSummary: - return BisectingKMeansSummary(self._call_java("summary")) + return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9e1f8f8..7841de9 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -161,7 +161,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction return self.getOrDefault(self.epsilon) -class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable, + HasTrainingSummary): """ Model fitted by :class:`LinearRegression`. @@ -201,21 +202,11 @@ class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritabl `trainingSummary is None`. """ if self.hasSummary: - java_lrt_summary = self._call_java("summary") - return LinearRegressionTrainingSummary(java_lrt_summary) + return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) - @property - @since("2.0.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model - instance. - """ - return self._call_java("hasSummary") - @since("2.0.0") def evaluate(self, dataset): """ @@ -1648,7 +1639,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, - JavaMLReadable): + JavaMLReadable, HasTrainingSummary): """ .. note:: Experimental @@ -1682,21 +1673,12 @@ class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWri `trainingSummary is None`. """ if self.hasSummary: - java_glrt_summary = self._call_java("summary") - return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + return GeneralizedLinearRegressionTrainingSummary( + super(GeneralizedLinearRegressionModel, self).summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) - @property - @since("2.0.0") - def hasSummary(self): - """ - Indicates whether a training summary exists for this model - instance. - """ - return self._call_java("hasSummary") - @since("2.0.0") def evaluate(self, dataset): """ diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index e846834..e184e1a 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -611,3 +611,29 @@ class DefaultParamsReader(MLReader): py_type = DefaultParamsReader.__get_class(pythonClassName) instance = py_type.load(path) return instance + + +@inherit_doc +class HasTrainingSummary(object): + """ + Base class for models that provides Training summary. + .. versionadded:: 3.0.0 + """ + + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary of the model trained on the training set. An exception is thrown if + no summary exists. + """ + return (self._call_java("summary")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org