This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f37cd07534e4 [SPARK-53976][PYTHON] Support logging in Pandas/Arrow UDFs
f37cd07534e4 is described below

commit f37cd07534e4eab267ffcdd8c40923b8a175c789
Author: Takuya Ueshin <[email protected]>
AuthorDate: Thu Oct 30 10:30:23 2025 -0700

    [SPARK-53976][PYTHON] Support logging in Pandas/Arrow UDFs
    
    ### What changes were proposed in this pull request?
    
    Supports logging in Pandas/Arrow UDFs.
    
    ### Why are the changes needed?
    
    The basic logging infrastructure was introduced in 
https://github.com/apache/spark/pull/52689, and other UDF types should also 
support logging.
    
    Here adding support for Pandas and Arrow UDFs.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the logging feature will be available in Pandas/Arrow UDFs.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52785 from ueshin/issues/SPARK-53976/pandas_arrow_udfs.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../sql/tests/arrow/test_arrow_cogrouped_map.py    | 46 +++++++++++++
 .../sql/tests/arrow/test_arrow_grouped_map.py      | 77 +++++++++++++++++++++
 python/pyspark/sql/tests/arrow/test_arrow_map.py   | 55 +++++++++++++++
 .../sql/tests/arrow/test_arrow_python_udf.py       | 12 ----
 .../sql/tests/arrow/test_arrow_udf_grouped_agg.py  | 40 ++++++++++-
 .../sql/tests/arrow/test_arrow_udf_scalar.py       | 79 +++++++++++++++++++++-
 .../sql/tests/arrow/test_arrow_udf_window.py       | 49 +++++++++++++-
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  | 45 ++++++++++++
 .../sql/tests/pandas/test_pandas_grouped_map.py    | 39 +++++++++++
 python/pyspark/sql/tests/pandas/test_pandas_map.py | 41 ++++++++++-
 .../tests/pandas/test_pandas_udf_grouped_agg.py    | 37 +++++++++-
 .../sql/tests/pandas/test_pandas_udf_scalar.py     | 75 +++++++++++++++++++-
 .../sql/tests/pandas/test_pandas_udf_window.py     | 46 +++++++++++++
 .../v2/python/UserDefinedPythonDataSource.scala    |  3 +-
 .../python/ArrowAggregatePythonExec.scala          |  7 ++
 .../sql/execution/python/ArrowEvalPythonExec.scala |  9 +++
 .../sql/execution/python/ArrowPythonRunner.scala   | 22 ++++--
 .../python/ArrowWindowPythonEvaluatorFactory.scala |  2 +
 .../execution/python/ArrowWindowPythonExec.scala   |  8 +++
 .../python/CoGroupedArrowPythonRunner.scala        |  9 +++
 .../python/FlatMapCoGroupsInBatchExec.scala        |  7 ++
 .../python/FlatMapGroupsInBatchExec.scala          |  7 ++
 .../python/MapInBatchEvaluatorFactory.scala        |  4 +-
 .../sql/execution/python/MapInBatchExec.scala      |  9 ++-
 24 files changed, 700 insertions(+), 28 deletions(-)

diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
index 13edcec6b57f..2bdd7bda3bc2 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
@@ -17,6 +17,7 @@
 import os
 import time
 import unittest
+import logging
 
 from pyspark.errors import PythonException
 from pyspark.sql import Row
@@ -26,6 +27,8 @@ from pyspark.testing.sqlutils import (
     have_pyarrow,
     pyarrow_requirement_message,
 )
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
 
 if have_pyarrow:
     import pyarrow as pa
@@ -367,6 +370,49 @@ class CogroupedMapInArrowTestsMixin:
             with 
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
                 CogroupedMapInArrowTestsMixin.test_apply_in_arrow(self)
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_cogroup_apply_in_arrow_with_logging(self):
+        import pyarrow as pa
+
+        def func_with_logging(left, right):
+            assert isinstance(left, pa.Table)
+            assert isinstance(right, pa.Table)
+            logger = logging.getLogger("test_arrow_cogrouped_map")
+            logger.warning(
+                "arrow cogrouped map: "
+                + f"{dict(v1=left['v1'].to_pylist(), 
v2=right['v2'].to_pylist())}"
+            )
+            return left.join(right, keys="id", join_type="inner")
+
+        left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)], 
["id", "v1"])
+        right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)], 
["id", "v2"])
+
+        grouped_left = left_df.groupBy("id")
+        grouped_right = right_df.groupBy("id")
+        cogrouped_df = grouped_left.cogroup(grouped_right)
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                cogrouped_df.applyInArrow(func_with_logging, "id long, v1 
long, v2 long"),
+                [Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100, 
300]]
+                + [Row(id=2, v1=20, v2=200)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"arrow cogrouped map: {dict(v1=v1, v2=v2)}",
+                    context={"func_name": func_with_logging.__name__},
+                    logger="test_arrow_cogrouped_map",
+                )
+                for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
+            ],
+        )
+
 
 class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, 
ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
index 8d3d929096b1..829c38385bd0 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
@@ -17,6 +17,7 @@
 import inspect
 import os
 import time
