Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20695#discussion_r181263309
  
    --- Diff: python/pyspark/ml/stat.py ---
    @@ -195,6 +197,185 @@ def test(dataset, sampleCol, distName, *params):
                                                  
_jvm().PythonUtils.toSeq(params)))
     
     
    +class Summarizer(object):
    +    """
    +    .. note:: Experimental
    +
    +    Tools for vectorized statistics on MLlib Vectors.
    +    The methods in this package provide various statistics for Vectors 
contained inside DataFrames.
    +    This class lets users pick the statistics they would like to extract 
for a given column.
    +
    +    >>> from pyspark.ml.stat import Summarizer
    +    >>> from pyspark.sql import Row
    +    >>> from pyspark.ml.linalg import Vectors
    +    >>> summarizer = Summarizer.metrics("mean", "count")
    +    >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 
1.0, 1.0)),
    +    ...                      Row(weight=0.0, features=Vectors.dense(1.0, 
2.0, 3.0))]).toDF()
    +    >>> df.select(summarizer.summary(df.features, 
df.weight)).show(truncate=False)
    +    +-----------------------------------+
    +    |aggregate_metrics(features, weight)|
    +    +-----------------------------------+
    +    |[[1.0,1.0,1.0], 1]                 |
    +    +-----------------------------------+
    +    <BLANKLINE>
    +    >>> df.select(summarizer.summary(df.features)).show(truncate=False)
    +    +--------------------------------+
    +    |aggregate_metrics(features, 1.0)|
    +    +--------------------------------+
    +    |[[1.0,1.5,2.0], 2]              |
    +    +--------------------------------+
    +    <BLANKLINE>
    +    >>> df.select(Summarizer.mean(df.features, 
df.weight)).show(truncate=False)
    +    +--------------+
    +    |mean(features)|
    +    +--------------+
    +    |[1.0,1.0,1.0] |
    +    +--------------+
    +    <BLANKLINE>
    +    >>> df.select(Summarizer.mean(df.features)).show(truncate=False)
    +    +--------------+
    +    |mean(features)|
    +    +--------------+
    +    |[1.0,1.5,2.0] |
    +    +--------------+
    +    <BLANKLINE>
    +
    +    .. versionadded:: 2.4.0
    +
    +    """
    +    @staticmethod
    +    @since("2.4.0")
    +    def mean(col, weightCol=None):
    +        """
    +        return a column of mean summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "mean")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def variance(col, weightCol=None):
    +        """
    +        return a column of variance summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "variance")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def count(col, weightCol=None):
    +        """
    +        return a column of count summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "count")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def numNonZeros(col, weightCol=None):
    +        """
    +        return a column of numNonZero summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "numNonZeros")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def max(col, weightCol=None):
    +        """
    +        return a column of max summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "max")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def min(col, weightCol=None):
    +        """
    +        return a column of min summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "min")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def normL1(col, weightCol=None):
    +        """
    +        return a column of normL1 summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "normL1")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def normL2(col, weightCol=None):
    +        """
    +        return a column of normL2 summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "normL2")
    +
    +    @staticmethod
    +    def _check_param(featureCol, weightCol):
    +        if weightCol is None:
    +            weightCol = lit(1.0)
    +        if not isinstance(featureCol, Column) or not isinstance(weightCol, 
Column):
    +            raise TypeError("featureCol and weightCol should be a Column")
    +        return featureCol, weightCol
    +
    +    @staticmethod
    +    def _get_single_metric(col, weightCol, metric):
    +        col, weightCol = Summarizer._check_param(col, weightCol)
    +        return 
Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + 
metric,
    +                                                col._jc, weightCol._jc))
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def metrics(*metrics):
    +        """
    +        Given a list of metrics, provides a builder that it turns computes 
metrics from a column.
    +
    +        See the documentation of [[Summarizer]] for an example.
    +
    +        The following metrics are accepted (case sensitive):
    +         - mean: a vector that contains the coefficient-wise mean.
    +         - variance: a vector tha contains the coefficient-wise variance.
    +         - count: the count of all vectors seen.
    +         - numNonzeros: a vector with the number of non-zeros for each 
coefficients
    +         - max: the maximum for each coefficient.
    +         - min: the minimum for each coefficient.
    +         - normL2: the Euclidian norm for each coefficient.
    +         - normL1: the L1 norm of each coefficient (sum of the absolute 
values).
    +
    +        :param metrics metrics that can be provided.
    +        :return a Summarizer
    +
    +        Note: Currently, the performance of this interface is about 2x~3x 
slower then using the RDD
    +        interface.
    +        """
    +        sc = SparkContext._active_spark_context
    +        js = 
JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics",
    +                                       _to_seq(sc, metrics))
    +        return SummarizerBuilder(js)
    +
    +
    +class SummarizerBuilder(object):
    --- End diff --
    
    Also, shouldn't we use JavaWrapper for this?  That will clean up when this 
object is destroyed.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to