This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 393c69a515b [SPARK-40559][PYTHON] Add applyInArrow to groupBy and 
cogroup
393c69a515b is described below

commit 393c69a515b4acb6ea0659c0a7f09ae487801c40
Author: Enrico Minack <git...@enrico.minack.dev>
AuthorDate: Mon Dec 4 10:21:26 2023 +0900

    [SPARK-40559][PYTHON] Add applyInArrow to groupBy and cogroup
    
    ### What changes were proposed in this pull request?
    Add `applyInArrow` method to PySpark `groupBy` and `groupBy.cogroup` to 
allow for user functions that work on Arrow. Similar to existing `mapInArrow`.
    
    ### Why are the changes needed?
    PySpark allows to transform a `DataFrame` via Pandas and Arrow API:
    ```
    df.mapInArrow(map_arrow, schema="...")
    df.mapInPandas(map_pandas, schema="...")
    ```
    
    For `df.groupBy(...)` and `df.groupBy(...).cogroup(...)`, there is only a 
Pandas interface, no Arrow interface:
    ```
    df.groupBy("id").applyInPandas(apply_pandas, schema="...")
    ```
    
    Providing a pure Arrow interface allows user code to use **any** 
Arrow-based data framework, not only Pandas, e.g. Polars:
    ```
    def apply_polars(df: polars.DataFrame) -> polars.DataFrame:
      return df
    
    def apply_arrow(table: pyarrow.Table) -> pyarrow.Table:
      df = polars.from_arrow(table)
      return apply_polars(df).to_arrow()
    
    df.groupBy("id").applyInArrow(apply_arrow, schema="...")
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    This adds method `applyInPandas` to PySpark `groupBy` and `groupBy.cogroup`.
    
    ### How was this patch tested?
    Tested with unit tests.
    
    Closes #38624 from EnricoMi/branch-pyspark-grouped-apply-in-arrow.
    
    Authored-by: Enrico Minack <git...@enrico.minack.dev>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/api/python/PythonRunner.scala |   4 +
 dev/sparktestsupport/modules.py                    |   2 +
 python/pyspark/errors/error_classes.py             |  15 ++
 python/pyspark/rdd.py                              |   4 +
 python/pyspark/sql/pandas/_typing/__init__.pyi     |  11 +
 python/pyspark/sql/pandas/functions.py             |  22 ++
 python/pyspark/sql/pandas/group_ops.py             | 230 +++++++++++++++-
 python/pyspark/sql/pandas/serializers.py           |  81 +++++-
 python/pyspark/sql/tests/arrow/__init__.py         |  16 ++
 .../sql/tests/arrow/test_arrow_cogrouped_map.py    | 300 +++++++++++++++++++++
 .../sql/tests/arrow/test_arrow_grouped_map.py      | 291 ++++++++++++++++++++
 python/pyspark/sql/udf.py                          |  40 +++
 python/pyspark/worker.py                           | 207 +++++++++++++-
 .../plans/logical/pythonLogicalOperators.scala     |  48 ++++
 .../spark/sql/RelationalGroupedDataset.scala       |  82 +++++-
 .../spark/sql/execution/SparkStrategies.scala      |   6 +
 .../python/FlatMapCoGroupsInArrowExec.scala        |  58 ++++
 .../python/FlatMapCoGroupsInPandasExec.scala       |  62 +----
 ...xec.scala => FlatMapCoGroupsInPythonExec.scala} |  43 +--
 .../python/FlatMapGroupsInArrowExec.scala          |  64 +++++
 .../python/FlatMapGroupsInPandasExec.scala         |  60 +----
 ...sExec.scala => FlatMapGroupsInPythonExec.scala} |  54 ++--
 22 files changed, 1513 insertions(+), 187 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index e6d5a750ea3..148f80540d9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -57,6 +57,8 @@ private[spark] object PythonEvalType {
   val SQL_COGROUPED_MAP_PANDAS_UDF = 206
   val SQL_MAP_ARROW_ITER_UDF = 207
   val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
+  val SQL_GROUPED_MAP_ARROW_UDF = 209
+  val SQL_COGROUPED_MAP_ARROW_UDF = 210
 
   val SQL_TABLE_UDF = 300
   val SQL_ARROW_TABLE_UDF = 301
@@ -74,6 +76,8 @@ private[spark] object PythonEvalType {
     case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
     case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
     case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => 
"SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
+    case SQL_GROUPED_MAP_ARROW_UDF => "SQL_GROUPED_MAP_ARROW_UDF"
+    case SQL_COGROUPED_MAP_ARROW_UDF => "SQL_COGROUPED_MAP_ARROW_UDF"
     case SQL_TABLE_UDF => "SQL_TABLE_UDF"
     case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF"
   }
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index feb49062316..718a2509741 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -491,6 +491,8 @@ pyspark_sql = Module(
         "pyspark.sql.pandas.utils",
         "pyspark.sql.observation",
         # unittests
+        "pyspark.sql.tests.arrow.test_arrow_cogrouped_map",
+        "pyspark.sql.tests.arrow.test_arrow_grouped_map",
         "pyspark.sql.tests.test_arrow",
         "pyspark.sql.tests.test_arrow_python_udf",
         "pyspark.sql.tests.test_catalog",
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index e1a93aa6be1..c7199ac938b 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -297,6 +297,11 @@ ERROR_CLASSES_JSON = """
       "`<arg_name>` should be one the values from PandasUDFType, got 
<arg_type>"
     ]
   },
+  "INVALID_RETURN_TYPE_FOR_ARROW_UDF": {
+    "message": [
+      "Grouped and Cogrouped map Arrow UDF should return StructType for 
<eval_type>, got <return_type>."
+    ]
+  },
   "INVALID_RETURN_TYPE_FOR_PANDAS_UDF": {
     "message": [
       "Pandas UDF should return StructType for <eval_type>, got <return_type>."
@@ -713,6 +718,11 @@ ERROR_CLASSES_JSON = """
       "transformation. For more information, see SPARK-5063."
     ]
   },
+  "RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDF" : {
+    "message" : [
+      "Column names of the returned pyarrow.Table do not match specified 
schema.<missing><extra>"
+    ]
+  },
   "RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF" : {
     "message" : [
       "Column names of the returned pandas.DataFrame do not match specified 
schema.<missing><extra>"
@@ -728,6 +738,11 @@ ERROR_CLASSES_JSON = """
       "The length of output in Scalar iterator pandas UDF should be the same 
with the input's; however, the length of output was <output_length> and the 
length of input was <input_length>."
     ]
   },
+  "RESULT_TYPE_MISMATCH_FOR_ARROW_UDF" : {
+    "message" : [
+      "Columns do not match in their data type: <mismatch>."
+    ]
+  },
   "SCHEMA_MISMATCH_FOR_PANDAS_UDF" : {
     "message" : [
       "Result vector from pandas_udf was not the required length: expected 
<expected>, got <actual>."
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d2a8bc4b111..1066830b537 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -110,6 +110,8 @@ if TYPE_CHECKING:
         PandasCogroupedMapUDFType,
         ArrowMapIterUDFType,
         PandasGroupedMapUDFWithStateType,
+        ArrowGroupedMapUDFType,
+        ArrowCogroupedMapUDFType,
     )
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.types import AtomicType, StructType
@@ -158,6 +160,8 @@ class PythonEvalType:
     SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
     SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
     SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" 
= 208
+    SQL_GROUPED_MAP_ARROW_UDF: "ArrowGroupedMapUDFType" = 209
+    SQL_COGROUPED_MAP_ARROW_UDF: "ArrowCogroupedMapUDFType" = 210
 
     SQL_TABLE_UDF: "SQLTableUDFType" = 300
     SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi 
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index 69279727ca9..0838f446279 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -53,6 +53,8 @@ PandasMapIterUDFType = Literal[205]
 PandasCogroupedMapUDFType = Literal[206]
 ArrowMapIterUDFType = Literal[207]
 PandasGroupedMapUDFWithStateType = Literal[208]
+ArrowGroupedMapUDFType = Literal[209]
+ArrowCogroupedMapUDFType = Literal[210]
 
 class PandasVariadicScalarToScalarFunction(Protocol):
     def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: 
...
@@ -341,4 +343,13 @@ PandasCogroupedMapFunction = Union[
     Callable[[Any, DataFrameLike, DataFrameLike], DataFrameLike],
 ]
 
+ArrowGroupedMapFunction = Union[
+    Callable[[pyarrow.Table], pyarrow.Table],
+    Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table],
+]
+ArrowCogroupedMapFunction = Union[
+    Callable[[pyarrow.Table, pyarrow.Table], pyarrow.Table],
+    Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table, pyarrow.Table], 
pyarrow.Table],
+]
+
 GroupedMapPandasUserDefinedFunction = 
NewType("GroupedMapPandasUserDefinedFunction", FunctionType)
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index 64969a05163..dc6fb5a8976 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -413,6 +413,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+        PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
         None,
     ]:  # None means it should infer the type from type hints.
         raise PySparkTypeError(
@@ -450,6 +452,8 @@ def _create_pandas_udf(f, returnType, evalType):
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+        PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
         PythonEvalType.SQL_ARROW_BATCHED_UDF,
     ]:
         # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is 
being triggered
@@ -497,6 +501,15 @@ def _create_pandas_udf(f, returnType, evalType):
             },
         )
 
+    if evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF and 
len(argspec.args) not in (1, 2):
+        raise PySparkValueError(
+            error_class="INVALID_PANDAS_UDF",
+            message_parameters={
+                "detail": "the function in groupby.applyInArrow must take 
either one argument "
+                "(data) or two arguments (key, data).",
+            },
+        )
+
     if evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF and 
len(argspec.args) not in (2, 3):
         raise PySparkValueError(
             error_class="INVALID_PANDAS_UDF",
@@ -506,6 +519,15 @@ def _create_pandas_udf(f, returnType, evalType):
             },
         )
 
+    if evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF and 
len(argspec.args) not in (2, 3):
+        raise PySparkValueError(
+            error_class="INVALID_PANDAS_UDF",
+            message_parameters={
+                "detail": "the function in cogroup.applyInArrow must take 
either two arguments "
+                "(left, right) or three arguments (key, left, right).",
+            },
+        )
+
     if is_remote():
         from pyspark.sql.connect.udf import _create_udf as _create_connect_udf
 
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index dfe37672c03..fa48d615ef8 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -31,6 +31,8 @@ if TYPE_CHECKING:
         PandasGroupedMapFunction,
         PandasGroupedMapFunctionWithState,
         PandasCogroupedMapFunction,
+        ArrowGroupedMapFunction,
+        ArrowCogroupedMapFunction,
     )
     from pyspark.sql.group import GroupedData
 
@@ -152,7 +154,7 @@ class PandasGroupedOpsMixin:
         Examples
         --------
         >>> import pandas as pd  # doctest: +SKIP
-        >>> from pyspark.sql.functions import pandas_udf, ceil
+        >>> from pyspark.sql.functions import ceil
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
         ...     ("id", "v"))  # doctest: +SKIP
@@ -358,6 +360,133 @@ class PandasGroupedOpsMixin:
         )
         return DataFrame(jdf, self.session)
 
+    def applyInArrow(
+        self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
+    ) -> "DataFrame":
+        """
+        Maps each group of the current :class:`DataFrame` using an Arrow udf 
and returns the result
+        as a `DataFrame`.
+
+        The function should take a `pyarrow.Table` and return another
+        `pyarrow.Table`. Alternatively, the user can pass a function that takes
+        a tuple of `pyarrow.Scalar` grouping key(s) and a `pyarrow.Table`.
+        For each group, all columns are passed together as a `pyarrow.Table`
+        to the user-function and the returned `pyarrow.Table` are combined as a
+        :class:`DataFrame`.
+
+        The `schema` should be a :class:`StructType` describing the schema of 
the returned
+        `pyarrow.Table`. The column labels of the returned `pyarrow.Table` 
must either match
+        the field names in the defined schema if specified as strings, or 
match the
+        field data types by position if not strings, e.g. integer indices.
+        The length of the returned `pyarrow.Table` can be arbitrary.
+
+        .. versionadded:: 4.0.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function that takes a `pyarrow.Table` and outputs a
+            `pyarrow.Table`, or that takes one tuple (grouping keys) and a
+            `pyarrow.Table` and outputs a `pyarrow.Table`.
+        schema : :class:`pyspark.sql.types.DataType` or str
+            the return type of the `func` in PySpark. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+
+        Examples
+        --------
+        >>> from pyspark.sql.functions import ceil
+        >>> import pyarrow  # doctest: +SKIP
+        >>> import pyarrow.compute as pc  # doctest: +SKIP
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+        ...     ("id", "v"))  # doctest: +SKIP
+        >>> def normalize(table):
+        ...     v = table.column("v")
+        ...     norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, 
ddof=1))
+        ...     return table.set_column(1, "v", norm)
+        >>> df.groupby("id").applyInArrow(
+        ...     normalize, schema="id long, v double").show()  # doctest: +SKIP
+        +---+-------------------+
+        +---+-------------------+
+        | id|                  v|
+        +---+-------------------+
+        |  1|-0.7071067811865475|
+        |  1| 0.7071067811865475|
+        |  2|-0.8320502943378437|
+        |  2|-0.2773500981126146|
+        |  2| 1.1094003924504583|
+        +---+-------------------+
+
+        Alternatively, the user can pass a function that takes two arguments.
+        In this case, the grouping key(s) will be passed as the first argument 
and the data will
+        be passed as the second argument. The grouping key(s) will be passed 
as a tuple of Arrow
+        scalars types, e.g., `pyarrow.Int32Scalar` and `pyarrow.FloatScalar`. 
The data will still
+        be passed in as a `pyarrow.Table` containing all columns from the 
original Spark DataFrame.
+        This is useful when the user does not want to hardcode grouping key(s) 
in the function.
+
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+        ...     ("id", "v"))  # doctest: +SKIP
+        >>> def mean_func(key, table):
+        ...     # key is a tuple of one pyarrow.Int64Scalar, which is the value
+        ...     # of 'id' for the current group
+        ...     mean = pc.mean(table.column("v"))
+        ...     return pyarrow.Table.from_pydict({"id": [key[0].as_py()], "v": 
[mean.as_py()]})
+        >>> df.groupby('id').applyInArrow(
+        ...     mean_func, schema="id long, v double")  # doctest: +SKIP
+        +---+---+
+        | id|  v|
+        +---+---+
+        |  1|1.5|
+        |  2|6.0|
+        +---+---+
+
+        >>> def sum_func(key, table):
+        ...     # key is a tuple of two pyarrow.Int64Scalars, which is the 
values
+        ...     # of 'id' and 'ceil(df.v / 2)' for the current group
+        ...     sum = pc.sum(table.column("v"))
+        ...     return pyarrow.Table.from_pydict({
+        ...         "id": [key[0].as_py()],
+        ...         "ceil(v / 2)": [key[1].as_py()],
+        ...         "v": [sum.as_py()]
+        ...     })
+        >>> df.groupby(df.id, ceil(df.v / 2)).applyInArrow(
+        ...     sum_func, schema="id long, `ceil(v / 2)` long, v 
double").show()  # doctest: +SKIP
+        +---+-----------+----+
+        | id|ceil(v / 2)|   v|
+        +---+-----------+----+
+        |  2|          5|10.0|
+        |  1|          1| 3.0|
+        |  2|          3| 5.0|
+        |  2|          2| 3.0|
+        +---+-----------+----+
+
+        Notes
+        -----
+        This function requires a full shuffle. All the data of a group will be 
loaded
+        into memory, so the user should be aware of the potential OOM risk if 
data is skewed
+        and certain groups are too large to fit in memory.
+
+        This API is experimental.
+
+        See Also
+        --------
+        pyspark.sql.functions.pandas_udf
+        """
+        from pyspark.sql import GroupedData
+        from pyspark.sql.functions import pandas_udf
+
+        assert isinstance(self, GroupedData)
+
+        # The usage of the pandas_udf is internal so type checking is disabled.
+        udf = pandas_udf(
+            func, returnType=schema, 
functionType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+        )  # type: ignore[call-overload]
+        df = self._df
+        udf_column = udf(*[df[col] for col in df.columns])
+        jdf = self._jgd.flatMapGroupsInArrow(udf_column._jc.expr())
+        return DataFrame(jdf, self.session)
+
     def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
         """
         Cogroups this group with another group so that we can run cogrouped 
operations.
@@ -432,7 +561,6 @@ class PandasCogroupedOps:
 
         Examples
         --------
-        >>> from pyspark.sql.functions import pandas_udf
         >>> df1 = spark.createDataFrame(
         ...     [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), 
(20000102, 2, 4.0)],
         ...     ("time", "id", "v1"))
@@ -499,6 +627,104 @@ class PandasCogroupedOps:
         jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, 
udf_column._jc.expr())
         return DataFrame(jdf, self._gd1.session)
 
+    def applyInArrow(
+        self, func: "ArrowCogroupedMapFunction", schema: Union[StructType, str]
+    ) -> "DataFrame":
+        """
+        Applies a function to each cogroup using Arrow and returns the result
+        as a `DataFrame`.
+
+        The function should take two `pyarrow.Table`s and return another
+        `pyarrow.Table`. Alternatively, the user can pass a function that takes
+        a tuple of `pyarrow.Scalar` grouping key(s) and the two 
`pyarrow.Table`s.
+        For each side of the cogroup, all columns are passed together as a
+        `pyarrow.Table` to the user-function and the returned `pyarrow.Table` 
are combined as
+        a :class:`DataFrame`.
+
+        The `schema` should be a :class:`StructType` describing the schema of 
the returned
+        `pandas.DataFrame`. The column labels of the returned 
`pandas.DataFrame` must either match
+        the field names in the defined schema if specified as strings, or 
match the
+        field data types by position if not strings, e.g. integer indices.
+        The length of the returned `pyarrow.Table` can be arbitrary.
+
+        .. versionadded:: 4.0.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function that takes two `pyarrow.Table`s, and
+            outputs a `pyarrow.Table`, or that takes one tuple (grouping keys) 
and two
+            ``pyarrow.Table``s, and outputs a ``pyarrow.Table``.
+        schema : :class:`pyspark.sql.types.DataType` or str
+            the return type of the `func` in PySpark. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+
+        Examples
+        --------
+        >>> import pyarrow  # doctest: +SKIP
+        >>> df1 = spark.createDataFrame([(1, 1.0), (2, 2.0), (1, 3.0), (2, 
4.0)], ("id", "v1"))
+        >>> df2 = spark.createDataFrame([(1, "x"), (2, "y")], ("id", "v2"))
+        >>> def summarize(l, r):
+        ...     return pyarrow.Table.from_pydict({
+        ...         "left": [l.num_rows],
+        ...         "right": [r.num_rows]
+        ...     })
+        >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
+        ...     summarize, schema="left long, right long"
+        ... ).show()  # doctest: +SKIP
+        +----+-----+
+        |left|right|
+        +----+-----+
+        |   2|    1|
+        |   2|    1|
+        +----+-----+
+
+        Alternatively, the user can define a function that takes three 
arguments.  In this case,
+        the grouping key(s) will be passed as the first argument and the data 
will be passed as the
+        second and third arguments.  The grouping key(s) will be passed as a 
tuple of Arrow scalars
+        types, e.g., `pyarrow.Int32Scalar` and `pyarrow.FloatScalar`. The data 
will still be passed
+        in as two `pyarrow.Table`s containing all columns from the original 
Spark DataFrames.
+
+        >>> def summarize(key, l, r):
+        ...     return pyarrow.Table.from_pydict({
+        ...         "key": [key[0].as_py()],
+        ...         "left": [l.num_rows],
+        ...         "right": [r.num_rows]
+        ...     })
+        >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
+        ...     summarize, schema="key long, left long, right long"
+        ... ).show()  # doctest: +SKIP
+        +---+----+-----+
+        |key|left|right|
+        +---+----+-----+
+        |  1|   2|    1|
+        |  2|   2|    1|
+        +---+----+-----+
+
+        Notes
+        -----
+        This function requires a full shuffle. All the data of a cogroup will 
be loaded
+        into memory, so the user should be aware of the potential OOM risk if 
data is skewed
+        and certain groups are too large to fit in memory.
+
+        This API is experimental.
+
+        See Also
+        --------
+        pyspark.sql.functions.pandas_udf
+        """
+        from pyspark.sql.pandas.functions import pandas_udf
+
+        # The usage of the pandas_udf is internal so type checking is disabled.
+        udf = pandas_udf(
+            func, returnType=schema, 
functionType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF
+        )  # type: ignore[call-overload]
+
+        all_cols = self._extract_cols(self._gd1) + 
self._extract_cols(self._gd2)
+        udf_column = udf(*all_cols)
+        jdf = self._gd1._jgd.flatMapCoGroupsInArrow(self._gd2._jgd, 
udf_column._jc.expr())
+        return DataFrame(jdf, self._gd1.session)
+
     @staticmethod
     def _extract_cols(gd: "GroupedData") -> List[Column]:
         df = gd._df
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 4c1d1c177d6..8ffb7407714 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -162,6 +162,49 @@ class ArrowStreamUDFSerializer(ArrowStreamSerializer):
         return super(ArrowStreamUDFSerializer, 
self).dump_stream(wrap_and_init_stream(), stream)
 
 
+class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
+    """
+    Serializes pyarrow.RecordBatch data with Arrow streaming format.
+
+    Loads Arrow record batches as ``[[pyarrow.RecordBatch]]`` (one 
``[pyarrow.RecordBatch]`` per
+    group) and serializes ``[([pyarrow.RecordBatch], arrow_type)]``.
+
+    Parameters
+    ----------
+    assign_cols_by_name : bool
+        If True, then DataFrames will get columns by name
+    """
+
+    def __init__(self, assign_cols_by_name):
+        super(ArrowStreamGroupUDFSerializer, self).__init__()
+        self._assign_cols_by_name = assign_cols_by_name
+
+    def dump_stream(self, iterator, stream):
+        import pyarrow as pa
+
+        # flatten inner list [([pa.RecordBatch], arrow_type)] into 
[(pa.RecordBatch, arrow_type)]
+        # so strip off inner iterator induced by 
ArrowStreamUDFSerializer.load_stream
+        batch_iter = (
+            (batch, arrow_type)
+            for batches, arrow_type in iterator  # tuple constructed in 
wrap_grouped_map_arrow_udf
+            for batch in batches
+        )
+
+        if self._assign_cols_by_name:
+            batch_iter = (
+                (
+                    pa.RecordBatch.from_arrays(
+                        [batch.column(field.name) for field in arrow_type],
+                        names=[field.name for field in arrow_type],
+                    ),
+                    arrow_type,
+                )
+                for batch, arrow_type in batch_iter
+            )
+
+        super(ArrowStreamGroupUDFSerializer, self).dump_stream(batch_iter, 
stream)
+
+
 class ArrowStreamPandasSerializer(ArrowStreamSerializer):
     """
     Serializes pandas.Series as Arrow data with Arrow streaming format.
@@ -633,7 +676,43 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
         return "ArrowStreamPandasUDTFSerializer"
 
 
-class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
+class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
+    """
+    Serializes pyarrow.RecordBatch data with Arrow streaming format.
+
+    Loads Arrow record batches as `[([pa.RecordBatch], [pa.RecordBatch])]` 
(one tuple per group)
+    and serializes `[([pa.RecordBatch], arrow_type)]`.
+
+    Parameters
+    ----------
+    assign_cols_by_name : bool
+        If True, then DataFrames will get columns by name
+    """
+
+    def __init__(self, assign_cols_by_name):
+        super(CogroupArrowUDFSerializer, self).__init__(assign_cols_by_name)
+
+    def load_stream(self, stream):
+        """
+        Deserialize Cogrouped ArrowRecordBatches and yield as two 
`pyarrow.RecordBatch`es.
+        """
+        dataframes_in_group = None
+
+        while dataframes_in_group is None or dataframes_in_group > 0:
+            dataframes_in_group = read_int(stream)
+
+            if dataframes_in_group == 2:
+                batches1 = [batch for batch in 
ArrowStreamSerializer.load_stream(self, stream)]
+                batches2 = [batch for batch in 
ArrowStreamSerializer.load_stream(self, stream)]
+                yield batches1, batches2
+
+            elif dataframes_in_group != 0:
+                raise ValueError(
+                    "Invalid number of dataframes in group 
{0}".format(dataframes_in_group)
+                )
+
+
+class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
     def load_stream(self, stream):
         """
         Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables 
and yield as two
diff --git a/python/pyspark/sql/tests/arrow/__init__.py 
b/python/pyspark/sql/tests/arrow/__init__.py
new file mode 100644
index 00000000000..cce3acad34a
--- /dev/null
+++ b/python/pyspark/sql/tests/arrow/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
new file mode 100644
index 00000000000..0206d4c2c6d
--- /dev/null
+++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
@@ -0,0 +1,300 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import time
+import unittest
+
+from pyspark.errors import PythonException
+from pyspark.sql import Row
+from pyspark.sql.functions import col
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pyarrow,
+    pyarrow_requirement_message,
+)
+from pyspark.testing.utils import QuietTest
+
+
+if have_pyarrow:
+    import pyarrow as pa
+    import pyarrow.compute as pc
+
+
+@unittest.skipIf(
+    not have_pyarrow,
+    pyarrow_requirement_message,  # type: ignore[arg-type]
+)
+class CogroupedMapInArrowTests(ReusedSQLTestCase):
+    @property
+    def left(self):
+        return self.spark.range(0, 10, 2, 3).withColumn("v", col("id") * 10)
+
+    @property
+    def right(self):
+        return self.spark.range(0, 10, 3, 3).withColumn("v", col("id") * 10)
+
+    @property
+    def cogrouped(self):
+        grouped_left_df = self.left.groupBy((col("id") / 4).cast("int"))
+        grouped_right_df = self.right.groupBy((col("id") / 4).cast("int"))
+        return grouped_left_df.cogroup(grouped_right_df)
+
+    @classmethod
+    def setUpClass(cls):
+        ReusedSQLTestCase.setUpClass()
+
+        # Synchronize default timezone between Python and Java
+        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
+        tz = "America/Los_Angeles"
+        os.environ["TZ"] = tz
+        time.tzset()
+
+        cls.sc.environment["TZ"] = tz
+        cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+    @classmethod
+    def tearDownClass(cls):
+        del os.environ["TZ"]
+        if cls.tz_prev is not None:
+            os.environ["TZ"] = cls.tz_prev
+        time.tzset()
+        ReusedSQLTestCase.tearDownClass()
+
+    @staticmethod
+    def apply_in_arrow_func(left, right):
+        assert isinstance(left, pa.Table)
+        assert isinstance(right, pa.Table)
+        assert left.schema.names == ["id", "v"]
+        assert right.schema.names == ["id", "v"]
+
+        left_ids = left.to_pydict()["id"]
+        right_ids = right.to_pydict()["id"]
+        result = {
+            "metric": ["min", "max", "len", "sum"],
+            "left": [min(left_ids), max(left_ids), len(left_ids), 
sum(left_ids)],
+            "right": [min(right_ids), max(right_ids), len(right_ids), 
sum(right_ids)],
+        }
+        return pa.Table.from_pydict(result)
+
+    @staticmethod
+    def apply_in_arrow_with_key_func(key_column):
+        def func(key, left, right):
+            assert isinstance(key, tuple)
+            assert all(isinstance(scalar, pa.Scalar) for scalar in key)
+            if key_column:
+                assert all(
+                    (pc.divide(k, pa.scalar(4)).cast(pa.int32()),) == key
+                    for table in [left, right]
+                    for k in table.column(key_column)
+                )
+            return CogroupedMapInArrowTests.apply_in_arrow_func(left, right)
+
+        return func
+
+    @staticmethod
+    def apply_in_pandas_with_key_func(key_column):
+        def func(key, left, right):
+            return 
CogroupedMapInArrowTests.apply_in_arrow_with_key_func(key_column)(
+                tuple(pa.scalar(k) for k in key),
+                pa.Table.from_pandas(left),
+                pa.Table.from_pandas(right),
+            ).to_pandas()
+
+        return func
+
+    def do_test_apply_in_arrow(self, cogrouped_df, key_column="id"):
+        schema = "metric string, left long, right long"
+
+        # compare with result of applyInPandas
+        expected = cogrouped_df.applyInPandas(
+            
CogroupedMapInArrowTests.apply_in_pandas_with_key_func(key_column), schema
+        )
+
+        # apply in arrow without key
+        actual = cogrouped_df.applyInArrow(
+            CogroupedMapInArrowTests.apply_in_arrow_func, schema
+        ).collect()
+        self.assertEqual(actual, expected.collect())
+
+        # apply in arrow with key
+        actual2 = cogrouped_df.applyInArrow(
+            CogroupedMapInArrowTests.apply_in_arrow_with_key_func(key_column), 
schema
+        ).collect()
+        self.assertEqual(actual2, expected.collect())
+
+    def test_apply_in_arrow(self):
+        self.do_test_apply_in_arrow(self.cogrouped)
+
+    def test_apply_in_arrow_empty_groupby(self):
+        grouped_left_df = self.left.groupBy()
+        grouped_right_df = self.right.groupBy()
+        cogrouped_df = grouped_left_df.cogroup(grouped_right_df)
+        self.do_test_apply_in_arrow(cogrouped_df, key_column=None)
+
+    def test_apply_in_arrow_not_returning_arrow_table(self):
+        def func(key, left, right):
+            return key
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Return type of the user-defined function should be 
pyarrow.Table, but is tuple",
+            ):
+                self.cogrouped.applyInArrow(func, schema="id long").collect()
+
+    def test_apply_in_arrow_returning_wrong_types(self):
+        for schema, expected in [
+            ("id integer, v long", "column 'id' \\(expected int32, actual 
int64\\)"),
+            (
+                "id integer, v integer",
+                "column 'id' \\(expected int32, actual int64\\), "
+                "column 'v' \\(expected int32, actual int64\\)",
+            ),
+            ("id long, v integer", "column 'v' \\(expected int32, actual 
int64\\)"),
+            ("id long, v string", "column 'v' \\(expected string, actual 
int64\\)"),
+        ]:
+            with self.subTest(schema=schema):
+                with QuietTest(self.sc):
+                    with self.assertRaisesRegex(
+                        PythonException,
+                        f"Columns do not match in their data type: {expected}",
+                    ):
+                        self.cogrouped.applyInArrow(
+                            lambda left, right: left, schema=schema
+                        ).collect()
+
+    def test_apply_in_arrow_returning_wrong_types_positional_assignment(self):
+        for schema, expected in [
+            ("a integer, b long", "column 'a' \\(expected int32, actual 
int64\\)"),
+            (
+                "a integer, b integer",
+                "column 'a' \\(expected int32, actual int64\\), "
+                "column 'b' \\(expected int32, actual int64\\)",
+            ),
+            ("a long, b int", "column 'b' \\(expected int32, actual int64\\)"),
+            ("a long, b string", "column 'b' \\(expected string, actual 
int64\\)"),
+        ]:
+            with self.subTest(schema=schema):
+                with self.sql_conf(
+                    
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
+                ):
+                    with QuietTest(self.sc):
+                        with self.assertRaisesRegex(
+                            PythonException,
+                            f"Columns do not match in their data type: 
{expected}",
+                        ):
+                            self.cogrouped.applyInArrow(
+                                lambda left, right: left, schema=schema
+                            ).collect()
+
+    def test_apply_in_arrow_returning_wrong_column_names(self):
+        def stats(key, left, right):
+            # returning three columns
+            return pa.Table.from_pydict(
+                {
+                    "id": [key[0].as_py()],
+                    "v": [pc.mean(left.column("v")).as_py()],
+                    "v2": [pc.stddev(right.column("v")).as_py()],
+                }
+            )
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Column names of the returned pyarrow.Table do not match 
specified schema. "
+                "Missing: m. Unexpected: v, v2.\n",
+            ):
+                # stats returns three columns while here we set schema with 
two columns
+                self.cogrouped.applyInArrow(stats, schema="id long, m 
double").collect()
+
+    def test_apply_in_arrow_returning_empty_dataframe(self):
+        def odd_means(key, left, right):
+            if key[0].as_py() == 0:
+                return pa.table([])
+            else:
+                return pa.Table.from_pydict(
+                    {
+                        "id": [key[0].as_py()],
+                        "m": [pc.mean(left.column("v")).as_py()],
+                        "n": [pc.mean(right.column("v")).as_py()],
+                    }
+                )
+
+        schema = "id long, m double, n double"
+        actual = self.cogrouped.applyInArrow(odd_means, 
schema=schema).sort("id").collect()
+        expected = [Row(id=1, m=50.0, n=60.0), Row(id=2, m=80.0, n=90.0)]
+        self.assertEqual(expected, actual)
+
+    def 
test_apply_in_arrow_returning_empty_dataframe_and_wrong_column_names(self):
+        def odd_means(key, left, _):
+            if key[0].as_py() % 2 == 0:
+                return pa.table([[]], names=["id"])
+            else:
+                return pa.Table.from_pydict(
+                    {"id": [key[0].as_py()], "m": 
[pc.mean(left.column("v")).as_py()]}
+                )
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Column names of the returned pyarrow.Table do not match 
specified schema. "
+                "Missing: m.\n",
+            ):
+                # stats returns one column for even keys while here we set 
schema with two columns
+                self.cogrouped.applyInArrow(odd_means, schema="id long, m 
double").collect()
+
+    def test_apply_in_arrow_column_order(self):
+        df = self.left
+        expected = df.select(df.id, (df.v * 3).alias("u"), df.v).collect()
+
+        # Function returns a table with required column names but different 
order
+        def change_col_order(left, _):
+            return left.append_column("u", pc.multiply(left.column("v"), 3))
+
+        # The result should assign columns by name from the table
+        result = (
+            self.cogrouped.applyInArrow(change_col_order, "id long, u long, v 
long")
+            .sort("id", "v")
+            .select("id", "u", "v")
+            .collect()
+        )
+        self.assertEqual(expected, result)
+
+    def test_positional_assignment_conf(self):
+        with self.sql_conf(
+            
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
+        ):
+
+            def foo(left, right):
+                return pa.Table.from_pydict({"x": ["hi"], "y": [1]})
+
+            result = self.cogrouped.applyInArrow(foo, "a string, b 
long").select("a", "b").collect()
+            for r in result:
+                self.assertEqual(r.a, "hi")
+                self.assertEqual(r.b, 1)
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.arrow.test_arrow_cogrouped_map import *  # noqa: 
F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
new file mode 100644
index 00000000000..fa43648d42d
--- /dev/null
+++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
@@ -0,0 +1,291 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import time
+import unittest
+
+from pyspark.errors import PythonException
+from pyspark.sql import Row
+from pyspark.sql.functions import array, col, explode, lit, mean, stddev
+from pyspark.sql.window import Window
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pyarrow,
+    pyarrow_requirement_message,
+)
+from pyspark.testing.utils import QuietTest
+
+
+if have_pyarrow:
+    import pyarrow as pa
+    import pyarrow.compute as pc
+
+
+@unittest.skipIf(
+    not have_pyarrow,
+    pyarrow_requirement_message,  # type: ignore[arg-type]
+)
+class GroupedMapInArrowTests(ReusedSQLTestCase):
+    @property
+    def data(self):
+        return (
+            self.spark.range(10)
+            .toDF("id")
+            .withColumn("vs", array([lit(i) for i in range(20, 30)]))
+            .withColumn("v", explode(col("vs")))
+            .drop("vs")
+        )
+
+    @classmethod
+    def setUpClass(cls):
+        ReusedSQLTestCase.setUpClass()
+
+        # Synchronize default timezone between Python and Java
+        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
+        tz = "America/Los_Angeles"
+        os.environ["TZ"] = tz
+        time.tzset()
+
+        cls.sc.environment["TZ"] = tz
+        cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+    @classmethod
+    def tearDownClass(cls):
+        del os.environ["TZ"]
+        if cls.tz_prev is not None:
+            os.environ["TZ"] = cls.tz_prev
+        time.tzset()
+        ReusedSQLTestCase.tearDownClass()
+
+    def test_apply_in_arrow(self):
+        def func(group):
+            assert isinstance(group, pa.Table)
+            assert group.schema.names == ["id", "value"]
+            return group
+
+        df = self.spark.range(10).withColumn("value", col("id") * 10)
+        grouped_df = df.groupBy((col("id") / 4).cast("int"))
+        expected = df.collect()
+
+        actual = grouped_df.applyInArrow(func, "id long, value long").collect()
+        self.assertEqual(actual, expected)
+
+    def test_apply_in_arrow_with_key(self):
+        def func(key, group):
+            assert isinstance(key, tuple)
+            assert all(isinstance(scalar, pa.Scalar) for scalar in key)
+            assert isinstance(group, pa.Table)
+            assert group.schema.names == ["id", "value"]
+            assert all(
+                (pc.divide(k, pa.scalar(4)).cast(pa.int32()),) == key for k in 
group.column("id")
+            )
+            return group
+
+        df = self.spark.range(10).withColumn("value", col("id") * 10)
+        grouped_df = df.groupBy((col("id") / 4).cast("int"))
+        expected = df.collect()
+
+        actual2 = grouped_df.applyInArrow(func, "id long, value 
long").collect()
+        self.assertEqual(actual2, expected)
+
+    def test_apply_in_arrow_empty_groupby(self):
+        df = self.data
+
+        def normalize(table):
+            v = table.column("v")
+            return table.set_column(
+                1, "v", pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, 
ddof=1))
+            )
+
+        # casting doubles to floats to get rid of numerical precision issues
+        # when comparing Arrow and Spark values
+        actual = (
+            df.groupby()
+            .applyInArrow(normalize, "id long, v double")
+            .withColumn("v", col("v").cast("float"))
+            .sort("id", "v")
+        )
+        windowSpec = Window.partitionBy()
+        expected = df.withColumn(
+            "v",
+            ((df.v - mean(df.v).over(windowSpec)) / 
stddev(df.v).over(windowSpec)).cast("float"),
+        )
+        self.assertEqual(actual.collect(), expected.collect())
+
+    def test_apply_in_arrow_not_returning_arrow_table(self):
+        df = self.data
+
+        def stats(key, _):
+            return key
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Return type of the user-defined function should be 
pyarrow.Table, but is tuple",
+            ):
+                df.groupby("id").applyInArrow(stats, schema="id long, m 
double").collect()
+
+    def test_apply_in_arrow_returning_wrong_types(self):
+        df = self.data
+
+        for schema, expected in [
+            ("id integer, v integer", "column 'id' \\(expected int32, actual 
int64\\)"),
+            (
+                "id integer, v long",
+                "column 'id' \\(expected int32, actual int64\\), "
+                "column 'v' \\(expected int64, actual int32\\)",
+            ),
+            ("id long, v long", "column 'v' \\(expected int64, actual 
int32\\)"),
+            ("id long, v string", "column 'v' \\(expected string, actual 
int32\\)"),
+        ]:
+            with self.subTest(schema=schema):
+                with QuietTest(self.sc):
+                    with self.assertRaisesRegex(
+                        PythonException,
+                        f"Columns do not match in their data type: {expected}",
+                    ):
+                        df.groupby("id").applyInArrow(lambda table: table, 
schema=schema).collect()
+
+    def test_apply_in_arrow_returning_wrong_types_positional_assignment(self):
+        df = self.data
+
+        for schema, expected in [
+            ("a integer, b integer", "column 'a' \\(expected int32, actual 
int64\\)"),
+            (
+                "a integer, b long",
+                "column 'a' \\(expected int32, actual int64\\), "
+                "column 'b' \\(expected int64, actual int32\\)",
+            ),
+            ("a long, b long", "column 'b' \\(expected int64, actual 
int32\\)"),
+            ("a long, b string", "column 'b' \\(expected string, actual 
int32\\)"),
+        ]:
+            with self.subTest(schema=schema):
+                with self.sql_conf(
+                    
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
+                ):
+                    with QuietTest(self.sc):
+                        with self.assertRaisesRegex(
+                            PythonException,
+                            f"Columns do not match in their data type: 
{expected}",
+                        ):
+                            df.groupby("id").applyInArrow(
+                                lambda table: table, schema=schema
+                            ).collect()
+
+    def test_apply_in_arrow_returning_wrong_column_names(self):
+        df = self.data
+
+        def stats(key, table):
+            # returning three columns
+            return pa.Table.from_pydict(
+                {
+                    "id": [key[0].as_py()],
+                    "v": [pc.mean(table.column("v")).as_py()],
+                    "v2": [pc.stddev(table.column("v")).as_py()],
+                }
+            )
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Column names of the returned pyarrow.Table do not match 
specified schema. "
+                "Missing: m. Unexpected: v, v2.\n",
+            ):
+                # stats returns three columns while here we set schema with 
two columns
+                df.groupby("id").applyInArrow(stats, schema="id long, m 
double").collect()
+
+    def test_apply_in_arrow_returning_empty_dataframe(self):
+        df = self.data
+
+        def odd_means(key, table):
+            if key[0].as_py() % 2 == 0:
+                return pa.table([])
+            else:
+                return pa.Table.from_pydict(
+                    {"id": [key[0].as_py()], "m": 
[pc.mean(table.column("v")).as_py()]}
+                )
+
+        schema = "id long, m double"
+        actual = df.groupby("id").applyInArrow(odd_means, 
schema=schema).sort("id").collect()
+        expected = [Row(id=id, m=24.5) for id in range(1, 10, 2)]
+        self.assertEqual(expected, actual)
+
+    def 
test_apply_in_arrow_returning_empty_dataframe_and_wrong_column_names(self):
+        df = self.data
+
+        def odd_means(key, table):
+            if key[0].as_py() % 2 == 0:
+                return pa.table([[]], names=["id"])
+            else:
+                return pa.Table.from_pydict(
+                    {"id": [key[0].as_py()], "m": 
[pc.mean(table.column("v")).as_py()]}
+                )
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Column names of the returned pyarrow.Table do not match 
specified schema. "
+                "Missing: m.\n",
+            ):
+                # stats returns one column for even keys while here we set 
schema with two columns
+                df.groupby("id").applyInArrow(odd_means, schema="id long, m 
double").collect()
+
+    def test_apply_in_arrow_column_order(self):
+        df = self.data
+        grouped_df = df.groupby("id")
+        expected = df.select(df.id, (df.v * 3).alias("u"), df.v).collect()
+
+        # Function returns a table with required column names but different 
order
+        def change_col_order(table):
+            return table.append_column("u", pc.multiply(table.column("v"), 3))
+
+        # The result should assign columns by name from the table
+        result = (
+            grouped_df.applyInArrow(change_col_order, "id long, u long, v int")
+            .sort("id", "v")
+            .select("id", "u", "v")
+            .collect()
+        )
+        self.assertEqual(expected, result)
+
+    def test_positional_assignment_conf(self):
+        with self.sql_conf(
+            
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
+        ):
+
+            def foo(_):
+                return pa.Table.from_pydict({"x": ["hi"], "y": [1]})
+
+            df = self.data
+            result = (
+                df.groupBy("id").applyInArrow(foo, "a string, b 
long").select("a", "b").collect()
+            )
+            for r in result:
+                self.assertEqual(r.a, "hi")
+                self.assertEqual(r.b, 1)
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.arrow.test_arrow_grouped_map import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 029293ab70f..9ffdbb21871 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -275,6 +275,26 @@ class UserDefinedFunction:
                         "return_type": str(self._returnType_placeholder),
                     },
                 )
+        elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+            if isinstance(self._returnType_placeholder, StructType):
+                try:
+                    to_arrow_type(self._returnType_placeholder)
+                except TypeError:
+                    raise PySparkNotImplementedError(
+                        error_class="NOT_IMPLEMENTED",
+                        message_parameters={
+                            "feature": "Invalid return type with grouped map 
Arrow UDFs or "
+                            f"at groupby.applyInArrow: 
{self._returnType_placeholder}"
+                        },
+                    )
+            else:
+                raise PySparkTypeError(
+                    error_class="INVALID_RETURN_TYPE_FOR_ARROW_UDF",
+                    message_parameters={
+                        "eval_type": "SQL_GROUPED_MAP_ARROW_UDF",
+                        "return_type": str(self._returnType_placeholder),
+                    },
+                )
         elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             if isinstance(self._returnType_placeholder, StructType):
                 try:
@@ -295,6 +315,26 @@ class UserDefinedFunction:
                         "return_type": str(self._returnType_placeholder),
                     },
                 )
+        elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
+            if isinstance(self._returnType_placeholder, StructType):
+                try:
+                    to_arrow_type(self._returnType_placeholder)
+                except TypeError:
+                    raise PySparkNotImplementedError(
+                        error_class="NOT_IMPLEMENTED",
+                        message_parameters={
+                            "feature": "Invalid return type in 
cogroup.applyInArrow: "
+                            f"{self._returnType_placeholder}"
+                        },
+                    )
+            else:
+                raise PySparkTypeError(
+                    error_class="INVALID_RETURN_TYPE_FOR_ARROW_UDF",
+                    message_parameters={
+                        "eval_type": "SQL_COGROUPED_MAP_ARROW_UDF",
+                        "return_type": str(self._returnType_placeholder),
+                    },
+                )
         elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
             try:
                 # StructType is not yet allowed as a return type, explicitly 
check here to fail fast
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 060594292ad..2534238b43c 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -47,8 +47,10 @@ from pyspark.sql.functions import 
SkipRestOfInputTableException
 from pyspark.sql.pandas.serializers import (
     ArrowStreamPandasUDFSerializer,
     ArrowStreamPandasUDTFSerializer,
-    CogroupUDFSerializer,
+    CogroupArrowUDFSerializer,
+    CogroupPandasUDFSerializer,
     ArrowStreamUDFSerializer,
+    ArrowStreamGroupUDFSerializer,
     ApplyInPandasWithStateSerializer,
 )
 from pyspark.sql.pandas.types import to_arrow_type
@@ -307,6 +309,33 @@ def wrap_arrow_batch_iter_udf(f, return_type):
     )
 
 
+def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf):
+    _assign_cols_by_name = assign_cols_by_name(runner_conf)
+
+    if _assign_cols_by_name:
+        expected_cols_and_types = {
+            col.name: to_arrow_type(col.dataType) for col in return_type.fields
+        }
+    else:
+        expected_cols_and_types = [
+            (col.name, to_arrow_type(col.dataType)) for col in 
return_type.fields
+        ]
+
+    def wrapped(left_key_table, left_value_table, right_key_table, 
right_value_table):
+        if len(argspec.args) == 2:
+            result = f(left_value_table, right_value_table)
+        elif len(argspec.args) == 3:
+            key_table = left_key_table if left_key_table.num_rows > 0 else 
right_key_table
+            key = tuple(c[0] for c in key_table.columns)
+            result = f(key, left_value_table, right_value_table)
+
+        verify_arrow_result(result, _assign_cols_by_name, 
expected_cols_and_types)
+
+        return result.to_batches()
+
+    return lambda kl, vl, kr, vr: (wrapped(kl, vl, kr, vr), 
to_arrow_type(return_type))
+
+
 def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
     _assign_cols_by_name = assign_cols_by_name(runner_conf)
 
@@ -331,6 +360,104 @@ def wrap_cogrouped_map_pandas_udf(f, return_type, 
argspec, runner_conf):
     return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), 
to_arrow_type(return_type))]
 
 
+def verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types):
+    import pyarrow as pa
+
+    if not isinstance(table, pa.Table):
+        raise PySparkTypeError(
+            error_class="UDF_RETURN_TYPE",
+            message_parameters={
+                "expected": "pyarrow.Table",
+                "actual": type(table).__name__,
+            },
+        )
+
+    # the types of the fields have to be identical to return type
+    # an empty table can have no columns; if there are columns, they have to 
match
+    if table.num_columns != 0 or table.num_rows != 0:
+        # columns are either mapped by name or position
+        if assign_cols_by_name:
+            actual_cols_and_types = {
+                name: dataType for name, dataType in zip(table.schema.names, 
table.schema.types)
+            }
+            missing = sorted(
+                
list(set(expected_cols_and_types.keys()).difference(actual_cols_and_types.keys()))
+            )
+            extra = sorted(
+                
list(set(actual_cols_and_types.keys()).difference(expected_cols_and_types.keys()))
+            )
+
+            if missing or extra:
+                missing = f" Missing: {', '.join(missing)}." if missing else ""
+                extra = f" Unexpected: {', '.join(extra)}." if extra else ""
+
+                raise PySparkRuntimeError(
+                    error_class="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDF",
+                    message_parameters={
+                        "missing": missing,
+                        "extra": extra,
+                    },
+                )
+
+            column_types = [
+                (name, expected_cols_and_types[name], 
actual_cols_and_types[name])
+                for name in sorted(expected_cols_and_types.keys())
+            ]
+        else:
+            actual_cols_and_types = [
+                (name, dataType) for name, dataType in zip(table.schema.names, 
table.schema.types)
+            ]
+            column_types = [
+                (expected_name, expected_type, actual_type)
+                for (expected_name, expected_type), (actual_name, actual_type) 
in zip(
+                    expected_cols_and_types, actual_cols_and_types
+                )
+            ]
+
+        type_mismatch = [
+            (name, expected, actual)
+            for name, expected, actual in column_types
+            if actual != expected
+        ]
+
+        if type_mismatch:
+            raise PySparkRuntimeError(
+                error_class="RESULT_TYPE_MISMATCH_FOR_ARROW_UDF",
+                message_parameters={
+                    "mismatch": ", ".join(
+                        "column '{}' (expected {}, actual {})".format(name, 
expected, actual)
+                        for name, expected, actual in type_mismatch
+                    )
+                },
+            )
+
+
+def wrap_grouped_map_arrow_udf(f, return_type, argspec, runner_conf):
+    _assign_cols_by_name = assign_cols_by_name(runner_conf)
+
+    if _assign_cols_by_name:
+        expected_cols_and_types = {
+            col.name: to_arrow_type(col.dataType) for col in return_type.fields
+        }
+    else:
+        expected_cols_and_types = [
+            (col.name, to_arrow_type(col.dataType)) for col in 
return_type.fields
+        ]
+
+    def wrapped(key_table, value_table):
+        if len(argspec.args) == 1:
+            result = f(value_table)
+        elif len(argspec.args) == 2:
+            key = tuple(c[0] for c in key_table.columns)
+            result = f(key, value_table)
+
+        verify_arrow_result(result, _assign_cols_by_name, 
expected_cols_and_types)
+
+        return result.to_batches()
+
+    return lambda k, v: (wrapped(k, v), to_arrow_type(return_type))
+
+
 def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
     _assign_cols_by_name = assign_cols_by_name(runner_conf)
 
@@ -618,11 +745,17 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+        argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
+        return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, 
argspec, runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
         return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, 
return_type)
     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
+    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
+        argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
+        return args_offsets, wrap_cogrouped_map_arrow_udf(func, return_type, 
argspec, runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
         return wrap_grouped_agg_pandas_udf(func, args_offsets, kwargs_offsets, 
return_type)
     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
@@ -635,7 +768,10 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         raise ValueError("Unknown eval type: {}".format(eval_type))
 
 
-# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and 
SQL_ARROW_BATCHED_UDF when
+# Used by SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_ARROW_UDF,
+# SQL_COGROUPED_MAP_PANDAS_UDF, SQL_COGROUPED_MAP_ARROW_UDF,
+# SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+# SQL_SCALAR_PANDAS_UDF and SQL_ARROW_BATCHED_UDF when
 # returning StructType
 def assign_cols_by_name(runner_conf):
     return (
@@ -1197,6 +1333,8 @@ def read_udfs(pickleSer, infile, eval_type):
         PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
         PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+        PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
     ):
         # Load conf used for pandas_udf evaluation
         num_conf = read_int(infile)
@@ -1215,9 +1353,12 @@ def read_udfs(pickleSer, infile, eval_type):
             
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
             == "true"
         )
+        _assign_cols_by_name = assign_cols_by_name(runner_conf)
 
-        if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
-            ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name(runner_conf))
+        if eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
+            ser = CogroupArrowUDFSerializer(_assign_cols_by_name)
+        elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
+            ser = CogroupPandasUDFSerializer(timezone, safecheck, 
_assign_cols_by_name)
         elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
             arrow_max_records_per_batch = runner_conf.get(
                 "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
@@ -1227,12 +1368,14 @@ def read_udfs(pickleSer, infile, eval_type):
             ser = ApplyInPandasWithStateSerializer(
                 timezone,
                 safecheck,
-                assign_cols_by_name(runner_conf),
+                _assign_cols_by_name,
                 state_object_schema,
                 arrow_max_records_per_batch,
             )
         elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
             ser = ArrowStreamUDFSerializer()
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+            ser = ArrowStreamGroupUDFSerializer(_assign_cols_by_name)
         else:
             # Scalar Pandas UDF handles struct type arguments as pandas 
DataFrames instead of
             # pandas Series. See SPARK-27240.
@@ -1251,7 +1394,7 @@ def read_udfs(pickleSer, infile, eval_type):
             ser = ArrowStreamPandasUDFSerializer(
                 timezone,
                 safecheck,
-                assign_cols_by_name(runner_conf),
+                _assign_cols_by_name,
                 df_for_struct,
                 struct_in_pandas,
                 ndarray_as_list,
@@ -1374,6 +1517,32 @@ def read_udfs(pickleSer, infile, eval_type):
             vals = [a[o] for o in parsed_offsets[0][1]]
             return f(keys, vals)
 
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+        import pyarrow as pa
+
+        # We assume there is only one UDF here because grouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+
+        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
+        # distinguish between grouping attributes and data attributes
+        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        def batch_from_offset(batch, offsets):
+            return pa.RecordBatch.from_arrays(
+                arrays=[batch.columns[o] for o in offsets],
+                names=[batch.schema.names[o] for o in offsets],
+            )
+
+        def table_from_batches(batches, offsets):
+            return pa.Table.from_batches([batch_from_offset(batch, offsets) 
for batch in batches])
+
+        def mapper(a):
+            keys = table_from_batches(a, parsed_offsets[0][0])
+            vals = table_from_batches(a, parsed_offsets[0][1])
+            return f(keys, vals)
+
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
         # We assume there is only one UDF here because grouped map doesn't
         # support combining multiple UDFs.
@@ -1426,6 +1595,32 @@ def read_udfs(pickleSer, infile, eval_type):
             df2_vals = [a[1][o] for o in parsed_offsets[1][1]]
             return f(df1_keys, df1_vals, df2_keys, df2_vals)
 
+    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
+        import pyarrow as pa
+
+        # We assume there is only one UDF here because cogrouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        def batch_from_offset(batch, offsets):
+            return pa.RecordBatch.from_arrays(
+                arrays=[batch.columns[o] for o in offsets],
+                names=[batch.schema.names[o] for o in offsets],
+            )
+
+        def table_from_batches(batches, offsets):
+            return pa.Table.from_batches([batch_from_offset(batch, offsets) 
for batch in batches])
+
+        def mapper(a):
+            df1_keys = table_from_batches(a[0], parsed_offsets[0][0])
+            df1_vals = table_from_batches(a[0], parsed_offsets[0][1])
+            df2_keys = table_from_batches(a[1], parsed_offsets[1][0])
+            df2_vals = table_from_batches(a[1], parsed_offsets[1][1])
+            return f(df1_keys, df1_vals, df2_keys, df2_vals)
+
     else:
         udfs = []
         for i in range(num_udfs):
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 26680e87fc2..d4ed673c351 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -48,6 +48,29 @@ case class FlatMapGroupsInPandas(
     copy(child = newChild)
 }
 
+/**
+ * FlatMap groups using a udf: iter(pyarrow.RecordBatch) -> 
iter(pyarrow.RecordBatch).
+ * This is used by DataFrame.groupby().applyInArrow().
+ */
+case class FlatMapGroupsInArrow(
+    groupingAttributes: Seq[Attribute],
+    functionExpr: Expression,
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+
+  /**
+   * This is needed because output attributes are considered `references` when
+   * passed through the constructor.
+   *
+   * Without this, catalyst will complain that output attributes are missing
+   * from the input.
+   */
+  override val producedAttributes = AttributeSet(output)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): 
FlatMapGroupsInArrow =
+    copy(child = newChild)
+}
+
 /**
  * Map partitions using a udf: iter(pandas.Dataframe) -> 
iter(pandas.DataFrame).
  * This is used by DataFrame.mapInPandas()
@@ -173,6 +196,31 @@ case class FlatMapGroupsInPandasWithState(
     newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = 
newChild)
 }
 
+/**
+ * Flatmap cogroups using a udf: iter(pyarrow.RecordBatch) -> 
iter(pyarrow.RecordBatch)
+ * This is used by DataFrame.groupby().cogroup().applyInArrow().
+ */
+case class FlatMapCoGroupsInArrow(
+    leftGroupingLen: Int,
+    rightGroupingLen: Int,
+    functionExpr: Expression,
+    output: Seq[Attribute],
+    left: LogicalPlan,
+    right: LogicalPlan) extends BinaryNode {
+
+  override val producedAttributes = AttributeSet(output)
+  override lazy val references: AttributeSet =
+    AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) 
-- producedAttributes
+
+  def leftAttributes: Seq[Attribute] = left.output.take(leftGroupingLen)
+
+  def rightAttributes: Seq[Attribute] = right.output.take(rightGroupingLen)
+
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapCoGroupsInArrow =
+    copy(left = newLeft, right = newRight)
+}
+
 trait BaseEvalPython extends UnaryNode {
 
   def udfs: Seq[PythonUDF]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 37223a33a2e..5ad96cdba21 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -557,7 +557,7 @@ class RelationalGroupedDataset protected[sql](
    */
   private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = {
     require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
-      "Must pass a grouped map udf")
+      "Must pass a grouped map pandas udf")
     require(expr.dataType.isInstanceOf[StructType],
       s"The returnType of the udf must be a ${StructType.simpleString}")
 
@@ -575,6 +575,38 @@ class RelationalGroupedDataset protected[sql](
     Dataset.ofRows(df.sparkSession, plan)
   }
 
+  /**
+   * Applies a grouped vectorized python user-defined function to each group 
of data.
+   * The user-defined function defines a transformation: `pandas.DataFrame` -> 
`pandas.DataFrame`.
+   * For each group, all elements in the group are passed as a 
`pandas.DataFrame` and the results
+   * for all groups are combined into a new [[DataFrame]].
+   *
+   * This function does not support partial aggregation, and requires 
shuffling all the data in
+   * the [[DataFrame]].
+   *
+   * This function uses Apache Arrow as serialization format between Java 
executors and Python
+   * workers.
+   */
+  private[sql] def flatMapGroupsInArrow(expr: PythonUDF): DataFrame = {
+    require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+      "Must pass a grouped map arrow udf")
+    require(expr.dataType.isInstanceOf[StructType],
+      s"The returnType of the udf must be a ${StructType.simpleString}")
+
+    val groupingNamedExpressions = groupingExprs.map {
+      case ne: NamedExpression => ne
+      case other => Alias(other, other.toString)()
+    }
+    val child = df.logicalPlan
+    val project = df.sparkSession.sessionState.executePlan(
+      Project(groupingNamedExpressions ++ child.output, child)).analyzed
+    val groupingAttributes = 
project.output.take(groupingNamedExpressions.length)
+    val output = toAttributes(expr.dataType.asInstanceOf[StructType])
+    val plan = FlatMapGroupsInArrow(groupingAttributes, expr, output, project)
+
+    Dataset.ofRows(df.sparkSession, plan)
+  }
+
   /**
    * Applies a vectorized python user-defined function to each cogrouped data.
    * The user-defined function defines a transformation:
@@ -589,7 +621,7 @@ class RelationalGroupedDataset protected[sql](
       r: RelationalGroupedDataset,
       expr: PythonUDF): DataFrame = {
     require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
-      "Must pass a cogrouped map udf")
+      "Must pass a cogrouped map pandas udf")
     require(this.groupingExprs.length == r.groupingExprs.length,
       "Cogroup keys must have same size: " +
         s"${this.groupingExprs.length} != ${r.groupingExprs.length}")
@@ -621,6 +653,52 @@ class RelationalGroupedDataset protected[sql](
     Dataset.ofRows(df.sparkSession, plan)
   }
 
+  /**
+   * Applies a vectorized python user-defined function to each cogrouped data.
+   * The user-defined function defines a transformation:
+   * `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`.
+   *  For each group in the cogrouped data, all elements in the group are 
passed as a
+   * `pandas.DataFrame` and the results for all cogroups are combined into a 
new [[DataFrame]].
+   *
+   * This function uses Apache Arrow as serialization format between Java 
executors and Python
+   * workers.
+   */
+  private[sql] def flatMapCoGroupsInArrow(
+      r: RelationalGroupedDataset,
+      expr: PythonUDF): DataFrame = {
+    require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
+      "Must pass a cogrouped map arrow udf")
+    require(this.groupingExprs.length == r.groupingExprs.length,
+      "Cogroup keys must have same size: " +
+        s"${this.groupingExprs.length} != ${r.groupingExprs.length}")
+    require(expr.dataType.isInstanceOf[StructType],
+      s"The returnType of the udf must be a ${StructType.simpleString}")
+
+    val leftGroupingNamedExpressions = groupingExprs.map {
+      case ne: NamedExpression => ne
+      case other => Alias(other, other.toString)()
+    }
+
+    val rightGroupingNamedExpressions = r.groupingExprs.map {
+      case ne: NamedExpression => ne
+      case other => Alias(other, other.toString)()
+    }
+
+    val leftChild = df.logicalPlan
+    val rightChild = r.df.logicalPlan
+
+    val left = df.sparkSession.sessionState.executePlan(
+      Project(leftGroupingNamedExpressions ++ leftChild.output, 
leftChild)).analyzed
+    val right = r.df.sparkSession.sessionState.executePlan(
+      Project(rightGroupingNamedExpressions ++ rightChild.output, 
rightChild)).analyzed
+
+    val output = toAttributes(expr.dataType.asInstanceOf[StructType])
+    val plan = FlatMapCoGroupsInArrow(
+      leftGroupingNamedExpressions.length, 
rightGroupingNamedExpressions.length,
+      expr, output, left, right)
+    Dataset.ofRows(df.sparkSession, plan)
+  }
+
   /**
    * Applies a grouped vectorized python user-defined function to each group 
of data.
    * The user-defined function defines a transformation: iterator of 
`pandas.DataFrame` ->
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 0e25a9539af..df770bd5eee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -826,10 +826,16 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           f, p, b, is, ot, planLater(child)) :: Nil
       case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
         execution.python.FlatMapGroupsInPandasExec(grouping, func, output, 
planLater(child)) :: Nil
+      case logical.FlatMapGroupsInArrow(grouping, func, output, child) =>
+        execution.python.FlatMapGroupsInArrowExec(grouping, func, output, 
planLater(child)) :: Nil
       case f @ logical.FlatMapCoGroupsInPandas(_, _, func, output, left, 
right) =>
         execution.python.FlatMapCoGroupsInPandasExec(
           f.leftAttributes, f.rightAttributes,
           func, output, planLater(left), planLater(right)) :: Nil
+      case f @ logical.FlatMapCoGroupsInArrow(_, _, func, output, left, right) 
=>
+        execution.python.FlatMapCoGroupsInArrowExec(
+          f.leftAttributes, f.rightAttributes,
+          func, output, planLater(left), planLater(right)) :: Nil
       case logical.MapInPandas(func, output, child, isBarrier) =>
         execution.python.MapInPandasExec(func, output, planLater(child), 
isBarrier) :: Nil
       case logical.PythonMapInArrow(func, output, child, isBarrier) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala
new file mode 100644
index 00000000000..17c68a86b75
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+
+
+/**
+ * Physical node for 
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]]
+ *
+ * The input dataframes are first Cogrouped.  Rows from each side of the 
cogroup are passed to the
+ * Python worker via Arrow.  As each side of the cogroup may have a different 
schema we send every
+ * group in its own Arrow stream.
+ * The Python worker turns the resulting record batches to 
`pandas.DataFrame`s, invokes the
+ * user-defined function, and passes the resulting `pandas.DataFrame`
+ * as an Arrow record batch. Finally, each record batch is turned to
+ * Iterator[InternalRow] using ColumnarBatch.
+ *
+ * Note on memory usage:
+ * Both the Python worker and the Java executor need to have enough memory to
+ * hold the largest cogroup. The memory on the Java side is used to construct 
the
+ * record batches (off heap memory). The memory on the Python side is used for
+ * holding the `pandas.DataFrame`. It's possible to further split one group 
into
+ * multiple record batches to reduce the memory footprint on the Java side, 
this
+ * is left as future work.
+ */
+case class FlatMapCoGroupsInArrowExec(
+    leftGroup: Seq[Attribute],
+    rightGroup: Seq[Attribute],
+    func: Expression,
+    output: Seq[Attribute],
+    left: SparkPlan,
+    right: SparkPlan)
+  extends FlatMapCoGroupsInPythonExec {
+
+  protected val pythonEvalType: Int = 
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): FlatMapCoGroupsInArrowExec =
+    copy(left = newLeft, right = newRight)
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index bbfe97d1947..32d7748bcaa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -17,15 +17,9 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.JobArtifactSet
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, Partitioning}
-import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, 
SparkPlan}
-import org.apache.spark.sql.execution.python.PandasGroupUtils._
+import org.apache.spark.sql.execution.SparkPlan
 
 
 /**
@@ -54,57 +48,9 @@ case class FlatMapCoGroupsInPandasExec(
     output: Seq[Attribute],
     left: SparkPlan,
     right: SparkPlan)
-  extends SparkPlan with BinaryExecNode with PythonSQLMetrics {
+  extends FlatMapCoGroupsInPythonExec {
 
-  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
-  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-  private val pandasFunction = func.asInstanceOf[PythonUDF].func
-  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
-
-  override def producedAttributes: AttributeSet = AttributeSet(output)
-
-  override def outputPartitioning: Partitioning = left.outputPartitioning
-
-  override def requiredChildDistribution: Seq[Distribution] = {
-    val leftDist = if (leftGroup.isEmpty) AllTuples else 
ClusteredDistribution(leftGroup)
-    val rightDist = if (rightGroup.isEmpty) AllTuples else 
ClusteredDistribution(rightGroup)
-    leftDist :: rightDist :: Nil
-  }
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
-    leftGroup
-      .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) 
:: Nil
-  }
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
-    val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, 
rightGroup)
-    val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
-
-    // Map cogrouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
-    left.execute().zipPartitions(right.execute())  { (leftData, rightData) =>
-      if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else {
-
-        val leftGrouped = groupAndProject(leftData, leftGroup, left.output, 
leftDedup)
-        val rightGrouped = groupAndProject(rightData, rightGroup, 
right.output, rightDedup)
-        val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
-          .map { case (_, l, r) => (l, r) }
-
-        val runner = new CoGroupedArrowPythonRunner(
-          chainedFunc,
-          PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
-          Array(leftArgOffsets ++ rightArgOffsets),
-          DataTypeUtils.fromAttributes(leftDedup),
-          DataTypeUtils.fromAttributes(rightDedup),
-          sessionLocalTimeZone,
-          pythonRunnerConf,
-          pythonMetrics,
-          jobArtifactUUID)
-
-        executePython(data, output, runner)
-      }
-    }
-  }
+  protected val pythonEvalType: Int = 
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
 
   override protected def withNewChildrenInternal(
       newLeft: SparkPlan, newRight: SparkPlan): FlatMapCoGroupsInPandasExec =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPythonExec.scala
similarity index 68%
copy from 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
copy to 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPythonExec.scala
index bbfe97d1947..f75b0019f10 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPythonExec.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.python
 
 import org.apache.spark.JobArtifactSet
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.api.python.ChainedPythonFunctions
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -29,32 +29,17 @@ import 
org.apache.spark.sql.execution.python.PandasGroupUtils._
 
 
 /**
- * Physical node for 
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]]
- *
- * The input dataframes are first Cogrouped.  Rows from each side of the 
cogroup are passed to the
- * Python worker via Arrow.  As each side of the cogroup may have a different 
schema we send every
- * group in its own Arrow stream.
- * The Python worker turns the resulting record batches to 
`pandas.DataFrame`s, invokes the
- * user-defined function, and passes the resulting `pandas.DataFrame`
- * as an Arrow record batch. Finally, each record batch is turned to
- * Iterator[InternalRow] using ColumnarBatch.
- *
- * Note on memory usage:
- * Both the Python worker and the Java executor need to have enough memory to
- * hold the largest cogroup. The memory on the Java side is used to construct 
the
- * record batches (off heap memory). The memory on the Python side is used for
- * holding the `pandas.DataFrame`. It's possible to further split one group 
into
- * multiple record batches to reduce the memory footprint on the Java side, 
this
- * is left as future work.
+ * Base class for Python-based FlatMapCoGroupsIn*Exec.
  */
-case class FlatMapCoGroupsInPandasExec(
-    leftGroup: Seq[Attribute],
-    rightGroup: Seq[Attribute],
-    func: Expression,
-    output: Seq[Attribute],
-    left: SparkPlan,
-    right: SparkPlan)
-  extends SparkPlan with BinaryExecNode with PythonSQLMetrics {
+trait FlatMapCoGroupsInPythonExec extends SparkPlan with BinaryExecNode with 
PythonSQLMetrics {
+  val leftGroup: Seq[Attribute]
+  val rightGroup: Seq[Attribute]
+  val func: Expression
+  val output: Seq[Attribute]
+  val left: SparkPlan
+  val right: SparkPlan
+
+  protected val pythonEvalType: Int
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
@@ -92,7 +77,7 @@ case class FlatMapCoGroupsInPandasExec(
 
         val runner = new CoGroupedArrowPythonRunner(
           chainedFunc,
-          PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+          pythonEvalType,
           Array(leftArgOffsets ++ rightArgOffsets),
           DataTypeUtils.fromAttributes(leftDedup),
           DataTypeUtils.fromAttributes(rightDedup),
@@ -105,8 +90,4 @@ case class FlatMapCoGroupsInPandasExec(
       }
     }
   }
-
-  override protected def withNewChildrenInternal(
-      newLeft: SparkPlan, newRight: SparkPlan): FlatMapCoGroupsInPandasExec =
-    copy(left = newLeft, right = newRight)
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
new file mode 100644
index 00000000000..b0dd800af8f
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{StructField, StructType}
+
+
+/**
+ * Physical node for 
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
+ *
+ * Rows in each group are passed to the Python worker as an Arrow record batch.
+ * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the
+ * user-defined function, and passes the resulting `pandas.DataFrame`
+ * as an Arrow record batch. Finally, each record batch is turned to
+ * Iterator[InternalRow] using ColumnarBatch.
+ *
+ * Note on memory usage:
+ * Both the Python worker and the Java executor need to have enough memory to
+ * hold the largest group. The memory on the Java side is used to construct the
+ * record batch (off heap memory). The memory on the Python side is used for
+ * holding the `pandas.DataFrame`. It's possible to further split one group 
into
+ * multiple record batches to reduce the memory footprint on the Java side, 
this
+ * is left as future work.
+ */
+case class FlatMapGroupsInArrowExec(
+    groupingAttributes: Seq[Attribute],
+    func: Expression,
+    output: Seq[Attribute],
+    child: SparkPlan)
+  extends FlatMapGroupsInPythonExec {
+
+  protected val pythonEvalType: Int = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+
+  override protected def groupedData(iter: Iterator[InternalRow], attrs: 
Seq[Attribute]):
+      Iterator[Iterator[InternalRow]] =
+    super.groupedData(iter, attrs)
+      // Here we wrap it via another row so that Python sides understand it as 
a DataFrame.
+      .map(_.map(InternalRow(_)))
+
+  override protected def groupedSchema(attrs: Seq[Attribute]): StructType =
+    StructType(StructField("struct", super.groupedSchema(attrs)) :: Nil)
+
+  override protected def withNewChildInternal(newChild: SparkPlan): 
FlatMapGroupsInArrowExec =
+    copy(child = newChild)
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index f2d21ce8e96..88747899720 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -17,15 +17,9 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.JobArtifactSet
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, Partitioning}
-import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.execution.python.PandasGroupUtils._
+import org.apache.spark.sql.execution.SparkPlan
 
 
 /**
@@ -50,55 +44,9 @@ case class FlatMapGroupsInPandasExec(
     func: Expression,
     output: Seq[Attribute],
     child: SparkPlan)
-  extends SparkPlan with UnaryExecNode with PythonSQLMetrics {
+  extends FlatMapGroupsInPythonExec {
 
-  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
-  private val largeVarTypes = conf.arrowUseLargeVarTypes
-  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-  private val pandasFunction = func.asInstanceOf[PythonUDF].func
-  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
-  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
-
-  override def producedAttributes: AttributeSet = AttributeSet(output)
-
-  override def outputPartitioning: Partitioning = child.outputPartitioning
-
-  override def requiredChildDistribution: Seq[Distribution] = {
-    if (groupingAttributes.isEmpty) {
-      AllTuples :: Nil
-    } else {
-      ClusteredDistribution(groupingAttributes) :: Nil
-    }
-  }
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    val inputRDD = child.execute()
-
-    val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, 
groupingAttributes)
-
-    // Map grouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
-    inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
-
-      val data = groupAndProject(iter, groupingAttributes, child.output, 
dedupAttributes)
-        .map { case (_, x) => x }
-
-      val runner = new ArrowPythonRunner(
-        chainedFunc,
-        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
-        Array(argOffsets),
-        DataTypeUtils.fromAttributes(dedupAttributes),
-        sessionLocalTimeZone,
-        largeVarTypes,
-        pythonRunnerConf,
-        pythonMetrics,
-        jobArtifactUUID)
-
-      executePython(data, output, runner)
-    }}
-  }
+  protected val pythonEvalType: Int = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
FlatMapGroupsInPandasExec =
     copy(child = newChild)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
similarity index 61%
copy from 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
copy to 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
index f2d21ce8e96..0c18206a825 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.python
 
 import org.apache.spark.JobArtifactSet
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.api.python.ChainedPythonFunctions
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -26,37 +26,25 @@ import 
org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.python.PandasGroupUtils._
+import org.apache.spark.sql.types.StructType
 
 
 /**
- * Physical node for 
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
- *
- * Rows in each group are passed to the Python worker as an Arrow record batch.
- * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the
- * user-defined function, and passes the resulting `pandas.DataFrame`
- * as an Arrow record batch. Finally, each record batch is turned to
- * Iterator[InternalRow] using ColumnarBatch.
- *
- * Note on memory usage:
- * Both the Python worker and the Java executor need to have enough memory to
- * hold the largest group. The memory on the Java side is used to construct the
- * record batch (off heap memory). The memory on the Python side is used for
- * holding the `pandas.DataFrame`. It's possible to further split one group 
into
- * multiple record batches to reduce the memory footprint on the Java side, 
this
- * is left as future work.
+ * Base class for Python-based FlatMapGroupsIn*Exec.
  */
-case class FlatMapGroupsInPandasExec(
-    groupingAttributes: Seq[Attribute],
-    func: Expression,
-    output: Seq[Attribute],
-    child: SparkPlan)
-  extends SparkPlan with UnaryExecNode with PythonSQLMetrics {
+trait FlatMapGroupsInPythonExec extends SparkPlan with UnaryExecNode with 
PythonSQLMetrics {
+  val groupingAttributes: Seq[Attribute]
+  val func: Expression
+  val output: Seq[Attribute]
+  val child: SparkPlan
+
+  protected val pythonEvalType: Int
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val largeVarTypes = conf.arrowUseLargeVarTypes
   private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-  private val pandasFunction = func.asInstanceOf[PythonUDF].func
-  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
+  private val pythonFunction = func.asInstanceOf[PythonUDF].func
+  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
@@ -74,6 +62,14 @@ case class FlatMapGroupsInPandasExec(
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
     Seq(groupingAttributes.map(SortOrder(_, Ascending)))
 
+  protected def groupedData(iter: Iterator[InternalRow], attrs: 
Seq[Attribute]):
+      Iterator[Iterator[InternalRow]] =
+    groupAndProject(iter, groupingAttributes, child.output, attrs)
+      .map { case (_, x) => x }
+
+  protected def groupedSchema(attrs: Seq[Attribute]): StructType =
+    DataTypeUtils.fromAttributes(attrs)
+
   override protected def doExecute(): RDD[InternalRow] = {
     val inputRDD = child.execute()
 
@@ -82,14 +78,13 @@ case class FlatMapGroupsInPandasExec(
     // Map grouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
     inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
 
-      val data = groupAndProject(iter, groupingAttributes, child.output, 
dedupAttributes)
-        .map { case (_, x) => x }
+      val data = groupedData(iter, dedupAttributes)
 
       val runner = new ArrowPythonRunner(
         chainedFunc,
-        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+        pythonEvalType,
         Array(argOffsets),
-        DataTypeUtils.fromAttributes(dedupAttributes),
+        groupedSchema(dedupAttributes),
         sessionLocalTimeZone,
         largeVarTypes,
         pythonRunnerConf,
@@ -99,7 +94,4 @@ case class FlatMapGroupsInPandasExec(
       executePython(data, output, runner)
     }}
   }
-
-  override protected def withNewChildInternal(newChild: SparkPlan): 
FlatMapGroupsInPandasExec =
-    copy(child = newChild)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to