+import logging
 from typing import Iterator, Tuple
 import unittest
 
@@ -29,6 +30,8 @@ from pyspark.testing.sqlutils import (
     have_pyarrow,
     pyarrow_requirement_message,
 )
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
 
 if have_pyarrow:
     import pyarrow as pa
@@ -394,6 +397,80 @@ class ApplyInArrowTestsMixin:
             with 
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
                 ApplyInArrowTestsMixin.test_apply_in_arrow(self)
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_apply_in_arrow_with_logging(self):
+        import pyarrow as pa
+
+        def func_with_logging(group):
+            assert isinstance(group, pa.Table)
+            logger = logging.getLogger("test_arrow_grouped_map")
+            logger.warning(f"arrow grouped map: {group.to_pydict()}")
+            return group
+
+        df = self.spark.range(9).withColumn("value", col("id") * 10)
+        grouped_df = df.groupBy((col("id") % 2).cast("int"))
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                grouped_df.applyInArrow(func_with_logging, "id long, value 
long"),
+                df,
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v 
in lst])}",
+                    context={"func_name": func_with_logging.__name__},
+                    logger="test_arrow_grouped_map",
+                )
+                for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
+            ],
+        )
+
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_apply_in_arrow_iter_with_logging(self):
+        import pyarrow as pa
+
+        def func_with_logging(group: Iterator[pa.RecordBatch]) -> 
Iterator[pa.RecordBatch]:
+            logger = logging.getLogger("test_arrow_grouped_map")
+            for batch in group:
+                assert isinstance(batch, pa.RecordBatch)
+                logger.warning(f"arrow grouped map: {batch.to_pydict()}")
+                yield batch
+
+        df = self.spark.range(9).withColumn("value", col("id") * 10)
+        grouped_df = df.groupBy((col("id") % 2).cast("int"))
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.maxRecordsPerBatch": 3,
+                "spark.sql.pyspark.worker.logging.enabled": "true",
+            }
+        ):
+            assertDataFrameEqual(
+                grouped_df.applyInArrow(func_with_logging, "id long, value 
long"),
+                df,
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v 
in lst])}",
+                    context={"func_name": func_with_logging.__name__},
+                    logger="test_arrow_grouped_map",
+                )
+                for lst in [[0, 2, 4], [6, 8], [1, 3, 5], [7]]
+            ],
+        )
+
 
 class ApplyInArrowTests(ApplyInArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_map.py
index 0f9f5b422440..4a56a32fbcdd 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py
@@ -17,6 +17,7 @@
 import os
 import time
 import unittest
+import logging
 
 from pyspark.sql.utils import PythonException
 from pyspark.testing.sqlutils import (
@@ -26,6 +27,9 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
+from pyspark.sql import Row
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
 
 if have_pyarrow:
     import pyarrow as pa
@@ -221,6 +225,46 @@ class MapInArrowTestsMixin(object):
             df = self.spark.range(1)
             df.mapInArrow(func, "a int").collect()
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_map_in_arrow_with_logging(self):
+        import pyarrow as pa
+
+        def func_with_logging(iterator):
+            logger = logging.getLogger("test_arrow_map")
+            for batch in iterator:
+                assert isinstance(batch, pa.RecordBatch)
+                logger.warning(f"arrow map: {batch.to_pydict()}")
+                yield batch
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.maxRecordsPerBatch": "3",
+                "spark.sql.pyspark.worker.logging.enabled": "true",
+            }
+        ):
+            assertDataFrameEqual(
+                self.spark.range(9, 
numPartitions=2).mapInArrow(func_with_logging, "id long"),
+                [Row(id=i) for i in range(9)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            
self._expected_logs_for_test_map_in_arrow_with_logging(func_with_logging.__name__),
+        )
+
+    def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
+        return [
+            Row(
+                level="WARNING",
+                msg=f"arrow map: {dict(id=lst)}",
+                context={"func_name": func_name},
+                logger="test_arrow_map",
+            )
+            for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+        ]
+
 
 class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
@@ -253,6 +297,17 @@ class 
MapInArrowWithArrowBatchSlicingTestsAndReducedBatchSizeTests(MapInArrowTes
         cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "3")
         cls.spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "10")
 
