Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19872#discussion_r162635572
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -4279,6 +4273,425 @@ def test_unsupported_types(self):
                     df.groupby('id').apply(f).collect()
     
     
    +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not 
installed")
    +class GroupbyAggPandasUDFTests(ReusedSQLTestCase):
    +
    +    @property
    +    def data(self):
    +        from pyspark.sql.functions import array, explode, col, lit
    +        return self.spark.range(10).toDF('id') \
    +            .withColumn("vs", array([lit(i * 1.0) + col('id') for i in 
range(20, 30)])) \
    +            .withColumn("v", explode(col('vs'))) \
    +            .drop('vs') \
    +            .withColumn('w', lit(1.0))
    +
    +    @property
    +    def python_plus_one(self):
    +        from pyspark.sql.functions import udf
    +
    +        @udf('double')
    +        def plus_one(v):
    +            assert isinstance(v, (int, float))
    +            return v + 1
    +        return plus_one
    +
    +    @property
    +    def pandas_scalar_plus_two(self):
    +        import pandas as pd
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.SCALAR)
    +        def plus_two(v):
    +            assert isinstance(v, pd.Series)
    +            return v + 2
    +        return plus_two
    +
    +    @property
    +    def pandas_agg_mean_udf(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUP_AGG)
    +        def avg(v):
    +            return v.mean()
    +        return avg
    +
    +    @property
    +    def pandas_agg_sum_udf(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUP_AGG)
    +        def sum(v):
    +            return v.sum()
    +        return sum
    +
    +    @property
    +    def pandas_agg_weighted_mean_udf(self):
    +        import numpy as np
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUP_AGG)
    +        def weighted_mean(v, w):
    +            return np.average(v, weights=w)
    +        return weighted_mean
    +
    +    def test_basic(self):
    +        from pyspark.sql.functions import col, lit, sum, mean
    +
    +        df = self.data
    +        weighted_mean_udf = self.pandas_agg_weighted_mean_udf
    +
    +        # Groupby one column and aggregate one UDF with literal
    +        result1 = df.groupby('id').agg(weighted_mean_udf(df.v, 
lit(1.0))).sort('id')
    +        expected1 = 
df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
    --- End diff --
    
    Ah. No worries. Thanks for clarification.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to