This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f0b2383a4a6 [SPARK-55053][PYTHON] Refactor data source/udtf analysis 
workers so they have a unified entry
9f0b2383a4a6 is described below

commit 9f0b2383a4a611ea348494ea24c95c37db642f34
Author: Tian Gao <[email protected]>
AuthorDate: Tue Jan 20 12:45:49 2026 +0900

    [SPARK-55053][PYTHON] Refactor data source/udtf analysis workers so they 
have a unified entry
    
    ### What changes were proposed in this pull request?
    
    * A new `util.py` file is added in `pyspark.sql.worker` to handle the 
common logic for all planner workers.
    * All the worker modules in that folder, including data source and udtf 
analysis modules, are using this unified entry.
    
    ### Why are the changes needed?
    
    * Reduce duplicated code. We have the same worker communication logic in 
all files and when we need to change it, we have to change all.
    * Make it easier for future instrumentation like data source profiler
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    test_datasource passed locally
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, cursor (claude-4.5-opus-high)
    
    Closes #53818 from gaogaotiantian/unify-data-source-workers.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/worker/analyze_udtf.py          | 286 +++++++++----------
 .../pyspark/sql/worker/commit_data_source_write.py | 114 +++-----
 python/pyspark/sql/worker/create_data_source.py    | 186 +++++--------
 .../sql/worker/data_source_pushdown_filters.py     | 200 ++++++--------
 python/pyspark/sql/worker/lookup_data_sources.py   |  75 ++---
 python/pyspark/sql/worker/plan_data_source_read.py | 208 ++++++--------
 .../sql/worker/python_streaming_sink_runner.py     | 159 +++++------
 .../worker/{lookup_data_sources.py => utils.py}    |  52 +---
 .../pyspark/sql/worker/write_into_data_source.py   | 305 +++++++++------------
 9 files changed, 624 insertions(+), 961 deletions(-)

diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 526cb316862c..7265138202cd 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -17,11 +17,9 @@
 
 import inspect
 import os
-import sys
 from textwrap import dedent
 from typing import Dict, List, IO, Tuple
 
-from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkRuntimeError, PySparkValueError
 from pyspark.logger.worker_io import capture_outputs, context_provider as 
default_context_provider
 from pyspark.serializers import (
@@ -29,25 +27,15 @@ from pyspark.serializers import (
     read_int,
     write_int,
     write_with_length,
-    SpecialLengths,
 )
 from pyspark.sql.functions import OrderingColumn, PartitioningColumn, 
SelectedColumn
 from pyspark.sql.types import _parse_datatype_json_string, StructType
 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
-from pyspark.util import (
-    handle_worker_exception,
-    local_connect_and_auth,
-    with_faulthandler,
-    start_faulthandler_periodic_traceback,
-)
+from pyspark.sql.worker.utils import worker_run
+from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
-    check_python_version,
     read_command,
     pickleSer,
-    send_accumulator_updates,
-    setup_broadcasts,
-    setup_memory_limits,
-    setup_spark_files,
     utf8_deserializer,
 )
 