+    def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
+        return [
+            Row(
+                level="WARNING",
+                msg=f"arrow map: {dict(id=[i])}",
+                context={"func_name": func_name},
+                logger="test_arrow_map",
+            )
+            for i in range(9)
+        ]
+
 
 class MapInArrowWithOutputArrowBatchSlicingRecordsTests(MapInArrowTests):
     @classmethod
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py 
b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
index 55b4edd72d5d..90e05caf2180 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
@@ -60,18 +60,6 @@ class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin):
     def test_register_java_udaf(self):
         super(ArrowPythonUDFTests, self).test_register_java_udaf()
 
-    @unittest.skip(
-        "TODO(SPARK-53976): Python worker logging is not supported for Arrow 
Python UDFs."
-    )
-    def test_udf_with_logging(self):
-        super().test_udf_with_logging()
-
-    @unittest.skip(
-        "TODO(SPARK-53976): Python worker logging is not supported for Arrow 
Python UDFs."
-    )
-    def test_multiple_udfs_with_logging(self):
-        super().test_multiple_udfs_with_logging()
-
     def test_complex_input_types(self):
         row = (
             self.spark.range(1)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index 136a99e19411..f719b4fb16bd 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -16,9 +16,10 @@
 #
 
 import unittest
+import logging
 
 from pyspark.sql.functions import arrow_udf, ArrowUDFType
-from pyspark.util import PythonEvalType
+from pyspark.util import PythonEvalType, is_remote_only
 from pyspark.sql import Row
 from pyspark.sql.types import (
     ArrayType,
@@ -35,6 +36,7 @@ from pyspark.testing.utils import (
     numpy_requirement_message,
     have_pyarrow,
     pyarrow_requirement_message,
+    assertDataFrameEqual,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
@@ -1021,6 +1023,42 @@ class GroupedAggArrowUDFTestsMixin:
 
                     self.assertEqual(expected, result)
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_grouped_agg_arrow_udf_with_logging(self):
+        import pyarrow as pa
+
+        @arrow_udf("double", ArrowUDFType.GROUPED_AGG)
+        def my_grouped_agg_arrow_udf(x):
+            assert isinstance(x, pa.Array)
+            logger = logging.getLogger("test_grouped_agg_arrow")
+            logger.warning(f"grouped agg arrow udf: {len(x)}")
+            return pa.compute.sum(x)
+
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                
df.groupby("id").agg(my_grouped_agg_arrow_udf("v").alias("result")),
+                [Row(id=1, result=3.0), Row(id=2, result=18.0)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"grouped agg arrow udf: {n}",
+                    context={"func_name": my_grouped_agg_arrow_udf.__name__},
+                    logger="test_grouped_agg_arrow",
+                )
+                for n in [2, 3]
+            ],
+        )
+
 
 class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index a682c6515ef6..05f33a4ae42f 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -20,13 +20,14 @@ import random
 import time
 import unittest
 import datetime
+import logging
 from decimal import Decimal
 from typing import Iterator, Tuple
 
 from pyspark.util import PythonEvalType
 
 from pyspark.sql.functions import arrow_udf, ArrowUDFType
-from pyspark.sql import functions as F
+from pyspark.sql import Row, functions as F
 from pyspark.sql.types import (
     IntegerType,
     ByteType,
@@ -51,8 +52,10 @@ from pyspark.testing.utils import (
     numpy_requirement_message,
     have_pyarrow,
     pyarrow_requirement_message,
+    assertDataFrameEqual,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.util import is_remote_only
 
 
 @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -1179,6 +1182,80 @@ class ScalarArrowUDFTestsMixin:
                     def func_a(a: pa.Array) -> pa.Array:
                         return a
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_scalar_arrow_udf_with_logging(self):
+        import pyarrow as pa
+
+        @arrow_udf("string")
+        def my_scalar_arrow_udf(x):
+            assert isinstance(x, pa.Array)
+            logger = logging.getLogger("test_scalar_arrow")
+            logger.warning(f"scalar arrow udf: {x.to_pylist()}")
+            return pa.array(["scalar_arrow_" + str(val.as_py()) for val in x])
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                self.spark.range(3, numPartitions=2).select(
+                    my_scalar_arrow_udf("id").alias("result")
+                ),
+                [Row(result=f"scalar_arrow_{i}") for i in range(3)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"scalar arrow udf: {lst}",
+                    context={"func_name": my_scalar_arrow_udf.__name__},
+                    logger="test_scalar_arrow",
+                )
+                for lst in [[0], [1, 2]]
+            ],
+        )
+
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_scalar_iter_arrow_udf_with_logging(self):
+        import pyarrow as pa
+
+        @arrow_udf("string", ArrowUDFType.SCALAR_ITER)
+        def my_scalar_iter_arrow_udf(it):
+            logger = logging.getLogger("test_scalar_iter_arrow")
+            for x in it:
+                assert isinstance(x, pa.Array)
+                logger.warning(f"scalar iter arrow udf: {x.to_pylist()}")
+                yield pa.array(["scalar_iter_arrow_" + str(val.as_py()) for 
val in x])
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.maxRecordsPerBatch": "3",
+                "spark.sql.pyspark.worker.logging.enabled": "true",
+            }
+        ):
+            assertDataFrameEqual(
+                self.spark.range(9, numPartitions=2).select(
+                    my_scalar_iter_arrow_udf("id").alias("result")
+                ),
+                [Row(result=f"scalar_iter_arrow_{i}") for i in range(9)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"scalar iter arrow udf: {lst}",
+                    context={"func_name": my_scalar_iter_arrow_udf.__name__},
+                    logger="test_scalar_iter_arrow",
+                )
+                for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+            ],
+        )
+
 
 class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index d67b99475bf8..240e34487b00 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -16,10 +16,11 @@
 #
 
 import unittest
