Github user HyukjinKwon commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21082#discussion_r192140997
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -5181,6 +5190,235 @@ def test_invalid_args(self):
                         'mixture.*aggregate function.*group aggregate pandas 
UDF'):
                     df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
     
    +
    +@unittest.skipIf(
    +    not _have_pandas or not _have_pyarrow,
    +    _pandas_requirement_message or _pyarrow_requirement_message)
    +class WindowPandasUDFTests(ReusedSQLTestCase):
    +    @property
    +    def data(self):
    +        from pyspark.sql.functions import array, explode, col, lit
    +        return self.spark.range(10).toDF('id') \
    +            .withColumn("vs", array([lit(i * 1.0) + col('id') for i in 
range(20, 30)])) \
    +            .withColumn("v", explode(col('vs'))) \
    +            .drop('vs') \
    +            .withColumn('w', lit(1.0))
    +
    +    @property
    +    def python_plus_one(self):
    +        from pyspark.sql.functions import udf
    +        return udf(lambda v: v + 1, 'double')
    +
    +    @property
    +    def pandas_scalar_time_two(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +        return pandas_udf(lambda v: v * 2, 'double')
    +
    +    @property
    +    def pandas_agg_mean_udf(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUPED_AGG)
    +        def avg(v):
    +            return v.mean()
    +        return avg
    +
    +    @property
    +    def pandas_agg_max_udf(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUPED_AGG)
    +        def max(v):
    +            return v.max()
    +        return max
    +
    +    @property
    +    def pandas_agg_min_udf(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUPED_AGG)
    +        def min(v):
    +            return v.min()
    +        return min
    +
    +    @property
    +    def unbounded_window(self):
    +        return Window.partitionBy('id') \
    +            .rowsBetween(Window.unboundedPreceding, 
Window.unboundedFollowing)
    +
    +    @property
    +    def ordered_window(self):
    +        return Window.partitionBy('id').orderBy('v')
    +
    +    @property
    +    def unpartitioned_window(self):
    +        return Window.partitionBy()
    +
    +    def test_simple(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType, 
percent_rank, mean, max
    +
    +        df = self.data
    +        w = self.unbounded_window
    +
    +        mean_udf = self.pandas_agg_mean_udf
    +
    +        result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
    +        expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
    +
    +        result2 = df.select(mean_udf(df['v']).over(w))
    +        expected2 = df.select(mean(df['v']).over(w))
    +
    +        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
    +        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
    +
    +    def test_multiple_udfs(self):
    +        from pyspark.sql.functions import max, min, mean
    +
    +        df = self.data
    +        w = self.unbounded_window
    +
    +        result1 = df.withColumn('mean_v', 
self.pandas_agg_mean_udf(df['v']).over(w)) \
    +                    .withColumn('max_v', 
self.pandas_agg_max_udf(df['v']).over(w)) \
    +                    .withColumn('min_w', 
self.pandas_agg_min_udf(df['w']).over(w)) \
    --- End diff --
    
    Trailing `\`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to