@@ -104,8 +92,7 @@ def read_arguments(infile: IO) -> 
Tuple[List[AnalyzeArgument], Dict[str, Analyze
     return args, kwargs
 
 
-@with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
+def _main(infile: IO, outfile: IO) -> None:
     """
     Runs the Python UDTF's `analyze` static method.
 
@@ -113,166 +100,141 @@ def main(infile: IO, outfile: IO) -> None:
     in JVM and receive the Python UDTF and its arguments for the `analyze` 
static method,
     and call the `analyze` static method, and send back a AnalyzeResult as a 
result of the method.
     """
-    try:
-        check_python_version(infile)
-
-        start_faulthandler_periodic_traceback()
-
-        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
-        setup_memory_limits(memory_limit_mb)
-
-        setup_spark_files(infile)
-        setup_broadcasts(infile)
-
-        _accumulatorRegistry.clear()
+    udtf_name = utf8_deserializer.loads(infile)
+    handler = read_udtf(infile)
+    args, kwargs = read_arguments(infile)
 
-        udtf_name = utf8_deserializer.loads(infile)
-        handler = read_udtf(infile)
-        args, kwargs = read_arguments(infile)
+    error_prefix = f"Failed to evaluate the user-defined table function 
'{udtf_name}'"
 
-        error_prefix = f"Failed to evaluate the user-defined table function 
'{udtf_name}'"
+    def format_error(msg: str) -> str:
+        return dedent(msg).replace("\n", " ")
 
-        def format_error(msg: str) -> str:
-            return dedent(msg).replace("\n", " ")
-
-        # Check that the arguments provided to the UDTF call match the 
expected parameters defined
-        # in the static 'analyze' method signature.
-        try:
-            inspect.signature(handler.analyze).bind(*args, **kwargs)  # type: 
ignore[attr-defined]
-        except TypeError as e:
-            # The UDTF call's arguments did not match the expected signature.
-            raise PySparkValueError(
-                format_error(
-                    f"""
-                    {error_prefix} because the function arguments did not 
match the expected
-                    signature of the static 'analyze' method ({e}). Please 
update the query so that
-                    this table function call provides arguments matching the 
expected signature, or
-                    else update the table function so that its static 
'analyze' method accepts the
-                    provided arguments, and then try the query again."""
-                )
+    # Check that the arguments provided to the UDTF call match the expected 
parameters defined
+    # in the static 'analyze' method signature.
+    try:
+        inspect.signature(handler.analyze).bind(*args, **kwargs)  # type: 
ignore[attr-defined]
+    except TypeError as e:
+        # The UDTF call's arguments did not match the expected signature.
+        raise PySparkValueError(
+            format_error(
+                f"""
+                {error_prefix} because the function arguments did not match 
the expected
+                signature of the static 'analyze' method ({e}). Please update 
the query so that
+                this table function call provides arguments matching the 
expected signature, or
+                else update the table function so that its static 'analyze' 
method accepts the
+                provided arguments, and then try the query again."""
             )
+        )
 
-        # The default context provider can't detect the class name from static 
methods.
-        def context_provider() -> dict[str, str]:
-            context = default_context_provider()
-            context["class_name"] = handler.__name__
-            return context
-
-        with capture_outputs(context_provider):
-            # Invoke the UDTF's 'analyze' method.
-            result = handler.analyze(*args, **kwargs)  # type: 
ignore[attr-defined]
-
-        # Check invariants about the 'analyze' method after running it.
-        if not isinstance(result, AnalyzeResult):
-            raise PySparkValueError(
-                format_error(
-                    f"""
-                    {error_prefix} because the static 'analyze' method expects 
a result of type
-                    pyspark.sql.udtf.AnalyzeResult, but instead this method 
returned a value of
-                    type: {type(result)}"""
-                )
+    # The default context provider can't detect the class name from static 
methods.
+    def context_provider() -> dict[str, str]:
+        context = default_context_provider()
+        context["class_name"] = handler.__name__
+        return context
+
+    with capture_outputs(context_provider):
+        # Invoke the UDTF's 'analyze' method.
+        result = handler.analyze(*args, **kwargs)  # type: ignore[attr-defined]
+
+    # Check invariants about the 'analyze' method after running it.
+    if not isinstance(result, AnalyzeResult):
+        raise PySparkValueError(
+            format_error(
+                f"""
+                {error_prefix} because the static 'analyze' method expects a 
result of type
+                pyspark.sql.udtf.AnalyzeResult, but instead this method 
returned a value of
+                type: {type(result)}"""
             )
-        elif not isinstance(result.schema, StructType):
-            raise PySparkValueError(
-                format_error(
-                    f"""
-                    {error_prefix} because the static 'analyze' method expects 
a result of type
-                    pyspark.sql.udtf.AnalyzeResult with a 'schema' field 
comprising a StructType,
-                    but the 'schema' field had the wrong type: 
{type(result.schema)}"""
-                )
+        )
+    elif not isinstance(result.schema, StructType):
+        raise PySparkValueError(
+            format_error(
+                f"""
+                {error_prefix} because the static 'analyze' method expects a 
result of type
+                pyspark.sql.udtf.AnalyzeResult with a 'schema' field 
comprising a StructType,
+                but the 'schema' field had the wrong type: 
{type(result.schema)}"""
             )
+        )
 
-        def invalid_analyze_result_field(field_name: str, expected_field: str) 
-> PySparkValueError:
-            return PySparkValueError(
-                format_error(
-                    f"""
-                    {error_prefix} because the static 'analyze' method 
returned an
-                    'AnalyzeResult' object with the '{field_name}' field set 
to a value besides a
-                    list or tuple of '{expected_field}' objects. Please update 
the table function
-                    and then try the query again."""
-                )
+    def invalid_analyze_result_field(field_name: str, expected_field: str) -> 
PySparkValueError:
+        return PySparkValueError(
+            format_error(
+                f"""
+                {error_prefix} because the static 'analyze' method returned an
+                'AnalyzeResult' object with the '{field_name}' field set to a 
value besides a
+                list or tuple of '{expected_field}' objects. Please update the 
table function
+                and then try the query again."""
             )
-
-        has_table_arg = any(arg.isTable for arg in args) or any(
-            arg.isTable for arg in kwargs.values()
         )
-        if not has_table_arg and result.withSinglePartition:
-            raise PySparkValueError(
-                format_error(
-                    f"""
-                    {error_prefix} because the static 'analyze' method 
returned an
-                    'AnalyzeResult' object with the 'withSinglePartition' 
field set to 'true', but
-                    the function call did not provide any table argument. 
Please update the query so
-                    that it provides a table argument, or else update the 
table function so that its
-                    'analyze' method returns an 'AnalyzeResult' object with the
-                    'withSinglePartition' field set to 'false', and then try 
the query again."""
-                )
+
+    has_table_arg = any(arg.isTable for arg in args) or any(arg.isTable for 
arg in kwargs.values())
+    if not has_table_arg and result.withSinglePartition:
+        raise PySparkValueError(
+            format_error(
+                f"""
+                {error_prefix} because the static 'analyze' method returned an
+                'AnalyzeResult' object with the 'withSinglePartition' field 
set to 'true', but
+                the function call did not provide any table argument. Please 
update the query so
+                that it provides a table argument, or else update the table 
function so that its
+                'analyze' method returns an 'AnalyzeResult' object with the
+                'withSinglePartition' field set to 'false', and then try the 
query again."""
             )
-        elif not has_table_arg and len(result.partitionBy) > 0:
-            raise PySparkValueError(
-                format_error(
-                    f"""
-                    {error_prefix} because the static 'analyze' method 
returned an
-                    'AnalyzeResult' object with the 'partitionBy' list set to 
non-empty, but the
-                    function call did not provide any table argument. Please 
update the query so
-                    that it provides a table argument, or else update the 
table function so that its
-                    'analyze' method returns an 'AnalyzeResult' object with 
the 'partitionBy' list
-                    set to empty, and then try the query again."""
-                )
+        )
+    elif not has_table_arg and len(result.partitionBy) > 0:
+        raise PySparkValueError(
+            format_error(
+                f"""
+                {error_prefix} because the static 'analyze' method returned an
+                'AnalyzeResult' object with the 'partitionBy' list set to 
non-empty, but the
+                function call did not provide any table argument. Please 
update the query so
+                that it provides a table argument, or else update the table 
function so that its
+                'analyze' method returns an 'AnalyzeResult' object with the 
'partitionBy' list
+                set to empty, and then try the query again."""
             )
-        elif not isinstance(result.partitionBy, (list, tuple)) or not all(
-            isinstance(val, PartitioningColumn) for val in result.partitionBy
-        ):
-            raise invalid_analyze_result_field("partitionBy", 
"PartitioningColumn")
-        elif not isinstance(result.orderBy, (list, tuple)) or not all(
-            isinstance(val, OrderingColumn) for val in result.orderBy
-        ):
-            raise invalid_analyze_result_field("orderBy", "OrderingColumn")
-        elif not isinstance(result.select, (list, tuple)) or not all(
-            isinstance(val, SelectedColumn) for val in result.select
-        ):
-            raise invalid_analyze_result_field("select", "SelectedColumn")
-
-        # Return the analyzed schema.
-        write_with_length(result.schema.json().encode("utf-8"), outfile)
-        # Return the pickled 'AnalyzeResult' class instance.
-        pickleSer._write_with_length(result, outfile)
-        # Return whether the "with single partition" property is requested.
-        write_int(1 if result.withSinglePartition else 0, outfile)
-        # Return the list of partitioning columns, if any.
-        write_int(len(result.partitionBy), outfile)
-        for partitioning_col in result.partitionBy:
-            write_with_length(partitioning_col.name.encode("utf-8"), outfile)
-        # Return the requested input table ordering, if any.
-        write_int(len(result.orderBy), outfile)
-        for ordering_col in result.orderBy:
-            write_with_length(ordering_col.name.encode("utf-8"), outfile)
-            write_int(1 if ordering_col.ascending else 0, outfile)
-            if ordering_col.overrideNullsFirst is None:
-                write_int(0, outfile)
-            elif ordering_col.overrideNullsFirst:
-                write_int(1, outfile)
-            else:
-                write_int(2, outfile)
-        # Return the requested selected input table columns, if specified.
-        write_int(len(result.select), outfile)
-        for col in result.select:
-            write_with_length(col.name.encode("utf-8"), outfile)
-            write_with_length(col.alias.encode("utf-8"), outfile)
-
-    except BaseException as e:
-        handle_worker_exception(e, outfile)
-        sys.exit(-1)
+        )
+    elif not isinstance(result.partitionBy, (list, tuple)) or not all(
+        isinstance(val, PartitioningColumn) for val in result.partitionBy
+    ):
+        raise invalid_analyze_result_field("partitionBy", "PartitioningColumn")
+    elif not isinstance(result.orderBy, (list, tuple)) or not all(
+        isinstance(val, OrderingColumn) for val in result.orderBy
+    ):
+        raise invalid_analyze_result_field("orderBy", "OrderingColumn")
+    elif not isinstance(result.select, (list, tuple)) or not all(
+        isinstance(val, SelectedColumn) for val in result.select
+    ):
+        raise invalid_analyze_result_field("select", "SelectedColumn")
+
+    # Return the analyzed schema.
+    write_with_length(result.schema.json().encode("utf-8"), outfile)
+    # Return the pickled 'AnalyzeResult' class instance.
+    pickleSer._write_with_length(result, outfile)
+    # Return whether the "with single partition" property is requested.
+    write_int(1 if result.withSinglePartition else 0, outfile)
+    # Return the list of partitioning columns, if any.
+    write_int(len(result.partitionBy), outfile)
+    for partitioning_col in result.partitionBy:
+        write_with_length(partitioning_col.name.encode("utf-8"), outfile)
+    # Return the requested input table ordering, if any.
+    write_int(len(result.orderBy), outfile)
+    for ordering_col in result.orderBy:
+        write_with_length(ordering_col.name.encode("utf-8"), outfile)
+        write_int(1 if ordering_col.ascending else 0, outfile)
+        if ordering_col.overrideNullsFirst is None:
+            write_int(0, outfile)
+        elif ordering_col.overrideNullsFirst:
+            write_int(1, outfile)
+        else:
+            write_int(2, outfile)
+    # Return the requested selected input table columns, if specified.
+    write_int(len(result.select), outfile)
+    for col in result.select:
+        write_with_length(col.name.encode("utf-8"), outfile)
+        write_with_length(col.alias.encode("utf-8"), outfile)
 
-    send_accumulator_updates(outfile)
 
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
+def main(infile: IO, outfile: IO) -> None:
+    worker_run(_main, infile, outfile)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py 
b/python/pyspark/sql/worker/commit_data_source_write.py
index 37fee6ad8357..6838a32db398 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -15,37 +15,22 @@
 # limitations under the License.
 #
 import os
-import sys
 from typing import IO
 
-from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError
 from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
     read_bool,
     read_int,
     write_int,
-    SpecialLengths,
 )
 from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage
-from pyspark.util import (
-    handle_worker_exception,
-    local_connect_and_auth,
-    with_faulthandler,
-    start_faulthandler_periodic_traceback,
-)
-from pyspark.worker_util import (
-    check_python_version,
-    pickleSer,
-    send_accumulator_updates,
-    setup_broadcasts,
-    setup_memory_limits,
-    setup_spark_files,
-)
+from pyspark.sql.worker.utils import worker_run
+from pyspark.util import local_connect_and_auth
+from pyspark.worker_util import pickleSer
 
 
-@with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
+def _main(infile: IO, outfile: IO) -> None:
     """
     Main method for committing or aborting a data source write operation.
 
@@ -54,65 +39,42 @@ def main(infile: IO, outfile: IO) -> None:
     responsible for invoking either the `commit` or the `abort` method on a 
data source
     writer instance, given a list of commit messages.
     """
-    try:
-        check_python_version(infile)
-
-        start_faulthandler_periodic_traceback()
-
-        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
-        setup_memory_limits(memory_limit_mb)
-
-        setup_spark_files(infile)
-        setup_broadcasts(infile)
-
-        _accumulatorRegistry.clear()
+    # Receive the data source writer instance.
+    writer = pickleSer._read_with_length(infile)
+    assert isinstance(writer, DataSourceWriter)
+
+    # Receive the commit messages.
+    num_messages = read_int(infile)
+    commit_messages = []
+    for _ in range(num_messages):
+        message = pickleSer._read_with_length(infile)
+        if message is not None and not isinstance(message, 
WriterCommitMessage):
+            raise PySparkAssertionError(
+                errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                messageParameters={
+                    "expected": "an instance of WriterCommitMessage",
+                    "actual": f"'{type(message).__name__}'",
+                },
+            )
+        commit_messages.append(message)
+
+    # Receive a boolean to indicate whether to invoke `abort`.
+    abort = read_bool(infile)
+
+    with capture_outputs():
+        # Commit or abort the Python data source write.
+        # Note the commit messages can be None if there are failed tasks.
+        if abort:
+            writer.abort(commit_messages)
+        else:
+            writer.commit(commit_messages)
+
+    # Send a status code back to JVM.
+    write_int(0, outfile)
 
-        # Receive the data source writer instance.
-        writer = pickleSer._read_with_length(infile)
-        assert isinstance(writer, DataSourceWriter)
 
-        # Receive the commit messages.
-        num_messages = read_int(infile)
-        commit_messages = []
-        for _ in range(num_messages):
-            message = pickleSer._read_with_length(infile)
-            if message is not None and not isinstance(message, 
WriterCommitMessage):
-                raise PySparkAssertionError(
-                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                    messageParameters={
-                        "expected": "an instance of WriterCommitMessage",
-                        "actual": f"'{type(message).__name__}'",
-                    },
-                )
-            commit_messages.append(message)
-
-        # Receive a boolean to indicate whether to invoke `abort`.
-        abort = read_bool(infile)
-
-        with capture_outputs():
-            # Commit or abort the Python data source write.
-            # Note the commit messages can be None if there are failed tasks.
-            if abort:
-                writer.abort(commit_messages)
-            else:
-                writer.commit(commit_messages)
-
-        # Send a status code back to JVM.
-        write_int(0, outfile)
-
-    except BaseException as e:
-        handle_worker_exception(e, outfile)
-        sys.exit(-1)
-
-    send_accumulator_updates(outfile)
-
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
+def main(infile: IO, outfile: IO) -> None:
+    worker_run(_main, infile, outfile)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
index bf6ceda41ffb..625b08088e60 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -16,10 +16,8 @@
 #
 import inspect
 import os
-import sys
 from typing import IO
 
-from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkTypeError
 from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
@@ -27,30 +25,19 @@ from pyspark.serializers import (
     read_int,
     write_int,
     write_with_length,
-    SpecialLengths,
 )
 from pyspark.sql.datasource import DataSource, CaseInsensitiveDict
 from pyspark.sql.types import _parse_datatype_json_string, StructType
-from pyspark.util import (
-    handle_worker_exception,
-    local_connect_and_auth,
-    with_faulthandler,
-    start_faulthandler_periodic_traceback,
-)
+from pyspark.sql.worker.utils import worker_run
+from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
-    check_python_version,
     read_command,
     pickleSer,
-    send_accumulator_updates,
-    setup_broadcasts,
-    setup_memory_limits,
-    setup_spark_files,
     utf8_deserializer,
 )
 
 
-@with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
+def _main(infile: IO, outfile: IO) -> None:
     """
     Main method for creating a Python data source instance.
 
@@ -67,118 +54,95 @@ def main(infile: IO, outfile: IO) -> None:
     This process then creates a `DataSource` instance using the above 
information and
     sends the pickled instance as well as the schema back to the JVM.
     """
-    try:
-        check_python_version(infile)
-
-        start_faulthandler_periodic_traceback()
-
-        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
-        setup_memory_limits(memory_limit_mb)
-
-        setup_spark_files(infile)
-        setup_broadcasts(infile)
-
-        _accumulatorRegistry.clear()
-
-        # Receive the data source class.
-        data_source_cls = read_command(pickleSer, infile)
-        if not (isinstance(data_source_cls, type) and 
issubclass(data_source_cls, DataSource)):
+    # Receive the data source class.
+    data_source_cls = read_command(pickleSer, infile)
+    if not (isinstance(data_source_cls, type) and issubclass(data_source_cls, 
DataSource)):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "a subclass of DataSource",
+                "actual": f"'{type(data_source_cls).__name__}'",
+            },
+        )
+
+    # Check the name method is a class method.
+    if not inspect.ismethod(data_source_cls.name):
+        raise PySparkTypeError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "'name()' method to be a classmethod",
+                "actual": f"'{type(data_source_cls.name).__name__}'",
+            },
+        )
+
+    # Receive the provider name.
+    provider = utf8_deserializer.loads(infile)
+
+    with capture_outputs():
+        # Check if the provider name matches the data source's name.
+        name = data_source_cls.name()
+        if provider.lower() != name.lower():
             raise PySparkAssertionError(
                 errorClass="DATA_SOURCE_TYPE_MISMATCH",
                 messageParameters={
-                    "expected": "a subclass of DataSource",
-                    "actual": f"'{type(data_source_cls).__name__}'",
-                },
-            )
-
-        # Check the name method is a class method.
-        if not inspect.ismethod(data_source_cls.name):
-            raise PySparkTypeError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "'name()' method to be a classmethod",
-                    "actual": f"'{type(data_source_cls.name).__name__}'",
+                    "expected": f"provider with name {name}",
+                    "actual": f"'{provider}'",
                 },
             )
 
-        # Receive the provider name.
-        provider = utf8_deserializer.loads(infile)
-
-        with capture_outputs():
-            # Check if the provider name matches the data source's name.
-            name = data_source_cls.name()
-            if provider.lower() != name.lower():
+        # Receive the user-specified schema
+        user_specified_schema = None
+        if read_bool(infile):
+            user_specified_schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+            if not isinstance(user_specified_schema, StructType):
                 raise PySparkAssertionError(
                     errorClass="DATA_SOURCE_TYPE_MISMATCH",
                     messageParameters={
-                        "expected": f"provider with name {name}",
-                        "actual": f"'{provider}'",
+                        "expected": "the user-defined schema to be a 
'StructType'",
+                        "actual": f"'{type(data_source_cls).__name__}'",
                     },
                 )
 
-            # Receive the user-specified schema
-            user_specified_schema = None
-            if read_bool(infile):
-                user_specified_schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
-                if not isinstance(user_specified_schema, StructType):
-                    raise PySparkAssertionError(
-                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                        messageParameters={
-                            "expected": "the user-defined schema to be a 
'StructType'",
-                            "actual": f"'{type(data_source_cls).__name__}'",
-                        },
-                    )
-
-            # Receive the options.
-            options = CaseInsensitiveDict()
-            num_options = read_int(infile)
-            for _ in range(num_options):
-                key = utf8_deserializer.loads(infile)
-                value = utf8_deserializer.loads(infile)
-                options[key] = value
-
-            # Instantiate a data source.
-            data_source = data_source_cls(options=options)  # type: ignore
-
-            # Get the schema of the data source.
-            # If user_specified_schema is not None, use user_specified_schema.
-            # Otherwise, use the schema of the data source.
-            # Throw exception if the data source does not implement schema().
-            is_ddl_string = False
-            if user_specified_schema is None:
-                schema = data_source.schema()
-                if isinstance(schema, str):
-                    # Here we cannot use _parse_datatype_string to parse the 
DDL string schema.
-                    # as it requires an active Spark session.
-                    is_ddl_string = True
-            else:
-                schema = user_specified_schema  # type: ignore
-
-            assert schema is not None
-
-        # Return the pickled data source instance.
-        pickleSer._write_with_length(data_source, outfile)
-
-        # Return the schema of the data source.
-        write_int(int(is_ddl_string), outfile)
-        if is_ddl_string:
-            write_with_length(schema.encode("utf-8"), outfile)  # type: ignore
+        # Receive the options.
+        options = CaseInsensitiveDict()
+        num_options = read_int(infile)
+        for _ in range(num_options):
+            key = utf8_deserializer.loads(infile)
+            value = utf8_deserializer.loads(infile)
+            options[key] = value
+
+        # Instantiate a data source.
+        data_source = data_source_cls(options=options)  # type: ignore
+
+        # Get the schema of the data source.
+        # If user_specified_schema is not None, use user_specified_schema.
+        # Otherwise, use the schema of the data source.
+        # Throw exception if the data source does not implement schema().
+        is_ddl_string = False
+        if user_specified_schema is None:
+            schema = data_source.schema()
+            if isinstance(schema, str):
+                # Here we cannot use _parse_datatype_string to parse the DDL 
string schema.
+                # as it requires an active Spark session.
+                is_ddl_string = True
         else:
-            write_with_length(schema.json().encode("utf-8"), outfile)  # type: 
ignore
+            schema = user_specified_schema  # type: ignore
 
-    except BaseException as e:
-        handle_worker_exception(e, outfile)
-        sys.exit(-1)
+        assert schema is not None
 
-    send_accumulator_updates(outfile)
+    # Return the pickled data source instance.
+    pickleSer._write_with_length(data_source, outfile)
 
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
+    # Return the schema of the data source.
+    write_int(int(is_ddl_string), outfile)
+    if is_ddl_string:
+        write_with_length(schema.encode("utf-8"), outfile)  # type: ignore
     else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
+        write_with_length(schema.json().encode("utf-8"), outfile)  # type: 
ignore
+
+
+def main(infile: IO, outfile: IO) -> None:
+    worker_run(_main, infile, outfile)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py 
b/python/pyspark/sql/worker/data_source_pushdown_filters.py
index 7d255e1dbf77..06b71e1ca8f1 100644
--- a/python/pyspark/sql/worker/data_source_pushdown_filters.py
+++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py
@@ -18,16 +18,14 @@
 import base64
 import json
 import os
-import sys
 import typing
 from dataclasses import dataclass, field
 from typing import IO, Type, Union
 
-from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkValueError
 from pyspark.errors.exceptions.base import PySparkNotImplementedError
 from pyspark.logger.worker_io import capture_outputs
-from pyspark.serializers import SpecialLengths, UTF8Deserializer, read_int, 
read_bool, write_int
+from pyspark.serializers import UTF8Deserializer, read_int, read_bool, 
write_int
 from pyspark.sql.datasource import (
     DataSource,
     DataSourceReader,
@@ -48,20 +46,11 @@ from pyspark.sql.datasource import (
 )
 from pyspark.sql.types import StructType, VariantVal, 
_parse_datatype_json_string
 from pyspark.sql.worker.plan_data_source_read import 
write_read_func_and_partitions
-from pyspark.util import (
-    handle_worker_exception,
-    local_connect_and_auth,
-    with_faulthandler,
-    start_faulthandler_periodic_traceback,
-)
+from pyspark.sql.worker.utils import worker_run
+from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
-    check_python_version,
     pickleSer,
     read_command,
-    send_accumulator_updates,
-    setup_broadcasts,
-    setup_memory_limits,
-    setup_spark_files,
 )
 
 utf8_deserializer = UTF8Deserializer()
@@ -123,8 +112,7 @@ def deserializeFilter(jsonDict: dict) -> Filter:
     return filter
 
 
-@with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
+def _main(infile: IO, outfile: IO) -> None:
     """
     Main method for planning a data source read with filter pushdown.
 
@@ -145,126 +133,96 @@ def main(infile: IO, outfile: IO) -> None:
     on the reader and determines which filters are supported. The indices of 
the supported
     filters are sent back to the JVM, along with the list of partitions and 
the read function.
     """
-    try:
-        check_python_version(infile)
-
-        start_faulthandler_periodic_traceback()
-
-        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
-        setup_memory_limits(memory_limit_mb)
-
-        setup_spark_files(infile)
-        setup_broadcasts(infile)
-
-        _accumulatorRegistry.clear()
+    # Receive the data source instance.
+    data_source = read_command(pickleSer, infile)
+    if not isinstance(data_source, DataSource):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "a Python data source instance of type 
'DataSource'",
+                "actual": f"'{type(data_source).__name__}'",
+            },
+        )
 