+import logging
 
 from pyspark.sql.functions import arrow_udf, ArrowUDFType
-from pyspark.util import PythonEvalType
-from pyspark.sql import functions as sf
+from pyspark.util import PythonEvalType, is_remote_only
+from pyspark.sql import Row, functions as sf
 from pyspark.sql.window import Window
 from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
 from pyspark.testing.utils import (
@@ -27,6 +28,7 @@ from pyspark.testing.utils import (
     numpy_requirement_message,
     have_pyarrow,
     pyarrow_requirement_message,
+    assertDataFrameEqual,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
@@ -804,6 +806,49 @@ class WindowArrowUDFTestsMixin:
                     )
                     self.assertEqual(expected2, result2)
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_window_arrow_udf_with_logging(self):
+        import pyarrow as pa
+
+        @arrow_udf("double", ArrowUDFType.GROUPED_AGG)
+        def my_window_arrow_udf(x):
+            assert isinstance(x, pa.Array)
+            logger = logging.getLogger("test_window_arrow")
+            logger.warning(f"window arrow udf: {x.to_pylist()}")
+            return pa.compute.sum(x)
+
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+        w = 
Window.partitionBy("id").orderBy("v").rangeBetween(Window.unboundedPreceding, 0)
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                df.select("id", 
my_window_arrow_udf("v").over(w).alias("result")),
+                [
+                    Row(id=1, result=1.0),
+                    Row(id=1, result=3.0),
+                    Row(id=2, result=3.0),
+                    Row(id=2, result=8.0),
+                    Row(id=2, result=18.0),
+                ],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"window arrow udf: {lst}",
+                    context={"func_name": my_window_arrow_udf.__name__},
+                    logger="test_window_arrow",
+                )
+                for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0, 
10.0]]
+            ],
+        )
+
 
 class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 44bd8a6fa9df..ab954dd133f3 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -16,6 +16,7 @@
 #
 
 import unittest
+import logging
 from typing import cast
 
 from pyspark.sql import functions as sf
@@ -38,6 +39,8 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
 
 if have_pandas:
     import pandas as pd
@@ -714,6 +717,48 @@ class CogroupedApplyInPandasTestsMixin:
             with 
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
                 CogroupedApplyInPandasTestsMixin.test_with_key_right(self)
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_cogroup_apply_in_pandas_with_logging(self):
+        import pandas as pd
+
+        def func_with_logging(left_pdf, right_pdf):
+            assert isinstance(left_pdf, pd.DataFrame)
+            assert isinstance(right_pdf, pd.DataFrame)
+            logger = logging.getLogger("test_pandas_cogrouped_map")
+            logger.warning(
+                f"pandas cogrouped map: {dict(v1=list(left_pdf['v1']), 
v2=list(right_pdf['v2']))}"
+            )
+            return pd.merge(left_pdf, right_pdf, on=["id"])
+
+        left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)], 
["id", "v1"])
+        right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)], 
["id", "v2"])
+
+        grouped_left = left_df.groupBy("id")
+        grouped_right = right_df.groupBy("id")
+        cogrouped_df = grouped_left.cogroup(grouped_right)
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                cogrouped_df.applyInPandas(func_with_logging, "id long, v1 
long, v2 long"),
+                [Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100, 
300]]
+                + [Row(id=2, v1=20, v2=200)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"pandas cogrouped map: {dict(v1=v1, v2=v2)}",
+                    context={"func_name": func_with_logging.__name__},
+                    logger="test_pandas_cogrouped_map",
+                )
+                for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
+            ],
+        )
+
 
 class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, 
ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 4c52303481fa..0e922d072871 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -17,6 +17,7 @@
 
 import datetime
 import unittest
