Github user gatorsmile commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20295#discussion_r214795846
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -4588,6 +4613,80 @@ def test_timestamp_dst(self):
             result = df.groupby('time').apply(foo_udf).sort('time')
             self.assertPandasEqual(df.toPandas(), result.toPandas())
     
    +    def test_udf_with_key(self):
    +        from pyspark.sql.functions import pandas_udf, col, PandasUDFType
    +        df = self.data
    +        pdf = df.toPandas()
    +
    +        def foo1(key, pdf):
    +            import numpy as np
    +            assert type(key) == tuple
    +            assert type(key[0]) == np.int64
    +
    +            return pdf.assign(v1=key[0],
    +                              v2=pdf.v * key[0],
    +                              v3=pdf.v * pdf.id,
    +                              v4=pdf.v * pdf.id.mean())
    +
    +        def foo2(key, pdf):
    +            import numpy as np
    +            assert type(key) == tuple
    +            assert type(key[0]) == np.int64
    +            assert type(key[1]) == np.int32
    +
    +            return pdf.assign(v1=key[0],
    +                              v2=key[1],
    +                              v3=pdf.v * key[0],
    +                              v4=pdf.v + key[1])
    +
    +        def foo3(key, pdf):
    +            assert type(key) == tuple
    +            assert len(key) == 0
    +            return pdf.assign(v1=pdf.v * pdf.id)
    +
    +        # v2 is int because numpy.int64 * pd.Series<int32> results in 
pd.Series<int32>
    +        # v3 is long because pd.Series<int64> * pd.Series<int32> results 
in pd.Series<int64>
    +        udf1 = pandas_udf(
    +            foo1,
    +            'id long, v int, v1 long, v2 int, v3 long, v4 double',
    +            PandasUDFType.GROUPED_MAP)
    +
    +        udf2 = pandas_udf(
    +            foo2,
    +            'id long, v int, v1 long, v2 int, v3 int, v4 int',
    +            PandasUDFType.GROUPED_MAP)
    +
    +        udf3 = pandas_udf(
    +            foo3,
    +            'id long, v int, v1 long',
    +            PandasUDFType.GROUPED_MAP)
    +
    +        # Test groupby column
    +        result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
    +        expected1 = pdf.groupby('id')\
    +            .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
    +            .sort_values(['id', 'v']).reset_index(drop=True)
    +        self.assertPandasEqual(expected1, result1)
    +
    +        # Test groupby expression
    +        result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 
'v').toPandas()
    +        expected2 = pdf.groupby(pdf.id % 2)\
    +            .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
    +            .sort_values(['id', 'v']).reset_index(drop=True)
    +        self.assertPandasEqual(expected2, result2)
    +
    +        # Test complex groupby
    +        result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 
'v').toPandas()
    --- End diff --
    
    Any negative test case when the number of columns specified in groupby is 
different from the definition of udf (foo2)?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to