-        # 
----------------------------------------------------------------------
-        # Start of worker logic
-        # 
----------------------------------------------------------------------
+    # Receive the data source output schema.
+    schema_json = utf8_deserializer.loads(infile)
+    schema = _parse_datatype_json_string(schema_json)
+    if not isinstance(schema, StructType):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "an output schema of type 'StructType'",
+                "actual": f"'{type(schema).__name__}'",
+            },
+        )
 
-        # Receive the data source instance.
-        data_source = read_command(pickleSer, infile)
-        if not isinstance(data_source, DataSource):
+    with capture_outputs():
+        # Get the reader.
+        reader = data_source.reader(schema=schema)
+        # Validate the reader.
+        if not isinstance(reader, DataSourceReader):
             raise PySparkAssertionError(
                 errorClass="DATA_SOURCE_TYPE_MISMATCH",
                 messageParameters={
-                    "expected": "a Python data source instance of type 
'DataSource'",
-                    "actual": f"'{type(data_source).__name__}'",
+                    "expected": "an instance of DataSourceReader",
+                    "actual": f"'{type(reader).__name__}'",
                 },
             )
 
-        # Receive the data source output schema.
-        schema_json = utf8_deserializer.loads(infile)
-        schema = _parse_datatype_json_string(schema_json)
-        if not isinstance(schema, StructType):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
+        # Receive the pushdown filters.
+        json_str = utf8_deserializer.loads(infile)
+        filter_dicts = json.loads(json_str)
+        filters = [FilterRef(deserializeFilter(f)) for f in filter_dicts]
+
+        # Push down the filters and get the indices of the unsupported filters.
+        unsupported_filters = set(
+            FilterRef(f) for f in reader.pushFilters([ref.filter for ref in 
filters])
+        )
+        supported_filter_indices = []
+        for i, filter in enumerate(filters):
+            if filter in unsupported_filters:
+                unsupported_filters.remove(filter)
+            else:
+                supported_filter_indices.append(i)
+
+        # If it returned any filters that are not in the original filters, 
raise an error.
+        if len(unsupported_filters) > 0:
+            raise PySparkValueError(
+                errorClass="DATA_SOURCE_EXTRANEOUS_FILTERS",
                 messageParameters={
-                    "expected": "an output schema of type 'StructType'",
-                    "actual": f"'{type(schema).__name__}'",
+                    "type": type(reader).__name__,
+                    "input": str(list(filters)),
+                    "extraneous": str(list(unsupported_filters)),
                 },
             )
 