+import logging
 
 from collections import OrderedDict
 from decimal import Decimal
@@ -60,6 +61,8 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
 
 if have_pandas:
     import pandas as pd
@@ -985,6 +988,42 @@ class ApplyInPandasTestsMixin:
             with 
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
                 ApplyInPandasTestsMixin.test_complex_groupby(self)
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_apply_in_pandas_with_logging(self):
+        import pandas as pd
+
+        def func_with_logging(pdf):
+            assert isinstance(pdf, pd.DataFrame)
+            logger = logging.getLogger("test_pandas_grouped_map")
+            logger.warning(
+                f"pandas grouped map: {dict(id=list(pdf['id']), 
value=list(pdf['value']))}"
+            )
+            return pdf
+
+        df = self.spark.range(9).withColumn("value", col("id") * 10)
+        grouped_df = df.groupBy((col("id") % 2).cast("int"))
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                grouped_df.applyInPandas(func_with_logging, "id long, value 
long"),
+                df,
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"pandas grouped map: {dict(id=lst, value=[v*10 for v 
in lst])}",
+                    context={"func_name": func_with_logging.__name__},
+                    logger="test_pandas_grouped_map",
+                )
+                for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
+            ],
+        )
+
 
 class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index b241b91e02a2..5e0e33a05b22 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -19,6 +19,7 @@ import shutil
 import tempfile
 import time
 import unittest
+import logging
 from typing import cast
 
 from pyspark.sql import Row
@@ -33,7 +34,8 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
-from pyspark.testing.utils import eventually
+from pyspark.testing.utils import assertDataFrameEqual, eventually
+from pyspark.util import is_remote_only
 
 if have_pandas:
     import pandas as pd
@@ -486,6 +488,43 @@ class MapInPandasTestsMixin:
         df = self.spark.range(1)
         self.assertEqual([Row(a=2, b=1)], df.mapInPandas(func, "a int, b 
int").collect())
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_map_in_pandas_with_logging(self):
+        import pandas as pd
+
+        def func_with_logging(iterator):
+            logger = logging.getLogger("test_pandas_map")
+            for pdf in iterator:
+                assert isinstance(pdf, pd.DataFrame)
+                logger.warning(f"pandas map: {list(pdf['id'])}")
+                yield pdf
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.maxRecordsPerBatch": "3",
+                "spark.sql.pyspark.worker.logging.enabled": "true",
+            }
+        ):
+            assertDataFrameEqual(
+                self.spark.range(9, 
numPartitions=2).mapInPandas(func_with_logging, "id long"),
+                [Row(id=i) for i in range(9)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"pandas map: {lst}",
+                    context={"func_name": func_with_logging.__name__},
+                    logger="test_pandas_map",
+                )
+                for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+            ],
+        )
+
 
 class MapInPandasTests(ReusedSQLTestCase, MapInPandasTestsMixin):
     @classmethod
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index 3fd970061b30..2b3e42312df9 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -16,9 +16,10 @@
 #
 
 import unittest
+import logging
 from typing import cast
 
-from pyspark.util import PythonEvalType
+from pyspark.util import PythonEvalType, is_remote_only
 from pyspark.sql import Row, functions as sf
 from pyspark.sql.functions import (
     array,
@@ -826,6 +827,40 @@ class GroupedAggPandasUDFTestsMixin:
 
                     self.assertEqual(expected, result)
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_grouped_agg_pandas_udf_with_logging(self):
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def my_grouped_agg_pandas_udf(x):
+            assert isinstance(x, pd.Series)
+            logger = logging.getLogger("test_grouped_agg_pandas")
+            logger.warning(f"grouped agg pandas udf: {len(x)}")
+            return x.sum()
+
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                
df.groupby("id").agg(my_grouped_agg_pandas_udf("v").alias("result")),
+                [Row(id=1, result=3.0), Row(id=2, result=18.0)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"grouped agg pandas udf: {n}",
+                    context={"func_name": my_grouped_agg_pandas_udf.__name__},
+                    logger="test_grouped_agg_pandas",
+                )
+                for n in [2, 3]
+            ],
+        )
+
 
 class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, 
ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index 3c2ae56067ae..fbfe1a226b5e 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -20,13 +20,14 @@ import shutil
 import tempfile
 import time
 import unittest
