Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20695#discussion_r181802586
--- Diff: python/pyspark/ml/stat.py ---
@@ -195,6 +197,195 @@ def test(dataset, sampleCol, distName, *params):
_jvm().PythonUtils.toSeq(params)))
+class Summarizer(object):
+ """
+ .. note:: Experimental
+
+ Tools for vectorized statistics on MLlib Vectors.
+ The methods in this package provide various statistics for Vectors
contained inside DataFrames.
+ This class lets users pick the statistics they would like to extract
for a given column.
+
+ >>> from pyspark.ml.stat import Summarizer
+ >>> from pyspark.sql import Row
+ >>> from pyspark.ml.linalg import Vectors
+ >>> summarizer = Summarizer.metrics("mean", "count")
+ >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0,
1.0, 1.0)),
+ ... Row(weight=0.0, features=Vectors.dense(1.0,
2.0, 3.0))]).toDF()
+ >>> df.select(summarizer.summary(df.features,
df.weight)).show(truncate=False)
+ +-----------------------------------+
+ |aggregate_metrics(features, weight)|
+ +-----------------------------------+
+ |[[1.0,1.0,1.0], 1] |
+ +-----------------------------------+
+ <BLANKLINE>
+ >>> df.select(summarizer.summary(df.features)).show(truncate=False)
+ +--------------------------------+
+ |aggregate_metrics(features, 1.0)|
+ +--------------------------------+
+ |[[1.0,1.5,2.0], 2] |
+ +--------------------------------+
+ <BLANKLINE>
+ >>> df.select(Summarizer.mean(df.features,
df.weight)).show(truncate=False)
+ +--------------+
+ |mean(features)|
+ +--------------+
+ |[1.0,1.0,1.0] |
+ +--------------+
+ <BLANKLINE>
+ >>> df.select(Summarizer.mean(df.features)).show(truncate=False)
+ +--------------+
+ |mean(features)|
+ +--------------+
+ |[1.0,1.5,2.0] |
+ +--------------+
+ <BLANKLINE>
+
+ .. versionadded:: 2.4.0
+
+ """
+ @staticmethod
+ @since("2.4.0")
+ def mean(col, weightCol=None):
+ """
+ return a column of mean summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "mean")
+
+ @staticmethod
+ @since("2.4.0")
+ def variance(col, weightCol=None):
+ """
+ return a column of variance summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "variance")
+
+ @staticmethod
+ @since("2.4.0")
+ def count(col, weightCol=None):
+ """
+ return a column of count summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "count")
+
+ @staticmethod
+ @since("2.4.0")
+ def numNonZeros(col, weightCol=None):
+ """
+ return a column of numNonZero summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "numNonZeros")
+
+ @staticmethod
+ @since("2.4.0")
+ def max(col, weightCol=None):
+ """
+ return a column of max summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "max")
+
+ @staticmethod
+ @since("2.4.0")
+ def min(col, weightCol=None):
+ """
+ return a column of min summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "min")
+
+ @staticmethod
+ @since("2.4.0")
+ def normL1(col, weightCol=None):
+ """
+ return a column of normL1 summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "normL1")
+
+ @staticmethod
+ @since("2.4.0")
+ def normL2(col, weightCol=None):
+ """
+ return a column of normL2 summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "normL2")
+
+ @staticmethod
+ def _check_param(featuresCol, weightCol):
+ if weightCol is None:
+ weightCol = lit(1.0)
+ if not isinstance(featuresCol, Column) or not
isinstance(weightCol, Column):
+ raise TypeError("featureCol and weightCol should be a Column")
+ return featuresCol, weightCol
+
+ @staticmethod
+ def _get_single_metric(col, weightCol, metric):
+ col, weightCol = Summarizer._check_param(col, weightCol)
+ return
Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." +
metric,
+ col._jc, weightCol._jc))
+
+ @staticmethod
+ @since("2.4.0")
+ def metrics(*metrics):
+ """
+ Given a list of metrics, provides a builder that it turns computes
metrics from a column.
+
+ See the documentation of [[Summarizer]] for an example.
+
+ The following metrics are accepted (case sensitive):
+ - mean: a vector that contains the coefficient-wise mean.
+ - variance: a vector tha contains the coefficient-wise variance.
+ - count: the count of all vectors seen.
+ - numNonzeros: a vector with the number of non-zeros for each
coefficients
+ - max: the maximum for each coefficient.
+ - min: the minimum for each coefficient.
+ - normL2: the Euclidian norm for each coefficient.
+ - normL1: the L1 norm of each coefficient (sum of the absolute
values).
+
+ :param metrics:
+ metrics that can be provided.
+ :return:
+ an object of :py:class:`pyspark.ml.stat.SummaryBuilder`
+
+ Note: Currently, the performance of this interface is about 2x~3x
slower then using the RDD
+ interface.
+ """
+ sc = SparkContext._active_spark_context
+ js =
JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics",
+ _to_seq(sc, metrics))
+ return SummaryBuilder(js)
+
+
+class SummaryBuilder(JavaWrapper):
+ """
+ .. note:: Experimental
+
+ A builder object that provides summary statistics about a given column.
+
+ Users should not directly create such builders, but instead use one of
the methods in
+ :py:class:`pyspark.ml.stat.Summarizer`
+
+ .. versionadded:: 2.4.0
+
+ """
+ def __init__(self, js):
+ self._js = js
--- End diff --
This should call the super's init method, and it should store js in
_java_obj (which is set in the JavaWrapper init).
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]