-        with capture_outputs():
-            # Get the reader.
-            reader = data_source.reader(schema=schema)
-            # Validate the reader.
-            if not isinstance(reader, DataSourceReader):
-                raise PySparkAssertionError(
-                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                    messageParameters={
-                        "expected": "an instance of DataSourceReader",
-                        "actual": f"'{type(reader).__name__}'",
-                    },
-                )
-
-            # Receive the pushdown filters.
-            json_str = utf8_deserializer.loads(infile)
-            filter_dicts = json.loads(json_str)
-            filters = [FilterRef(deserializeFilter(f)) for f in filter_dicts]
-
-            # Push down the filters and get the indices of the unsupported 
filters.
-            unsupported_filters = set(
-                FilterRef(f) for f in reader.pushFilters([ref.filter for ref 
in filters])
-            )
-            supported_filter_indices = []
-            for i, filter in enumerate(filters):
-                if filter in unsupported_filters:
-                    unsupported_filters.remove(filter)
-                else:
-                    supported_filter_indices.append(i)
-
-            # If it returned any filters that are not in the original filters, 
raise an error.
-            if len(unsupported_filters) > 0:
-                raise PySparkValueError(
-                    errorClass="DATA_SOURCE_EXTRANEOUS_FILTERS",
-                    messageParameters={
-                        "type": type(reader).__name__,
-                        "input": str(list(filters)),
-                        "extraneous": str(list(unsupported_filters)),
-                    },
-                )
-
-            # Receive the max arrow batch size.
-            max_arrow_batch_size = read_int(infile)
-            assert max_arrow_batch_size > 0, (
-                "The maximum arrow batch size should be greater than 0, but 
got "
-                f"'{max_arrow_batch_size}'"
-            )
-            binary_as_bytes = read_bool(infile)
-
-            # Return the read function and partitions. Doing this in the same 
worker
-            # as filter pushdown helps reduce the number of Python worker 
calls.
-            write_read_func_and_partitions(
-                outfile,
-                reader=reader,
-                data_source=data_source,
-                schema=schema,
-                max_arrow_batch_size=max_arrow_batch_size,
-                binary_as_bytes=binary_as_bytes,
-            )
-
-        # Return the supported filter indices.
-        write_int(len(supported_filter_indices), outfile)
-        for index in supported_filter_indices:
-            write_int(index, outfile)
+        # Receive the max arrow batch size.
+        max_arrow_batch_size = read_int(infile)
+        assert max_arrow_batch_size > 0, (
+            "The maximum arrow batch size should be greater than 0, but got "
+            f"'{max_arrow_batch_size}'"
+        )
+        binary_as_bytes = read_bool(infile)
+
+        # Return the read function and partitions. Doing this in the same 
worker
+        # as filter pushdown helps reduce the number of Python worker calls.
+        write_read_func_and_partitions(
+            outfile,
+            reader=reader,
+            data_source=data_source,
+            schema=schema,
+            max_arrow_batch_size=max_arrow_batch_size,
+            binary_as_bytes=binary_as_bytes,
+        )
 