+import logging
 from datetime import date, datetime
 from decimal import Decimal
 from typing import cast
 
 from pyspark import TaskContext
-from pyspark.util import PythonEvalType
-from pyspark.sql import Column
+from pyspark.util import PythonEvalType, is_remote_only
+from pyspark.sql import Column, Row
 from pyspark.sql.functions import (
     array,
     col,
@@ -1917,6 +1918,76 @@ class ScalarPandasUDFTestsMixin:
                 row = df.select(pandas_udf(lambda _: pd.Series(["123"]), 
t)(df.id)).first()
                 assert row[0] == 123
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_scalar_pandas_udf_with_logging(self):
+        @pandas_udf("string", PandasUDFType.SCALAR)
+        def my_scalar_pandas_udf(x):
+            assert isinstance(x, pd.Series)
+            logger = logging.getLogger("test_scalar_pandas")
+            logger.warning(f"scalar pandas udf: {list(x)}")
+            return pd.Series(["scalar_pandas_" + str(val) for val in x])
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                self.spark.range(3, numPartitions=2).select(
+                    my_scalar_pandas_udf("id").alias("result")
+                ),
+                [Row(result=f"scalar_pandas_{i}") for i in range(3)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"scalar pandas udf: {lst}",
+                    context={"func_name": my_scalar_pandas_udf.__name__},
+                    logger="test_scalar_pandas",
+                )
+                for lst in [[0], [1, 2]]
+            ],
+        )
+
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_scalar_iter_pandas_udf_with_logging(self):
+        @pandas_udf("string", PandasUDFType.SCALAR_ITER)
+        def my_scalar_iter_pandas_udf(it):
+            logger = logging.getLogger("test_scalar_iter_pandas")
+            for x in it:
+                assert isinstance(x, pd.Series)
+                logger.warning(f"scalar iter pandas udf: {list(x)}")
+                yield pd.Series(["scalar_iter_pandas_" + str(val) for val in 
x])
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.maxRecordsPerBatch": "3",
+                "spark.sql.pyspark.worker.logging.enabled": "true",
+            }
+        ):
+            assertDataFrameEqual(
+                self.spark.range(9, numPartitions=2).select(
+                    my_scalar_iter_pandas_udf("id").alias("result")
+                ),
+                [Row(result=f"scalar_iter_pandas_{i}") for i in range(9)],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"scalar iter pandas udf: {lst}",
+                    context={"func_name": my_scalar_iter_pandas_udf.__name__},
+                    logger="test_scalar_iter_pandas",
+                )
+                for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+            ],
+        )
+
 
 class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 547e237902b3..6fa7e9063836 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -16,6 +16,7 @@
 #
 
 import unittest
+import logging
 from typing import cast
 from decimal import Decimal
 
@@ -38,6 +39,8 @@ from pyspark.testing.sqlutils import (
     pyarrow_requirement_message,
 )
 from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.sql import Row
+from pyspark.util import is_remote_only
 
 if have_pandas:
     from pandas.testing import assert_frame_equal
@@ -633,6 +636,49 @@ class WindowPandasUDFTestsMixin:
                     )
                     self.assertEqual(expected2, result2)
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_window_pandas_udf_with_logging(self):
+        import pandas as pd
+
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def my_window_pandas_udf(x):
+            assert isinstance(x, pd.Series)
+            logger = logging.getLogger("test_window_pandas")
+            logger.warning(f"window pandas udf: {list(x)}")
+            return x.sum()
+
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+        w = 
Window.partitionBy("id").orderBy("v").rangeBetween(Window.unboundedPreceding, 0)
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                df.select("id", 
my_window_pandas_udf("v").over(w).alias("result")),
+                [
+                    Row(id=1, result=1.0),
+                    Row(id=1, result=3.0),
+                    Row(id=2, result=3.0),
+                    Row(id=2, result=8.0),
+                    Row(id=2, result=18.0),
+                ],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=f"window pandas udf: {lst}",
+                    context={"func_name": my_window_pandas_udf.__name__},
+                    logger="test_window_pandas",
+                )
+                for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0, 
10.0]]
+            ],
+        )
+
 
 class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 26bd5368e6f9..63e7e32c1c7b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -170,7 +170,8 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       conf.arrowUseLargeVarTypes,
       pythonRunnerConf,
       metrics,
