This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f37cd07534e4 [SPARK-53976][PYTHON] Support logging in Pandas/Arrow UDFs
f37cd07534e4 is described below
commit f37cd07534e4eab267ffcdd8c40923b8a175c789
Author: Takuya Ueshin <[email protected]>
AuthorDate: Thu Oct 30 10:30:23 2025 -0700
[SPARK-53976][PYTHON] Support logging in Pandas/Arrow UDFs
### What changes were proposed in this pull request?
Supports logging in Pandas/Arrow UDFs.
### Why are the changes needed?
The basic logging infrastructure was introduced in
https://github.com/apache/spark/pull/52689, and other UDF types should also
support logging.
Here adding support for Pandas and Arrow UDFs.
### Does this PR introduce _any_ user-facing change?
Yes, the logging feature will be available in Pandas/Arrow UDFs.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52785 from ueshin/issues/SPARK-53976/pandas_arrow_udfs.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/tests/arrow/test_arrow_cogrouped_map.py | 46 +++++++++++++
.../sql/tests/arrow/test_arrow_grouped_map.py | 77 +++++++++++++++++++++
python/pyspark/sql/tests/arrow/test_arrow_map.py | 55 +++++++++++++++
.../sql/tests/arrow/test_arrow_python_udf.py | 12 ----
.../sql/tests/arrow/test_arrow_udf_grouped_agg.py | 40 ++++++++++-
.../sql/tests/arrow/test_arrow_udf_scalar.py | 79 +++++++++++++++++++++-
.../sql/tests/arrow/test_arrow_udf_window.py | 49 +++++++++++++-
.../sql/tests/pandas/test_pandas_cogrouped_map.py | 45 ++++++++++++
.../sql/tests/pandas/test_pandas_grouped_map.py | 39 +++++++++++
python/pyspark/sql/tests/pandas/test_pandas_map.py | 41 ++++++++++-
.../tests/pandas/test_pandas_udf_grouped_agg.py | 37 +++++++++-
.../sql/tests/pandas/test_pandas_udf_scalar.py | 75 +++++++++++++++++++-
.../sql/tests/pandas/test_pandas_udf_window.py | 46 +++++++++++++
.../v2/python/UserDefinedPythonDataSource.scala | 3 +-
.../python/ArrowAggregatePythonExec.scala | 7 ++
.../sql/execution/python/ArrowEvalPythonExec.scala | 9 +++
.../sql/execution/python/ArrowPythonRunner.scala | 22 ++++--
.../python/ArrowWindowPythonEvaluatorFactory.scala | 2 +
.../execution/python/ArrowWindowPythonExec.scala | 8 +++
.../python/CoGroupedArrowPythonRunner.scala | 9 +++
.../python/FlatMapCoGroupsInBatchExec.scala | 7 ++
.../python/FlatMapGroupsInBatchExec.scala | 7 ++
.../python/MapInBatchEvaluatorFactory.scala | 4 +-
.../sql/execution/python/MapInBatchExec.scala | 9 ++-
24 files changed, 700 insertions(+), 28 deletions(-)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
index 13edcec6b57f..2bdd7bda3bc2 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
@@ -17,6 +17,7 @@
import os
import time
import unittest
+import logging
from pyspark.errors import PythonException
from pyspark.sql import Row
@@ -26,6 +27,8 @@ from pyspark.testing.sqlutils import (
have_pyarrow,
pyarrow_requirement_message,
)
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
if have_pyarrow:
import pyarrow as pa
@@ -367,6 +370,49 @@ class CogroupedMapInArrowTestsMixin:
with
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
CogroupedMapInArrowTestsMixin.test_apply_in_arrow(self)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_cogroup_apply_in_arrow_with_logging(self):
+ import pyarrow as pa
+
+ def func_with_logging(left, right):
+ assert isinstance(left, pa.Table)
+ assert isinstance(right, pa.Table)
+ logger = logging.getLogger("test_arrow_cogrouped_map")
+ logger.warning(
+ "arrow cogrouped map: "
+ + f"{dict(v1=left['v1'].to_pylist(),
v2=right['v2'].to_pylist())}"
+ )
+ return left.join(right, keys="id", join_type="inner")
+
+ left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)],
["id", "v1"])
+ right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)],
["id", "v2"])
+
+ grouped_left = left_df.groupBy("id")
+ grouped_right = right_df.groupBy("id")
+ cogrouped_df = grouped_left.cogroup(grouped_right)
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ cogrouped_df.applyInArrow(func_with_logging, "id long, v1
long, v2 long"),
+ [Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100,
300]]
+ + [Row(id=2, v1=20, v2=200)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"arrow cogrouped map: {dict(v1=v1, v2=v2)}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_arrow_cogrouped_map",
+ )
+ for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
+ ],
+ )
+
class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin,
ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
index 8d3d929096b1..829c38385bd0 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
@@ -17,6 +17,7 @@
import inspect
import os
import time
+import logging
from typing import Iterator, Tuple
import unittest
@@ -29,6 +30,8 @@ from pyspark.testing.sqlutils import (
have_pyarrow,
pyarrow_requirement_message,
)
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
if have_pyarrow:
import pyarrow as pa
@@ -394,6 +397,80 @@ class ApplyInArrowTestsMixin:
with
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
ApplyInArrowTestsMixin.test_apply_in_arrow(self)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_apply_in_arrow_with_logging(self):
+ import pyarrow as pa
+
+ def func_with_logging(group):
+ assert isinstance(group, pa.Table)
+ logger = logging.getLogger("test_arrow_grouped_map")
+ logger.warning(f"arrow grouped map: {group.to_pydict()}")
+ return group
+
+ df = self.spark.range(9).withColumn("value", col("id") * 10)
+ grouped_df = df.groupBy((col("id") % 2).cast("int"))
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ grouped_df.applyInArrow(func_with_logging, "id long, value
long"),
+ df,
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v
in lst])}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_arrow_grouped_map",
+ )
+ for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
+ ],
+ )
+
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_apply_in_arrow_iter_with_logging(self):
+ import pyarrow as pa
+
+ def func_with_logging(group: Iterator[pa.RecordBatch]) ->
Iterator[pa.RecordBatch]:
+ logger = logging.getLogger("test_arrow_grouped_map")
+ for batch in group:
+ assert isinstance(batch, pa.RecordBatch)
+ logger.warning(f"arrow grouped map: {batch.to_pydict()}")
+ yield batch
+
+ df = self.spark.range(9).withColumn("value", col("id") * 10)
+ grouped_df = df.groupBy((col("id") % 2).cast("int"))
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": 3,
+ "spark.sql.pyspark.worker.logging.enabled": "true",
+ }
+ ):
+ assertDataFrameEqual(
+ grouped_df.applyInArrow(func_with_logging, "id long, value
long"),
+ df,
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v
in lst])}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_arrow_grouped_map",
+ )
+ for lst in [[0, 2, 4], [6, 8], [1, 3, 5], [7]]
+ ],
+ )
+
class ApplyInArrowTests(ApplyInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py
b/python/pyspark/sql/tests/arrow/test_arrow_map.py
index 0f9f5b422440..4a56a32fbcdd 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py
@@ -17,6 +17,7 @@
import os
import time
import unittest
+import logging
from pyspark.sql.utils import PythonException
from pyspark.testing.sqlutils import (
@@ -26,6 +27,9 @@ from pyspark.testing.sqlutils import (
pandas_requirement_message,
pyarrow_requirement_message,
)
+from pyspark.sql import Row
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
if have_pyarrow:
import pyarrow as pa
@@ -221,6 +225,46 @@ class MapInArrowTestsMixin(object):
df = self.spark.range(1)
df.mapInArrow(func, "a int").collect()
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_map_in_arrow_with_logging(self):
+ import pyarrow as pa
+
+ def func_with_logging(iterator):
+ logger = logging.getLogger("test_arrow_map")
+ for batch in iterator:
+ assert isinstance(batch, pa.RecordBatch)
+ logger.warning(f"arrow map: {batch.to_pydict()}")
+ yield batch
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": "3",
+ "spark.sql.pyspark.worker.logging.enabled": "true",
+ }
+ ):
+ assertDataFrameEqual(
+ self.spark.range(9,
numPartitions=2).mapInArrow(func_with_logging, "id long"),
+ [Row(id=i) for i in range(9)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+
self._expected_logs_for_test_map_in_arrow_with_logging(func_with_logging.__name__),
+ )
+
+ def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
+ return [
+ Row(
+ level="WARNING",
+ msg=f"arrow map: {dict(id=lst)}",
+ context={"func_name": func_name},
+ logger="test_arrow_map",
+ )
+ for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+ ]
+
class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
@@ -253,6 +297,17 @@ class
MapInArrowWithArrowBatchSlicingTestsAndReducedBatchSizeTests(MapInArrowTes
cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "3")
cls.spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "10")
+ def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
+ return [
+ Row(
+ level="WARNING",
+ msg=f"arrow map: {dict(id=[i])}",
+ context={"func_name": func_name},
+ logger="test_arrow_map",
+ )
+ for i in range(9)
+ ]
+
class MapInArrowWithOutputArrowBatchSlicingRecordsTests(MapInArrowTests):
@classmethod
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
index 55b4edd72d5d..90e05caf2180 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
@@ -60,18 +60,6 @@ class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin):
def test_register_java_udaf(self):
super(ArrowPythonUDFTests, self).test_register_java_udaf()
- @unittest.skip(
- "TODO(SPARK-53976): Python worker logging is not supported for Arrow
Python UDFs."
- )
- def test_udf_with_logging(self):
- super().test_udf_with_logging()
-
- @unittest.skip(
- "TODO(SPARK-53976): Python worker logging is not supported for Arrow
Python UDFs."
- )
- def test_multiple_udfs_with_logging(self):
- super().test_multiple_udfs_with_logging()
-
def test_complex_input_types(self):
row = (
self.spark.range(1)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index 136a99e19411..f719b4fb16bd 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -16,9 +16,10 @@
#
import unittest
+import logging
from pyspark.sql.functions import arrow_udf, ArrowUDFType
-from pyspark.util import PythonEvalType
+from pyspark.util import PythonEvalType, is_remote_only
from pyspark.sql import Row
from pyspark.sql.types import (
ArrayType,
@@ -35,6 +36,7 @@ from pyspark.testing.utils import (
numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
+ assertDataFrameEqual,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -1021,6 +1023,42 @@ class GroupedAggArrowUDFTestsMixin:
self.assertEqual(expected, result)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_grouped_agg_arrow_udf_with_logging(self):
+ import pyarrow as pa
+
+ @arrow_udf("double", ArrowUDFType.GROUPED_AGG)
+ def my_grouped_agg_arrow_udf(x):
+ assert isinstance(x, pa.Array)
+ logger = logging.getLogger("test_grouped_agg_arrow")
+ logger.warning(f"grouped agg arrow udf: {len(x)}")
+ return pa.compute.sum(x)
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+
df.groupby("id").agg(my_grouped_agg_arrow_udf("v").alias("result")),
+ [Row(id=1, result=3.0), Row(id=2, result=18.0)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"grouped agg arrow udf: {n}",
+ context={"func_name": my_grouped_agg_arrow_udf.__name__},
+ logger="test_grouped_agg_arrow",
+ )
+ for n in [2, 3]
+ ],
+ )
+
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index a682c6515ef6..05f33a4ae42f 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -20,13 +20,14 @@ import random
import time
import unittest
import datetime
+import logging
from decimal import Decimal
from typing import Iterator, Tuple
from pyspark.util import PythonEvalType
from pyspark.sql.functions import arrow_udf, ArrowUDFType
-from pyspark.sql import functions as F
+from pyspark.sql import Row, functions as F
from pyspark.sql.types import (
IntegerType,
ByteType,
@@ -51,8 +52,10 @@ from pyspark.testing.utils import (
numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
+ assertDataFrameEqual,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.util import is_remote_only
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -1179,6 +1182,80 @@ class ScalarArrowUDFTestsMixin:
def func_a(a: pa.Array) -> pa.Array:
return a
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_scalar_arrow_udf_with_logging(self):
+ import pyarrow as pa
+
+ @arrow_udf("string")
+ def my_scalar_arrow_udf(x):
+ assert isinstance(x, pa.Array)
+ logger = logging.getLogger("test_scalar_arrow")
+ logger.warning(f"scalar arrow udf: {x.to_pylist()}")
+ return pa.array(["scalar_arrow_" + str(val.as_py()) for val in x])
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ self.spark.range(3, numPartitions=2).select(
+ my_scalar_arrow_udf("id").alias("result")
+ ),
+ [Row(result=f"scalar_arrow_{i}") for i in range(3)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"scalar arrow udf: {lst}",
+ context={"func_name": my_scalar_arrow_udf.__name__},
+ logger="test_scalar_arrow",
+ )
+ for lst in [[0], [1, 2]]
+ ],
+ )
+
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_scalar_iter_arrow_udf_with_logging(self):
+ import pyarrow as pa
+
+ @arrow_udf("string", ArrowUDFType.SCALAR_ITER)
+ def my_scalar_iter_arrow_udf(it):
+ logger = logging.getLogger("test_scalar_iter_arrow")
+ for x in it:
+ assert isinstance(x, pa.Array)
+ logger.warning(f"scalar iter arrow udf: {x.to_pylist()}")
+ yield pa.array(["scalar_iter_arrow_" + str(val.as_py()) for
val in x])
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": "3",
+ "spark.sql.pyspark.worker.logging.enabled": "true",
+ }
+ ):
+ assertDataFrameEqual(
+ self.spark.range(9, numPartitions=2).select(
+ my_scalar_iter_arrow_udf("id").alias("result")
+ ),
+ [Row(result=f"scalar_iter_arrow_{i}") for i in range(9)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"scalar iter arrow udf: {lst}",
+ context={"func_name": my_scalar_iter_arrow_udf.__name__},
+ logger="test_scalar_iter_arrow",
+ )
+ for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+ ],
+ )
+
class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index d67b99475bf8..240e34487b00 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -16,10 +16,11 @@
#
import unittest
+import logging
from pyspark.sql.functions import arrow_udf, ArrowUDFType
-from pyspark.util import PythonEvalType
-from pyspark.sql import functions as sf
+from pyspark.util import PythonEvalType, is_remote_only
+from pyspark.sql import Row, functions as sf
from pyspark.sql.window import Window
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
from pyspark.testing.utils import (
@@ -27,6 +28,7 @@ from pyspark.testing.utils import (
numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
+ assertDataFrameEqual,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -804,6 +806,49 @@ class WindowArrowUDFTestsMixin:
)
self.assertEqual(expected2, result2)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_window_arrow_udf_with_logging(self):
+ import pyarrow as pa
+
+ @arrow_udf("double", ArrowUDFType.GROUPED_AGG)
+ def my_window_arrow_udf(x):
+ assert isinstance(x, pa.Array)
+ logger = logging.getLogger("test_window_arrow")
+ logger.warning(f"window arrow udf: {x.to_pylist()}")
+ return pa.compute.sum(x)
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+ w =
Window.partitionBy("id").orderBy("v").rangeBetween(Window.unboundedPreceding, 0)
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ df.select("id",
my_window_arrow_udf("v").over(w).alias("result")),
+ [
+ Row(id=1, result=1.0),
+ Row(id=1, result=3.0),
+ Row(id=2, result=3.0),
+ Row(id=2, result=8.0),
+ Row(id=2, result=18.0),
+ ],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"window arrow udf: {lst}",
+ context={"func_name": my_window_arrow_udf.__name__},
+ logger="test_window_arrow",
+ )
+ for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0,
10.0]]
+ ],
+ )
+
class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 44bd8a6fa9df..ab954dd133f3 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -16,6 +16,7 @@
#
import unittest
+import logging
from typing import cast
from pyspark.sql import functions as sf
@@ -38,6 +39,8 @@ from pyspark.testing.sqlutils import (
pandas_requirement_message,
pyarrow_requirement_message,
)
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
if have_pandas:
import pandas as pd
@@ -714,6 +717,48 @@ class CogroupedApplyInPandasTestsMixin:
with
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
CogroupedApplyInPandasTestsMixin.test_with_key_right(self)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_cogroup_apply_in_pandas_with_logging(self):
+ import pandas as pd
+
+ def func_with_logging(left_pdf, right_pdf):
+ assert isinstance(left_pdf, pd.DataFrame)
+ assert isinstance(right_pdf, pd.DataFrame)
+ logger = logging.getLogger("test_pandas_cogrouped_map")
+ logger.warning(
+ f"pandas cogrouped map: {dict(v1=list(left_pdf['v1']),
v2=list(right_pdf['v2']))}"
+ )
+ return pd.merge(left_pdf, right_pdf, on=["id"])
+
+ left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)],
["id", "v1"])
+ right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)],
["id", "v2"])
+
+ grouped_left = left_df.groupBy("id")
+ grouped_right = right_df.groupBy("id")
+ cogrouped_df = grouped_left.cogroup(grouped_right)
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ cogrouped_df.applyInPandas(func_with_logging, "id long, v1
long, v2 long"),
+ [Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100,
300]]
+ + [Row(id=2, v1=20, v2=200)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"pandas cogrouped map: {dict(v1=v1, v2=v2)}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_pandas_cogrouped_map",
+ )
+ for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
+ ],
+ )
+
class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin,
ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 4c52303481fa..0e922d072871 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -17,6 +17,7 @@
import datetime
import unittest
+import logging
from collections import OrderedDict
from decimal import Decimal
@@ -60,6 +61,8 @@ from pyspark.testing.sqlutils import (
pandas_requirement_message,
pyarrow_requirement_message,
)
+from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.util import is_remote_only
if have_pandas:
import pandas as pd
@@ -985,6 +988,42 @@ class ApplyInPandasTestsMixin:
with
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
ApplyInPandasTestsMixin.test_complex_groupby(self)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_apply_in_pandas_with_logging(self):
+ import pandas as pd
+
+ def func_with_logging(pdf):
+ assert isinstance(pdf, pd.DataFrame)
+ logger = logging.getLogger("test_pandas_grouped_map")
+ logger.warning(
+ f"pandas grouped map: {dict(id=list(pdf['id']),
value=list(pdf['value']))}"
+ )
+ return pdf
+
+ df = self.spark.range(9).withColumn("value", col("id") * 10)
+ grouped_df = df.groupBy((col("id") % 2).cast("int"))
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ grouped_df.applyInPandas(func_with_logging, "id long, value
long"),
+ df,
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"pandas grouped map: {dict(id=lst, value=[v*10 for v
in lst])}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_pandas_grouped_map",
+ )
+ for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
+ ],
+ )
+
class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index b241b91e02a2..5e0e33a05b22 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -19,6 +19,7 @@ import shutil
import tempfile
import time
import unittest
+import logging
from typing import cast
from pyspark.sql import Row
@@ -33,7 +34,8 @@ from pyspark.testing.sqlutils import (
pandas_requirement_message,
pyarrow_requirement_message,
)
-from pyspark.testing.utils import eventually
+from pyspark.testing.utils import assertDataFrameEqual, eventually
+from pyspark.util import is_remote_only
if have_pandas:
import pandas as pd
@@ -486,6 +488,43 @@ class MapInPandasTestsMixin:
df = self.spark.range(1)
self.assertEqual([Row(a=2, b=1)], df.mapInPandas(func, "a int, b
int").collect())
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_map_in_pandas_with_logging(self):
+ import pandas as pd
+
+ def func_with_logging(iterator):
+ logger = logging.getLogger("test_pandas_map")
+ for pdf in iterator:
+ assert isinstance(pdf, pd.DataFrame)
+ logger.warning(f"pandas map: {list(pdf['id'])}")
+ yield pdf
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": "3",
+ "spark.sql.pyspark.worker.logging.enabled": "true",
+ }
+ ):
+ assertDataFrameEqual(
+ self.spark.range(9,
numPartitions=2).mapInPandas(func_with_logging, "id long"),
+ [Row(id=i) for i in range(9)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"pandas map: {lst}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_pandas_map",
+ )
+ for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+ ],
+ )
+
class MapInPandasTests(ReusedSQLTestCase, MapInPandasTestsMixin):
@classmethod
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index 3fd970061b30..2b3e42312df9 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -16,9 +16,10 @@
#
import unittest
+import logging
from typing import cast
-from pyspark.util import PythonEvalType
+from pyspark.util import PythonEvalType, is_remote_only
from pyspark.sql import Row, functions as sf
from pyspark.sql.functions import (
array,
@@ -826,6 +827,40 @@ class GroupedAggPandasUDFTestsMixin:
self.assertEqual(expected, result)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_grouped_agg_pandas_udf_with_logging(self):
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def my_grouped_agg_pandas_udf(x):
+ assert isinstance(x, pd.Series)
+ logger = logging.getLogger("test_grouped_agg_pandas")
+ logger.warning(f"grouped agg pandas udf: {len(x)}")
+ return x.sum()
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+
df.groupby("id").agg(my_grouped_agg_pandas_udf("v").alias("result")),
+ [Row(id=1, result=3.0), Row(id=2, result=18.0)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"grouped agg pandas udf: {n}",
+ context={"func_name": my_grouped_agg_pandas_udf.__name__},
+ logger="test_grouped_agg_pandas",
+ )
+ for n in [2, 3]
+ ],
+ )
+
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin,
ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index 3c2ae56067ae..fbfe1a226b5e 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -20,13 +20,14 @@ import shutil
import tempfile
import time
import unittest
+import logging
from datetime import date, datetime
from decimal import Decimal
from typing import cast
from pyspark import TaskContext
-from pyspark.util import PythonEvalType
-from pyspark.sql import Column
+from pyspark.util import PythonEvalType, is_remote_only
+from pyspark.sql import Column, Row
from pyspark.sql.functions import (
array,
col,
@@ -1917,6 +1918,76 @@ class ScalarPandasUDFTestsMixin:
row = df.select(pandas_udf(lambda _: pd.Series(["123"]),
t)(df.id)).first()
assert row[0] == 123
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_scalar_pandas_udf_with_logging(self):
+ @pandas_udf("string", PandasUDFType.SCALAR)
+ def my_scalar_pandas_udf(x):
+ assert isinstance(x, pd.Series)
+ logger = logging.getLogger("test_scalar_pandas")
+ logger.warning(f"scalar pandas udf: {list(x)}")
+ return pd.Series(["scalar_pandas_" + str(val) for val in x])
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ self.spark.range(3, numPartitions=2).select(
+ my_scalar_pandas_udf("id").alias("result")
+ ),
+ [Row(result=f"scalar_pandas_{i}") for i in range(3)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"scalar pandas udf: {lst}",
+ context={"func_name": my_scalar_pandas_udf.__name__},
+ logger="test_scalar_pandas",
+ )
+ for lst in [[0], [1, 2]]
+ ],
+ )
+
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_scalar_iter_pandas_udf_with_logging(self):
+ @pandas_udf("string", PandasUDFType.SCALAR_ITER)
+ def my_scalar_iter_pandas_udf(it):
+ logger = logging.getLogger("test_scalar_iter_pandas")
+ for x in it:
+ assert isinstance(x, pd.Series)
+ logger.warning(f"scalar iter pandas udf: {list(x)}")
+ yield pd.Series(["scalar_iter_pandas_" + str(val) for val in
x])
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": "3",
+ "spark.sql.pyspark.worker.logging.enabled": "true",
+ }
+ ):
+ assertDataFrameEqual(
+ self.spark.range(9, numPartitions=2).select(
+ my_scalar_iter_pandas_udf("id").alias("result")
+ ),
+ [Row(result=f"scalar_iter_pandas_{i}") for i in range(9)],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"scalar iter pandas udf: {lst}",
+ context={"func_name": my_scalar_iter_pandas_udf.__name__},
+ logger="test_scalar_iter_pandas",
+ )
+ for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+ ],
+ )
+
class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 547e237902b3..6fa7e9063836 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -16,6 +16,7 @@
#
import unittest
+import logging
from typing import cast
from decimal import Decimal
@@ -38,6 +39,8 @@ from pyspark.testing.sqlutils import (
pyarrow_requirement_message,
)
from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.sql import Row
+from pyspark.util import is_remote_only
if have_pandas:
from pandas.testing import assert_frame_equal
@@ -633,6 +636,49 @@ class WindowPandasUDFTestsMixin:
)
self.assertEqual(expected2, result2)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_window_pandas_udf_with_logging(self):
+ import pandas as pd
+
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def my_window_pandas_udf(x):
+ assert isinstance(x, pd.Series)
+ logger = logging.getLogger("test_window_pandas")
+ logger.warning(f"window pandas udf: {list(x)}")
+ return x.sum()
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+ w =
Window.partitionBy("id").orderBy("v").rangeBetween(Window.unboundedPreceding, 0)
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ df.select("id",
my_window_pandas_udf("v").over(w).alias("result")),
+ [
+ Row(id=1, result=1.0),
+ Row(id=1, result=3.0),
+ Row(id=2, result=3.0),
+ Row(id=2, result=8.0),
+ Row(id=2, result=18.0),
+ ],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"window pandas udf: {lst}",
+ context={"func_name": my_window_pandas_udf.__name__},
+ logger="test_window_pandas",
+ )
+ for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0,
10.0]]
+ ],
+ )
+
class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 26bd5368e6f9..63e7e32c1c7b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -170,7 +170,8 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
conf.arrowUseLargeVarTypes,
pythonRunnerConf,
metrics,
- jobArtifactUUID)
+ jobArtifactUUID,
+ None) // TODO: Python worker logging
}
def createPythonMetrics(): Array[CustomMetric] = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
index f4e8831f23b8..a3d6c57c58bd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
@@ -144,6 +144,12 @@ case class ArrowAggregatePythonExec(
val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ val sessionUUID = {
+ Option(session).collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
// Map grouped rows to ArrowPythonRunner results, Only execute if
partition is not empty
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
@@ -190,6 +196,7 @@ case class ArrowAggregatePythonExec(
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
+ sessionUUID,
conf.pythonUDFProfiler) with GroupedPythonArrowInput
val columnarBatchIter = runner.compute(projectedRowIter,
context.partitionId(), context)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 92236ca42b2d..7498815cda4e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -82,6 +82,12 @@ case class ArrowEvalPythonExec(
}
private[this] val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ private[this] val sessionUUID = {
+ Option(session).collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
override protected def evaluatorFactory: EvalPythonEvaluatorFactory = {
new ArrowEvalPythonEvaluatorFactory(
@@ -95,6 +101,7 @@ case class ArrowEvalPythonExec(
ArrowPythonRunner.getPythonRunnerConfMap(conf),
pythonMetrics,
jobArtifactUUID,
+ sessionUUID,
conf.pythonUDFProfiler)
}
@@ -121,6 +128,7 @@ class ArrowEvalPythonEvaluatorFactory(
pythonRunnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
+ sessionUUID: Option[String],
profiler: Option[String])
extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
@@ -147,6 +155,7 @@ class ArrowEvalPythonEvaluatorFactory(
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
+ sessionUUID,
profiler) with BatchedPythonArrowInput
val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(),
context)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 77aec2a35f21..b94e00bc11ef 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python
import java.io.DataOutputStream
+import java.util
import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
@@ -36,12 +37,20 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
- jobArtifactUUID: Option[String])
+ jobArtifactUUID: Option[String],
+ sessionUUID: Option[String])
extends BasePythonRunner[IN, OUT](
funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics)
with PythonArrowInput[IN]
with PythonArrowOutput[OUT] {
+ override val envVars: util.Map[String, String] = {
+ val envVars = new util.HashMap(funcs.head._1.funcs.head.envVars)
+ sessionUUID.foreach { uuid =>
+ envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid)
+ }
+ envVars
+ }
override val pythonExec: String =
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head._1.funcs.head.pythonExec)
@@ -77,10 +86,11 @@ abstract class RowInputArrowPythonRunner(
largeVarTypes: Boolean,
workerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
- jobArtifactUUID: Option[String])
+ jobArtifactUUID: Option[String],
+ sessionUUID: Option[String])
extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch](
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
workerConf,
- pythonMetrics, jobArtifactUUID)
+ pythonMetrics, jobArtifactUUID, sessionUUID)
with BasicPythonArrowInput
with BasicPythonArrowOutput
@@ -97,10 +107,11 @@ class ArrowPythonRunner(
workerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
+ sessionUUID: Option[String],
profiler: Option[String])
extends RowInputArrowPythonRunner(
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
workerConf,
- pythonMetrics, jobArtifactUUID) {
+ pythonMetrics, jobArtifactUUID, sessionUUID) {
override protected def writeUDF(dataOut: DataOutputStream): Unit =
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
@@ -120,10 +131,11 @@ class ArrowPythonWithNamedArgumentRunner(
workerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
+ sessionUUID: Option[String],
profiler: Option[String])
extends RowInputArrowPythonRunner(
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId,
largeVarTypes, workerConf,
- pythonMetrics, jobArtifactUUID) {
+ pythonMetrics, jobArtifactUUID, sessionUUID) {
override protected def writeUDF(dataOut: DataOutputStream): Unit = {
if (evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
index 82c03b1d0229..2bf974d9026f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
@@ -45,6 +45,7 @@ class ArrowWindowPythonEvaluatorFactory(
val evalType: Int,
val spillSize: SQLMetric,
pythonMetrics: Map[String, SQLMetric],
+ sessionUUID: Option[String],
profiler: Option[String])
extends PartitionEvaluatorFactory[InternalRow, InternalRow] with
WindowEvaluatorFactoryBase {
@@ -378,6 +379,7 @@ class ArrowWindowPythonEvaluatorFactory(
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
+ sessionUUID,
profiler) with GroupedPythonArrowInput
val windowFunctionResult = runner.compute(pythonInput,
context.partitionId(), context)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala
index c8259c10dbd9..ba3ffe7639eb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonExec.scala
@@ -91,6 +91,13 @@ case class ArrowWindowPythonExec(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")
)
+ private[this] val sessionUUID = {
+ Option(session).collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val evaluatorFactory =
new ArrowWindowPythonEvaluatorFactory(
@@ -101,6 +108,7 @@ case class ArrowWindowPythonExec(
evalType,
longMetric("spillSize"),
pythonMetrics,
+ sessionUUID,
conf.pythonUDFProfiler)
// Start processing.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 9dbdd285338e..00eb9039d05c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python
import java.io.DataOutputStream
+import java.util
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions,
PythonRDD, PythonWorker}
@@ -43,12 +44,20 @@ class CoGroupedArrowPythonRunner(
conf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
+ sessionUUID: Option[String],
profiler: Option[String])
extends BasePythonRunner[
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](
funcs.map(_._1), evalType, argOffsets, jobArtifactUUID, pythonMetrics)
with BasicPythonArrowOutput {
+ override val envVars: util.Map[String, String] = {
+ val envVars = new util.HashMap(funcs.head._1.funcs.head.envVars)
+ sessionUUID.foreach { uuid =>
+ envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid)
+ }
+ envVars
+ }
override val pythonExec: String =
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head._1.funcs.head.pythonExec)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
index af487218391e..38427866458e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
@@ -68,6 +68,12 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with
BinaryExecNode with Pyth
val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output,
rightGroup)
val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ val sessionUUID = {
+ Option(session).collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
// Map cogrouped rows to ArrowPythonRunner results, Only execute if
partition is not empty
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
@@ -89,6 +95,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with
BinaryExecNode with Pyth
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
+ sessionUUID,
conf.pythonUDFProfiler)
executePython(data, output, runner)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
index 57a50a8fc857..7d221552226d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
@@ -48,6 +48,12 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with
UnaryExecNode with PythonS
private val chainedFunc =
Seq((ChainedPythonFunctions(Seq(pythonFunction)), pythonUDF.resultId.id))
private[this] val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ private[this] val sessionUUID = {
+ Option(session).collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
override def producedAttributes: AttributeSet = AttributeSet(output)
@@ -92,6 +98,7 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with
UnaryExecNode with PythonS
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
+ sessionUUID,
conf.pythonUDFProfiler) with GroupedPythonArrowInput
executePython(data, output, runner)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 9e3e8610ed37..4e78b3035a7e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -40,7 +40,8 @@ class MapInBatchEvaluatorFactory(
largeVarTypes: Boolean,
pythonRunnerConf: Map[String, String],
val pythonMetrics: Map[String, SQLMetric],
- jobArtifactUUID: Option[String])
+ jobArtifactUUID: Option[String],
+ sessionUUID: Option[String])
extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow]
=
@@ -72,6 +73,7 @@ class MapInBatchEvaluatorFactory(
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
+ sessionUUID,
None) with BatchedPythonArrowInput
val columnarBatchIter = pyRunner.compute(batchIter,
context.partitionId(), context)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index c003d503c7ca..1d03c0cf7603 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -44,6 +44,12 @@ trait MapInBatchExec extends UnaryExecNode with
PythonSQLMetrics {
override def producedAttributes: AttributeSet = AttributeSet(output)
private[this] val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ private[this] val sessionUUID = {
+ Option(session).collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -63,7 +69,8 @@ trait MapInBatchExec extends UnaryExecNode with
PythonSQLMetrics {
conf.arrowUseLargeVarTypes,
pythonRunnerConf,
pythonMetrics,
- jobArtifactUUID)
+ jobArtifactUUID,
+ sessionUUID)
val rdd = if (isBarrier) {
val rddBarrier = child.execute().barrier()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]