-        # 
----------------------------------------------------------------------
-        # End of worker logic
-        # 
----------------------------------------------------------------------
-    except BaseException as e:
-        handle_worker_exception(e, outfile)
-        sys.exit(-1)
+    # Return the supported filter indices.
+    write_int(len(supported_filter_indices), outfile)
+    for index in supported_filter_indices:
+        write_int(index, outfile)
 
-    send_accumulator_updates(outfile)
 
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
+def main(infile: IO, outfile: IO) -> None:
+    worker_run(_main, infile, outfile)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/worker/lookup_data_sources.py 
b/python/pyspark/sql/worker/lookup_data_sources.py
index b23903cac8cb..e432f40d6904 100644
--- a/python/pyspark/sql/worker/lookup_data_sources.py
+++ b/python/pyspark/sql/worker/lookup_data_sources.py
@@ -17,35 +17,19 @@
 from importlib import import_module
 from pkgutil import iter_modules
 import os
-import sys
 from typing import IO
 
-from pyspark.accumulators import _accumulatorRegistry
 from pyspark.serializers import (
-    read_int,
     write_int,
     write_with_length,
-    SpecialLengths,
 )
 from pyspark.sql.datasource import DataSource
-from pyspark.util import (
-    handle_worker_exception,
-    local_connect_and_auth,
-    with_faulthandler,
-    start_faulthandler_periodic_traceback,
-)
-from pyspark.worker_util import (
-    check_python_version,
-    pickleSer,
-    send_accumulator_updates,
-    setup_broadcasts,
-    setup_memory_limits,
-    setup_spark_files,
-)
+from pyspark.sql.worker.utils import worker_run
+from pyspark.util import local_connect_and_auth
+from pyspark.worker_util import pickleSer
 
 
-@with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
+def _main(infile: IO, outfile: IO) -> None:
     """
     Main method for looking up the available Python Data Sources in Python 
path.
 
@@ -56,46 +40,23 @@ def main(infile: IO, outfile: IO) -> None:
     This is responsible for searching the available Python Data Sources so 
they can be
     statically registered automatically.
     """
-    try:
-        check_python_version(infile)
-
-        start_faulthandler_periodic_traceback()
-
-        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
-        setup_memory_limits(memory_limit_mb)
+    infos = {}
+    for info in iter_modules():
+        if info.name.startswith("pyspark_"):
+            mod = import_module(info.name)
+            if hasattr(mod, "DefaultSource") and issubclass(mod.DefaultSource, 
DataSource):
+                infos[mod.DefaultSource.name()] = mod.DefaultSource
 
-        setup_spark_files(infile)
-        setup_broadcasts(infile)
+    # Writes name -> pickled data source to JVM side to be registered
+    # as a Data Source.
+    write_int(len(infos), outfile)
+    for name, dataSource in infos.items():
+        write_with_length(name.encode("utf-8"), outfile)
+        pickleSer._write_with_length(dataSource, outfile)
 
-        _accumulatorRegistry.clear()
 
-        infos = {}
-        for info in iter_modules():
-            if info.name.startswith("pyspark_"):
-                mod = import_module(info.name)
-                if hasattr(mod, "DefaultSource") and 
issubclass(mod.DefaultSource, DataSource):
-                    infos[mod.DefaultSource.name()] = mod.DefaultSource
-
-        # Writes name -> pickled data source to JVM side to be registered
-        # as a Data Source.
-        write_int(len(infos), outfile)
-        for name, dataSource in infos.items():
-            write_with_length(name.encode("utf-8"), outfile)
-            pickleSer._write_with_length(dataSource, outfile)
-
-    except BaseException as e:
-        handle_worker_exception(e, outfile)
-        sys.exit(-1)
-
-    send_accumulator_updates(outfile)
-
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
+def main(infile: IO, outfile: IO) -> None:
+    worker_run(_main, infile, outfile)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index f4a1231955b4..ed1c602b0af4 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -16,20 +16,17 @@
 #
 
 import os
-import sys
 import functools
 import pyarrow as pa
 from itertools import islice, chain
 from typing import IO, List, Iterator, Iterable, Tuple, Union
 
-from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
 from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
     read_bool,
     read_int,
     write_int,
-    SpecialLengths,
 )
 from pyspark.sql import Row
 from pyspark.sql.conversion import ArrowTableToRowsConversion, 
LocalDataToArrowConversion
@@ -46,20 +43,11 @@ from pyspark.sql.types import (
     BinaryType,
     StructType,
 )
-from pyspark.util import (
-    handle_worker_exception,
-    local_connect_and_auth,
-    with_faulthandler,
-    start_faulthandler_periodic_traceback,
-)
+from pyspark.sql.worker.utils import worker_run
+from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
-    check_python_version,
     read_command,
     pickleSer,
-    send_accumulator_updates,
-    setup_broadcasts,
-    setup_memory_limits,
-    setup_spark_files,
     utf8_deserializer,
 )
 
@@ -271,8 +259,7 @@ def write_read_func_and_partitions(
         write_int(0, outfile)
 
 
-@with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
+def _main(infile: IO, outfile: IO) -> None:
     """
     Main method for planning a data source read.
 
@@ -292,123 +279,100 @@ def main(infile: IO, outfile: IO) -> None:
     The partition values and the Arrow Batch are then serialized and sent back 
to the JVM
     via the socket.
     """
-    try:
-        check_python_version(infile)
-
-        start_faulthandler_periodic_traceback()
-
-        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
-        setup_memory_limits(memory_limit_mb)
-
-        setup_spark_files(infile)
-        setup_broadcasts(infile)
-
-        _accumulatorRegistry.clear()
-
-        # Receive the data source instance.
-        data_source = read_command(pickleSer, infile)
-        if not isinstance(data_source, DataSource):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "a Python data source instance of type 
'DataSource'",
-                    "actual": f"'{type(data_source).__name__}'",
-                },
-            )
-
-        # Receive the output schema from its child plan.
-        input_schema_json = utf8_deserializer.loads(infile)
-        input_schema = _parse_datatype_json_string(input_schema_json)
-        if not isinstance(input_schema, StructType):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "an input schema of type 'StructType'",
-                    "actual": f"'{type(input_schema).__name__}'",
-                },
-            )
-        assert len(input_schema) == 1 and isinstance(input_schema[0].dataType, 
BinaryType), (
-            "The input schema of Python data source read should contain only 
one column of type "
-            f"'BinaryType', but got '{input_schema}'"
+    # Receive the data source instance.
+    data_source = read_command(pickleSer, infile)
+    if not isinstance(data_source, DataSource):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "a Python data source instance of type 
'DataSource'",
+                "actual": f"'{type(data_source).__name__}'",
+            },
         )
 
-        # Receive the data source output schema.
-        schema_json = utf8_deserializer.loads(infile)
-        schema = _parse_datatype_json_string(schema_json)
-        if not isinstance(schema, StructType):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "an output schema of type 'StructType'",
-                    "actual": f"'{type(schema).__name__}'",
-                },
-            )
+    # Receive the output schema from its child plan.
+    input_schema_json = utf8_deserializer.loads(infile)
+    input_schema = _parse_datatype_json_string(input_schema_json)
+    if not isinstance(input_schema, StructType):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "an input schema of type 'StructType'",
+                "actual": f"'{type(input_schema).__name__}'",
+            },
+        )
+    assert len(input_schema) == 1 and isinstance(input_schema[0].dataType, 
BinaryType), (
+        "The input schema of Python data source read should contain only one 
column of type "
+        f"'BinaryType', but got '{input_schema}'"
+    )
 