-      jobArtifactUUID)
+      jobArtifactUUID,
+      None) // TODO: Python worker logging
   }
 
   def createPythonMetrics(): Array[CustomMetric] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
index f4e8831f23b8..a3d6c57c58bd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
@@ -144,6 +144,12 @@ case class ArrowAggregatePythonExec(
 
 
     val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+    val sessionUUID = {
+      Option(session).collect {
+        case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+          session.sessionUUID
+      }
+    }
 
     // Map grouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
     inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
@@ -190,6 +196,7 @@ case class ArrowAggregatePythonExec(
         pythonRunnerConf,
         pythonMetrics,
         jobArtifactUUID,
+        sessionUUID,
         conf.pythonUDFProfiler) with GroupedPythonArrowInput
 
       val columnarBatchIter = runner.compute(projectedRowIter, 
context.partitionId(), context)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 92236ca42b2d..7498815cda4e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -82,6 +82,12 @@ case class ArrowEvalPythonExec(
   }
 
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+  private[this] val sessionUUID = {
+    Option(session).collect {
+      case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+        session.sessionUUID
+    }
+  }
 
   override protected def evaluatorFactory: EvalPythonEvaluatorFactory = {
     new ArrowEvalPythonEvaluatorFactory(
@@ -95,6 +101,7 @@ case class ArrowEvalPythonExec(
       ArrowPythonRunner.getPythonRunnerConfMap(conf),
       pythonMetrics,
       jobArtifactUUID,
+      sessionUUID,
       conf.pythonUDFProfiler)
   }
 
@@ -121,6 +128,7 @@ class ArrowEvalPythonEvaluatorFactory(
     pythonRunnerConf: Map[String, String],
     pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
+    sessionUUID: Option[String],
     profiler: Option[String])
   extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
 
@@ -147,6 +155,7 @@ class ArrowEvalPythonEvaluatorFactory(
       pythonRunnerConf,
       pythonMetrics,
       jobArtifactUUID,
+      sessionUUID,
       profiler) with BatchedPythonArrowInput
     val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), 
context)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 77aec2a35f21..b94e00bc11ef 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.python
 
 import java.io.DataOutputStream
+import java.util
 
 import org.apache.spark.api.python._
 import org.apache.spark.sql.catalyst.InternalRow
@@ -36,12 +37,20 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
     protected override val largeVarTypes: Boolean,
     protected override val workerConf: Map[String, String],
     override val pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    sessionUUID: Option[String])
   extends BasePythonRunner[IN, OUT](
     funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics)
   with PythonArrowInput[IN]
   with PythonArrowOutput[OUT] {
 
+  override val envVars: util.Map[String, String] = {
+    val envVars = new util.HashMap(funcs.head._1.funcs.head.envVars)
+    sessionUUID.foreach { uuid =>
+      envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid)
+    }
+    envVars
+  }
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
       funcs.head._1.funcs.head.pythonExec)
@@ -77,10 +86,11 @@ abstract class RowInputArrowPythonRunner(
     largeVarTypes: Boolean,
     workerConf: Map[String, String],
     pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    sessionUUID: Option[String])
   extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch](
     funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, 
workerConf,
-    pythonMetrics, jobArtifactUUID)
+    pythonMetrics, jobArtifactUUID, sessionUUID)
   with BasicPythonArrowInput
   with BasicPythonArrowOutput
 
@@ -97,10 +107,11 @@ class ArrowPythonRunner(
     workerConf: Map[String, String],
     pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
+    sessionUUID: Option[String],
     profiler: Option[String])
   extends RowInputArrowPythonRunner(
     funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, 
workerConf,
-    pythonMetrics, jobArtifactUUID) {
+    pythonMetrics, jobArtifactUUID, sessionUUID) {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit =
     PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
@@ -120,10 +131,11 @@ class ArrowPythonWithNamedArgumentRunner(
     workerConf: Map[String, String],
     pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
+    sessionUUID: Option[String],
     profiler: Option[String])
   extends RowInputArrowPythonRunner(
     funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, 
largeVarTypes, workerConf,
-    pythonMetrics, jobArtifactUUID) {
+    pythonMetrics, jobArtifactUUID, sessionUUID) {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit = {
     if (evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
index 82c03b1d0229..2bf974d9026f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
@@ -45,6 +45,7 @@ class ArrowWindowPythonEvaluatorFactory(
     val evalType: Int,
     val spillSize: SQLMetric,
     pythonMetrics: Map[String, SQLMetric],
+    sessionUUID: Option[String],
     profiler: Option[String])
   extends PartitionEvaluatorFactory[InternalRow, InternalRow] with 
WindowEvaluatorFactoryBase {
 
@@ -378,6 +379,7 @@ class ArrowWindowPythonEvaluatorFactory(
         pythonRunnerConf,
         pythonMetrics,
         jobArtifactUUID,
+        sessionUUID,
         profiler) with GroupedPythonArrowInput
 
       val windowFunctionResult = runner.compute(pythonInput, 
context.partitionId(), context)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala
index c8259c10dbd9..ba3ffe7639eb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala
@@ -91,6 +91,13 @@ case class ArrowWindowPythonExec(
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")
   )
 
+  private[this] val sessionUUID = {
+    Option(session).collect {
+      case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+        session.sessionUUID
+    }
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val evaluatorFactory =
       new ArrowWindowPythonEvaluatorFactory(
@@ -101,6 +108,7 @@ case class ArrowWindowPythonExec(
         evalType,
         longMetric("spillSize"),
         pythonMetrics,
+        sessionUUID,
         conf.pythonUDFProfiler)
 
     // Start processing.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 9dbdd285338e..00eb9039d05c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.python
 
 import java.io.DataOutputStream
+import java.util
 
 import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD, PythonWorker}
@@ -43,12 +44,20 @@ class CoGroupedArrowPythonRunner(
     conf: Map[String, String],
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
+    sessionUUID: Option[String],
     profiler: Option[String])
   extends BasePythonRunner[
     (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](
     funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics)
   with BasicPythonArrowOutput {
 
+  override val envVars: util.Map[String, String] = {
+    val envVars = new util.HashMap(funcs.head._1.funcs.head.envVars)
+    sessionUUID.foreach { uuid =>
+      envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid)
+    }
+    envVars
+  }
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
       funcs.head._1.funcs.head.pythonExec)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
index af487218391e..38427866458e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
@@ -68,6 +68,12 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with 
BinaryExecNode with Pyth
     val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
     val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, 
rightGroup)
     val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+    val sessionUUID = {
+      Option(session).collect {
+        case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+          session.sessionUUID
+      }
+    }
 
     // Map cogrouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
     left.execute().zipPartitions(right.execute())  { (leftData, rightData) =>
@@ -89,6 +95,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with 
BinaryExecNode with Pyth
           pythonRunnerConf,
           pythonMetrics,
           jobArtifactUUID,
+          sessionUUID,
           conf.pythonUDFProfiler)
 
         executePython(data, output, runner)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
index 57a50a8fc857..7d221552226d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
@@ -48,6 +48,12 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with 
UnaryExecNode with PythonS
   private val chainedFunc =
     Seq((ChainedPythonFunctions(Seq(pythonFunction)), pythonUDF.resultId.id))
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+  private[this] val sessionUUID = {
+    Option(session).collect {
+      case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+        session.sessionUUID
+    }
+  }
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
 
@@ -92,6 +98,7 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with 
UnaryExecNode with PythonS
         pythonRunnerConf,
         pythonMetrics,
         jobArtifactUUID,
+        sessionUUID,
         conf.pythonUDFProfiler) with GroupedPythonArrowInput
 
       executePython(data, output, runner)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 9e3e8610ed37..4e78b3035a7e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -40,7 +40,8 @@ class MapInBatchEvaluatorFactory(
     largeVarTypes: Boolean,
     pythonRunnerConf: Map[String, String],
     val pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    sessionUUID: Option[String])
   extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
 
   override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] 
=
@@ -72,6 +73,7 @@ class MapInBatchEvaluatorFactory(
         pythonRunnerConf,
         pythonMetrics,
         jobArtifactUUID,
+        sessionUUID,
         None) with BatchedPythonArrowInput
       val columnarBatchIter = pyRunner.compute(batchIter, 
context.partitionId(), context)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index c003d503c7ca..1d03c0cf7603 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -44,6 +44,12 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
   override def producedAttributes: AttributeSet = AttributeSet(output)
 
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+  private[this] val sessionUUID = {
+    Option(session).collect {
+      case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+        session.sessionUUID
+    }
+  }
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
@@ -63,7 +69,8 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
       conf.arrowUseLargeVarTypes,
       pythonRunnerConf,
       pythonMetrics,
-      jobArtifactUUID)
+      jobArtifactUUID,
+      sessionUUID)
 
     val rdd = if (isBarrier) {
       val rddBarrier = child.execute().barrier()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to