-        # Receive the configuration values.
-        max_arrow_batch_size = read_int(infile)
-        assert max_arrow_batch_size > 0, (
-            "The maximum arrow batch size should be greater than 0, but got "
-            f"'{max_arrow_batch_size}'"
+    # Receive the data source output schema.
+    schema_json = utf8_deserializer.loads(infile)
+    schema = _parse_datatype_json_string(schema_json)
+    if not isinstance(schema, StructType):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "an output schema of type 'StructType'",
+                "actual": f"'{type(schema).__name__}'",
+            },
         )
-        enable_pushdown = read_bool(infile)
 
-        is_streaming = read_bool(infile)
-        binary_as_bytes = read_bool(infile)
+    # Receive the configuration values.
+    max_arrow_batch_size = read_int(infile)
+    assert max_arrow_batch_size > 0, (
+        "The maximum arrow batch size should be greater than 0, but got "
+        f"'{max_arrow_batch_size}'"
+    )
+    enable_pushdown = read_bool(infile)
+
+    is_streaming = read_bool(infile)
+    binary_as_bytes = read_bool(infile)
 
-        with capture_outputs():
-            # Instantiate data source reader.
-            if is_streaming:
-                reader: Union[DataSourceReader, DataSourceStreamReader] = 
_streamReader(
-                    data_source, schema
+    with capture_outputs():
+        # Instantiate data source reader.
+        if is_streaming:
+            reader: Union[DataSourceReader, DataSourceStreamReader] = 
_streamReader(
+                data_source, schema
+            )
+        else:
+            reader = data_source.reader(schema=schema)
+            # Validate the reader.
+            if not isinstance(reader, DataSourceReader):
+                raise PySparkAssertionError(
+                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                    messageParameters={
+                        "expected": "an instance of DataSourceReader",
+                        "actual": f"'{type(reader).__name__}'",
+                    },
                 )
-            else:
-                reader = data_source.reader(schema=schema)
-                # Validate the reader.
-                if not isinstance(reader, DataSourceReader):
-                    raise PySparkAssertionError(
-                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                        messageParameters={
-                            "expected": "an instance of DataSourceReader",
-                            "actual": f"'{type(reader).__name__}'",
-                        },
-                    )
-                is_pushdown_implemented = (
-                    getattr(reader.pushFilters, "__func__", None)
-                    is not DataSourceReader.pushFilters
+            is_pushdown_implemented = (
+                getattr(reader.pushFilters, "__func__", None) is not 
DataSourceReader.pushFilters
+            )
+            if is_pushdown_implemented and not enable_pushdown:
+                # Do not silently ignore pushFilters when pushdown is disabled.
+                # Raise an error to ask the user to enable pushdown.
+                raise PySparkAssertionError(
+                    errorClass="DATA_SOURCE_PUSHDOWN_DISABLED",
+                    messageParameters={
+                        "type": type(reader).__name__,
+                        "conf": "spark.sql.python.filterPushdown.enabled",
+                    },
                 )
-                if is_pushdown_implemented and not enable_pushdown:
-                    # Do not silently ignore pushFilters when pushdown is 
disabled.
-                    # Raise an error to ask the user to enable pushdown.
-                    raise PySparkAssertionError(
-                        errorClass="DATA_SOURCE_PUSHDOWN_DISABLED",
-                        messageParameters={
-                            "type": type(reader).__name__,
-                            "conf": "spark.sql.python.filterPushdown.enabled",
-                        },
-                    )
 
-            # Send the read function and partitions to the JVM.
-            write_read_func_and_partitions(
-                outfile,
-                reader=reader,
-                data_source=data_source,
-                schema=schema,
-                max_arrow_batch_size=max_arrow_batch_size,
-                binary_as_bytes=binary_as_bytes,
-            )
-    except BaseException as e:
-        handle_worker_exception(e, outfile)
-        sys.exit(-1)
+        # Send the read function and partitions to the JVM.
+        write_read_func_and_partitions(
+            outfile,
+            reader=reader,
+            data_source=data_source,
+            schema=schema,
+            max_arrow_batch_size=max_arrow_batch_size,
+            binary_as_bytes=binary_as_bytes,
+        )
 
-    send_accumulator_updates(outfile)
 
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
+def main(infile: IO, outfile: IO) -> None:
+    worker_run(_main, infile, outfile)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py 
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index 5ca3307fca33..952722d0d946 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -16,10 +16,8 @@
 #
 
 import os
-import sys
 from typing import IO
 
-from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError
 from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
@@ -27,33 +25,22 @@ from pyspark.serializers import (
     read_int,
     read_long,
     write_int,
-    SpecialLengths,
 )
 from pyspark.sql.datasource import DataSource, WriterCommitMessage
 from pyspark.sql.types import (
     _parse_datatype_json_string,
     StructType,
 )
-from pyspark.util import (
-    handle_worker_exception,
-    local_connect_and_auth,
-    with_faulthandler,
-    start_faulthandler_periodic_traceback,
-)
+from pyspark.sql.worker.utils import worker_run
+from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
-    check_python_version,
     read_command,
     pickleSer,
-    send_accumulator_updates,
-    setup_broadcasts,
-    setup_memory_limits,
-    setup_spark_files,
     utf8_deserializer,
 )
 
 
-@with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
+def _main(infile: IO, outfile: IO) -> None:
     """
     Main method for committing or aborting a data source streaming write 
operation.
 
@@ -62,89 +49,67 @@ def main(infile: IO, outfile: IO) -> None:
     responsible for invoking either the `commit` or the `abort` method on a 
data source
     writer instance, given a list of commit messages.
     """
-    try:
-        check_python_version(infile)
-
-        start_faulthandler_periodic_traceback()
-
-        setup_spark_files(infile)
-        setup_broadcasts(infile)
-
-        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
-        setup_memory_limits(memory_limit_mb)
-
-        _accumulatorRegistry.clear()
-
-        # Receive the data source instance.
-        data_source = read_command(pickleSer, infile)
+    # Receive the data source instance.
+    data_source = read_command(pickleSer, infile)
+
+    if not isinstance(data_source, DataSource):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "a Python data source instance of type 
'DataSource'",
+                "actual": f"'{type(data_source).__name__}'",
+            },
+        )
+    # Receive the data source output schema.
+    schema_json = utf8_deserializer.loads(infile)
+    schema = _parse_datatype_json_string(schema_json)
+    if not isinstance(schema, StructType):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "an output schema of type 'StructType'",
+                "actual": f"'{type(schema).__name__}'",
+            },
+        )
+    # Receive the `overwrite` flag.
+    overwrite = read_bool(infile)
+
+    with capture_outputs():
+        # Create the data source writer instance.
+        writer = data_source.streamWriter(schema=schema, overwrite=overwrite)
+        # Receive the commit messages.
+        num_messages = read_int(infile)
+
+        commit_messages = []
+        for _ in range(num_messages):
+            message = pickleSer._read_with_length(infile)
+            if message is not None and not isinstance(message, 
WriterCommitMessage):
+                raise PySparkAssertionError(
+                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                    messageParameters={
+                        "expected": "an instance of WriterCommitMessage",
+                        "actual": f"'{type(message).__name__}'",
+                    },
+                )
+            commit_messages.append(message)
+
+        batch_id = read_long(infile)
+        abort = read_bool(infile)
+
+        # Commit or abort the Python data source write.
+        # Note the commit messages can be None if there are failed tasks.
+        if abort:
+            writer.abort(commit_messages, batch_id)
+        else:
+            writer.commit(commit_messages, batch_id)
+
+    # Send a status code back to JVM.
+    write_int(0, outfile)
+    outfile.flush()
 
-        if not isinstance(data_source, DataSource):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "a Python data source instance of type 
'DataSource'",
-                    "actual": f"'{type(data_source).__name__}'",
-                },
-            )
-        # Receive the data source output schema.
-        schema_json = utf8_deserializer.loads(infile)
-        schema = _parse_datatype_json_string(schema_json)
-        if not isinstance(schema, StructType):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "an output schema of type 'StructType'",
-                    "actual": f"'{type(schema).__name__}'",
-                },
-            )
-        # Receive the `overwrite` flag.
-        overwrite = read_bool(infile)
 
-        with capture_outputs():
-            # Create the data source writer instance.
-            writer = data_source.streamWriter(schema=schema, 
overwrite=overwrite)
-            # Receive the commit messages.
-            num_messages = read_int(infile)
-
-            commit_messages = []
-            for _ in range(num_messages):
-                message = pickleSer._read_with_length(infile)
-                if message is not None and not isinstance(message, 
WriterCommitMessage):
-                    raise PySparkAssertionError(
-                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                        messageParameters={
-                            "expected": "an instance of WriterCommitMessage",
-                            "actual": f"'{type(message).__name__}'",
-                        },
-                    )
-                commit_messages.append(message)
-
-            batch_id = read_long(infile)
-            abort = read_bool(infile)
-
-            # Commit or abort the Python data source write.
-            # Note the commit messages can be None if there are failed tasks.
-            if abort:
-                writer.abort(commit_messages, batch_id)
-            else:
-                writer.commit(commit_messages, batch_id)
-
-        # Send a status code back to JVM.
-        write_int(0, outfile)
-        outfile.flush()
-    except BaseException as e:
-        handle_worker_exception(e, outfile)
-        sys.exit(-1)
-
-    send_accumulator_updates(outfile)
-
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
+def main(infile: IO, outfile: IO) -> None:
+    worker_run(_main, infile, outfile)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/worker/lookup_data_sources.py 
b/python/pyspark/sql/worker/utils.py
similarity index 52%
copy from python/pyspark/sql/worker/lookup_data_sources.py
copy to python/pyspark/sql/worker/utils.py
index b23903cac8cb..bd5c6ffda9ee 100644
--- a/python/pyspark/sql/worker/lookup_data_sources.py
+++ b/python/pyspark/sql/worker/utils.py
@@ -14,48 +14,33 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from importlib import import_module
-from pkgutil import iter_modules
+
 import os
 import sys
-from typing import IO
+from typing import Callable, IO
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.serializers import (
     read_int,
     write_int,
-    write_with_length,
     SpecialLengths,
 )
-from pyspark.sql.datasource import DataSource
 from pyspark.util import (
+    start_faulthandler_periodic_traceback,
     handle_worker_exception,
-    local_connect_and_auth,
     with_faulthandler,
-    start_faulthandler_periodic_traceback,
 )
 from pyspark.worker_util import (
     check_python_version,
-    pickleSer,
     send_accumulator_updates,
-    setup_broadcasts,
     setup_memory_limits,
     setup_spark_files,
+    setup_broadcasts,
 )
 
 
 @with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
-    """
-    Main method for looking up the available Python Data Sources in Python 
path.
-
-    This process is invoked from the 
`UserDefinedPythonDataSourceLookupRunner.runInPython`
-    method in `UserDefinedPythonDataSource.lookupAllDataSourcesInPython` when 
the first
-    call related to Python Data Source happens via `DataSourceManager`.
-
-    This is responsible for searching the available Python Data Sources so 
they can be
-    statically registered automatically.
-    """
+def worker_run(main: Callable, infile: IO, outfile: IO) -> None:
     try:
         check_python_version(infile)
 
@@ -69,20 +54,7 @@ def main(infile: IO, outfile: IO) -> None:
 
         _accumulatorRegistry.clear()
 
-        infos = {}
-        for info in iter_modules():
-            if info.name.startswith("pyspark_"):
-                mod = import_module(info.name)
-                if hasattr(mod, "DefaultSource") and 
issubclass(mod.DefaultSource, DataSource):
-                    infos[mod.DefaultSource.name()] = mod.DefaultSource
-
-        # Writes name -> pickled data source to JVM side to be registered
-        # as a Data Source.
-        write_int(len(infos), outfile)
-        for name, dataSource in infos.items():
-            write_with_length(name.encode("utf-8"), outfile)
-            pickleSer._write_with_length(dataSource, outfile)
-
+        main(infile, outfile)
     except BaseException as e:
         handle_worker_exception(e, outfile)
         sys.exit(-1)
@@ -96,15 +68,3 @@ def main(infile: IO, outfile: IO) -> None:
         # write a different value to tell JVM to not reuse this worker
         write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
         sys.exit(-1)
-
-
-if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index b8a54f8397dc..111829bb7d58 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -16,10 +16,8 @@
 #
 import inspect
 import os
-import sys
 from typing import IO, Iterable, Iterator
 
-from pyspark.accumulators import _accumulatorRegistry
 from pyspark.sql.conversion import ArrowTableToRowsConversion
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, 
PySparkTypeError
 from pyspark.logger.worker_io import capture_outputs
@@ -27,7 +25,6 @@ from pyspark.serializers import (
     read_bool,
     read_int,
     write_int,
-    SpecialLengths,
 )
 from pyspark.sql import Row
 from pyspark.sql.datasource import (
@@ -45,26 +42,18 @@ from pyspark.sql.types import (
     BinaryType,
     _create_row,
 )
+from pyspark.sql.worker.utils import worker_run
 from pyspark.util import (
-    handle_worker_exception,
     local_connect_and_auth,
-    with_faulthandler,
-    start_faulthandler_periodic_traceback,
 )
 from pyspark.worker_util import (
-    check_python_version,
     read_command,
     pickleSer,
-    send_accumulator_updates,
-    setup_broadcasts,
-    setup_memory_limits,
-    setup_spark_files,
     utf8_deserializer,
 )
 
 
-@with_faulthandler
-def main(infile: IO, outfile: IO) -> None:
+def _main(infile: IO, outfile: IO) -> None:
     """
     Main method for saving into a Python data source.
 
@@ -83,194 +72,172 @@ def main(infile: IO, outfile: IO) -> None:
     instance and send a function using the writer instance that can be used
     in mapInPandas/mapInArrow back to the JVM.
     """
-    try:
-        check_python_version(infile)
-
-        start_faulthandler_periodic_traceback()
-
-        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
-        setup_memory_limits(memory_limit_mb)
 
-        setup_spark_files(infile)
-        setup_broadcasts(infile)
-
-        _accumulatorRegistry.clear()
+    # Receive the data source class.
+    data_source_cls = read_command(pickleSer, infile)
+    if not (isinstance(data_source_cls, type) and issubclass(data_source_cls, 
DataSource)):
+        raise PySparkAssertionError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "a subclass of DataSource",
+                "actual": f"'{type(data_source_cls).__name__}'",
+            },
+        )
+
+    # Check the name method is a class method.
+    if not inspect.ismethod(data_source_cls.name):
+        raise PySparkTypeError(
+            errorClass="DATA_SOURCE_TYPE_MISMATCH",
+            messageParameters={
+                "expected": "'name()' method to be a classmethod",
+                "actual": f"'{type(data_source_cls.name).__name__}'",
+            },
+        )
+
+    # Receive the provider name.
+    provider = utf8_deserializer.loads(infile)
+
+    with capture_outputs():
+        # Check if the provider name matches the data source's name.
+        name = data_source_cls.name()
+        if provider.lower() != name.lower():
+            raise PySparkAssertionError(
+                errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                messageParameters={
+                    "expected": f"provider with name {name}",
+                    "actual": f"'{provider}'",
+                },
+            )
 
-        # Receive the data source class.
-        data_source_cls = read_command(pickleSer, infile)
-        if not (isinstance(data_source_cls, type) and 
issubclass(data_source_cls, DataSource)):
+        # Receive the input schema
+        schema = _parse_datatype_json_string(utf8_deserializer.loads(infile))
+        if not isinstance(schema, StructType):
             raise PySparkAssertionError(
                 errorClass="DATA_SOURCE_TYPE_MISMATCH",
                 messageParameters={
-                    "expected": "a subclass of DataSource",
+                    "expected": "the schema to be a 'StructType'",
                     "actual": f"'{type(data_source_cls).__name__}'",
                 },
             )
 
-        # Check the name method is a class method.
-        if not inspect.ismethod(data_source_cls.name):
-            raise PySparkTypeError(
+        # Receive the return type
+        return_type = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+        if not isinstance(return_type, StructType):
+            raise PySparkAssertionError(
                 errorClass="DATA_SOURCE_TYPE_MISMATCH",
                 messageParameters={
-                    "expected": "'name()' method to be a classmethod",
-                    "actual": f"'{type(data_source_cls.name).__name__}'",
+                    "expected": "a return type of type 'StructType'",
+                    "actual": f"'{type(return_type).__name__}'",
                 },
             )
-
-        # Receive the provider name.
-        provider = utf8_deserializer.loads(infile)
-
-        with capture_outputs():
-            # Check if the provider name matches the data source's name.
-            name = data_source_cls.name()
-            if provider.lower() != name.lower():
-                raise PySparkAssertionError(
-                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                    messageParameters={
-                        "expected": f"provider with name {name}",
-                        "actual": f"'{provider}'",
-                    },
-                )
-
-            # Receive the input schema
-            schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
-            if not isinstance(schema, StructType):
+        assert len(return_type) == 1 and isinstance(return_type[0].dataType, 
BinaryType), (
+            "The output schema of Python data source write should contain only 
one column "
+            f"of type 'BinaryType', but got '{return_type}'"
+        )
+        return_col_name = return_type[0].name
+
+        # Receive the options.
+        options = CaseInsensitiveDict()
+        num_options = read_int(infile)
+        for _ in range(num_options):
+            key = utf8_deserializer.loads(infile)
+            value = utf8_deserializer.loads(infile)
+            options[key] = value
+
+        # Receive the `overwrite` flag.
+        overwrite = read_bool(infile)
+
+        is_streaming = read_bool(infile)
+        binary_as_bytes = read_bool(infile)
+
+        # Instantiate a data source.
+        data_source = data_source_cls(options=options)  # type: ignore
+
+        if is_streaming:
+            # Instantiate the streaming data source writer.
+            writer = data_source.streamWriter(schema, overwrite)
+            if not isinstance(writer, (DataSourceStreamWriter, 
DataSourceStreamArrowWriter)):
                 raise PySparkAssertionError(
                     errorClass="DATA_SOURCE_TYPE_MISMATCH",
                     messageParameters={
-                        "expected": "the schema to be a 'StructType'",
-                        "actual": f"'{type(data_source_cls).__name__}'",
+                        "expected": (
+                            "an instance of DataSourceStreamWriter or "
+                            "DataSourceStreamArrowWriter"
+                        ),
+                        "actual": f"'{type(writer).__name__}'",
                     },
                 )
+        else:
+            # Instantiate the data source writer.
 
-            # Receive the return type
-            return_type = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
-            if not isinstance(return_type, StructType):
+            writer = data_source.writer(schema, overwrite)  # type: 
ignore[assignment]
+            if not isinstance(writer, DataSourceWriter):
                 raise PySparkAssertionError(
                     errorClass="DATA_SOURCE_TYPE_MISMATCH",
                     messageParameters={
-                        "expected": "a return type of type 'StructType'",
-                        "actual": f"'{type(return_type).__name__}'",
+                        "expected": "an instance of DataSourceWriter",
+                        "actual": f"'{type(writer).__name__}'",
                     },
                 )
-            assert len(return_type) == 1 and 
isinstance(return_type[0].dataType, BinaryType), (
-                "The output schema of Python data source write should contain 
only one column "
-                f"of type 'BinaryType', but got '{return_type}'"
-            )
-            return_col_name = return_type[0].name
 
-            # Receive the options.
-            options = CaseInsensitiveDict()
-            num_options = read_int(infile)
-            for _ in range(num_options):
-                key = utf8_deserializer.loads(infile)
-                value = utf8_deserializer.loads(infile)
-                options[key] = value
-
-            # Receive the `overwrite` flag.
-            overwrite = read_bool(infile)
-
-            is_streaming = read_bool(infile)
-            binary_as_bytes = read_bool(infile)
-
-            # Instantiate a data source.
-            data_source = data_source_cls(options=options)  # type: ignore
-
-            if is_streaming:
-                # Instantiate the streaming data source writer.
-                writer = data_source.streamWriter(schema, overwrite)
-                if not isinstance(writer, (DataSourceStreamWriter, 
DataSourceStreamArrowWriter)):
-                    raise PySparkAssertionError(
-                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                        messageParameters={
-                            "expected": (
-                                "an instance of DataSourceStreamWriter or "
-                                "DataSourceStreamArrowWriter"
-                            ),
-                            "actual": f"'{type(writer).__name__}'",
-                        },
-                    )
-            else:
-                # Instantiate the data source writer.
-
-                writer = data_source.writer(schema, overwrite)  # type: 
ignore[assignment]
-                if not isinstance(writer, DataSourceWriter):
-                    raise PySparkAssertionError(
-                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                        messageParameters={
-                            "expected": "an instance of DataSourceWriter",
-                            "actual": f"'{type(writer).__name__}'",
-                        },
-                    )
-
-        # Create a function that can be used in mapInArrow.
-        import pyarrow as pa
-
-        converters = [
-            ArrowTableToRowsConversion._create_converter(
-                f.dataType, none_on_identity=False, 
binary_as_bytes=binary_as_bytes
+    # Create a function that can be used in mapInArrow.
+    import pyarrow as pa
+
+    converters = [
+        ArrowTableToRowsConversion._create_converter(
+            f.dataType, none_on_identity=False, binary_as_bytes=binary_as_bytes
+        )
+        for f in schema.fields
+    ]
+    fields = schema.fieldNames()
+
+    def data_source_write_func(iterator: Iterable[pa.RecordBatch]) -> 
Iterable[pa.RecordBatch]:
+        def batch_to_rows() -> Iterator[Row]:
+            for batch in iterator:
+                columns = [column.to_pylist() for column in batch.columns]
+                for row in range(0, batch.num_rows):
+                    values = [
+                        converters[col](columns[col][row])  # type: 
ignore[misc]
+                        for col in range(batch.num_columns)
+                    ]
+                    yield _create_row(fields=fields, values=values)
+
+        if isinstance(writer, DataSourceArrowWriter):
+            res = writer.write(iterator)
+        elif isinstance(writer, DataSourceStreamArrowWriter):
+            res = writer.write(iterator)  # type: ignore[arg-type]
+        else:
+            res = writer.write(batch_to_rows())
+
+        # Check the commit message has the right type.
+        if not isinstance(res, WriterCommitMessage):
+            raise PySparkRuntimeError(
+                errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                messageParameters={
+                    "expected": (
+                        "'WriterCommitMessage' as the return type of " "the 
`write` method"
+                    ),
+                    "actual": type(res).__name__,
+                },
             )
-            for f in schema.fields
-        ]
-        fields = schema.fieldNames()
 
-        def data_source_write_func(iterator: Iterable[pa.RecordBatch]) -> 
Iterable[pa.RecordBatch]:
-            def batch_to_rows() -> Iterator[Row]:
-                for batch in iterator:
-                    columns = [column.to_pylist() for column in batch.columns]
-                    for row in range(0, batch.num_rows):
-                        values = [
-                            converters[col](columns[col][row])  # type: 
ignore[misc]
-                            for col in range(batch.num_columns)
-                        ]
-                        yield _create_row(fields=fields, values=values)
+        # Serialize the commit message and return it.
+        pickled = pickleSer.dumps(res)
 
-            if isinstance(writer, DataSourceArrowWriter):
-                res = writer.write(iterator)
-            elif isinstance(writer, DataSourceStreamArrowWriter):
-                res = writer.write(iterator)  # type: ignore[arg-type]
-            else:
-                res = writer.write(batch_to_rows())
+        # Return the commit message.
+        messages = pa.array([pickled])
+        yield pa.record_batch([messages], names=[return_col_name])
 
-            # Check the commit message has the right type.
-            if not isinstance(res, WriterCommitMessage):
-                raise PySparkRuntimeError(
-                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                    messageParameters={
-                        "expected": (
-                            "'WriterCommitMessage' as the return type of " 
"the `write` method"
-                        ),
-                        "actual": type(res).__name__,
-                    },
-                )
+    # Return the pickled write UDF.
+    command = (data_source_write_func, return_type)
+    pickleSer._write_with_length(command, outfile)
 
-            # Serialize the commit message and return it.
-            pickled = pickleSer.dumps(res)
+    # Return the picked writer.
+    pickleSer._write_with_length(writer, outfile)
 
-            # Return the commit message.
-            messages = pa.array([pickled])
-            yield pa.record_batch([messages], names=[return_col_name])
 
-        # Return the pickled write UDF.
-        command = (data_source_write_func, return_type)
-        pickleSer._write_with_length(command, outfile)
-
-        # Return the picked writer.
-        pickleSer._write_with_length(writer, outfile)
-
-    except BaseException as e:
-        handle_worker_exception(e, outfile)
-        sys.exit(-1)
-
-    send_accumulator_updates(outfile)
-
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
+def main(infile: IO, outfile: IO) -> None:
+    worker_run(_main, infile, outfile)
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to