This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 172c9412e4bfb1d9b5d017851485bd1263c493af Author: pawelkocinski <[email protected]> AuthorDate: Sun Jul 27 00:26:37 2025 +0200 SEDONA-738 Fix unit tests. --- pom.xml | 12 +- sedonaworker/__init__.py | 0 sedonaworker/worker.py | 643 ++++++++++++++++ .../scala/org/apache/spark/SedonaSparkEnv.scala | 495 +++++++++++++ .../spark/api/python/SedonaPythonRunner.scala | 811 +++++++++++++++++++++ .../execution/python/SedonaArrowPythonRunner.scala | 70 ++ .../execution/python/SedonaPythonArrowInput.scala | 148 ++++ .../execution/python/SedonaPythonArrowOutput.scala | 135 ++++ .../execution/python/SedonaPythonUDFRunner.scala | 147 ++++ .../apache/spark/sql/udf/SedonaArrowStrategy.scala | 4 +- .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 21 +- 11 files changed, 2475 insertions(+), 11 deletions(-) diff --git a/pom.xml b/pom.xml index 44a1dcb16a..c8ae8b50e6 100644 --- a/pom.xml +++ b/pom.xml @@ -19,12 +19,12 @@ <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelVersion>4.0.0</modelVersion> - <parent> - <groupId>org.apache</groupId> - <artifactId>apache</artifactId> - <version>23</version> - <relativePath /> - </parent> +<!-- <parent>--> +<!-- <groupId>org.apache</groupId>--> +<!-- <artifactId>apache</artifactId>--> +<!-- <version>23</version>--> +<!-- <relativePath />--> +<!-- </parent>--> <groupId>org.apache.sedona</groupId> <artifactId>sedona-parent</artifactId> <version>1.8.1-SNAPSHOT</version> diff --git a/sedonaworker/__init__.py b/sedonaworker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sedonaworker/worker.py b/sedonaworker/worker.py new file mode 100644 index 0000000000..42fb20beb3 --- /dev/null +++ b/sedonaworker/worker.py @@ -0,0 +1,643 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Worker that receives input from Piped RDD. +""" +import os +import sys +import time +from inspect import currentframe, getframeinfo, getfullargspec +import importlib +import json +from typing import Any, Iterable, Iterator + +# 'resource' is a Unix specific module. +has_resource_module = True +try: + import resource +except ImportError: + has_resource_module = False +import traceback +import warnings +import faulthandler + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.java_gateway import local_connect_and_auth +from pyspark.taskcontext import BarrierTaskContext, TaskContext +from pyspark.files import SparkFiles +from pyspark.resource import ResourceInformation +from pyspark.rdd import PythonEvalType +from pyspark.serializers import ( + write_with_length, + write_int, + read_long, + read_bool, + write_long, + read_int, + SpecialLengths, + UTF8Deserializer, + CPickleSerializer, + BatchedSerializer, +) +from pyspark.sql.pandas.serializers import ( + ArrowStreamPandasUDFSerializer, + ArrowStreamPandasUDTFSerializer, + CogroupUDFSerializer, + ArrowStreamUDFSerializer, + ApplyInPandasWithStateSerializer, +) +from pyspark.sql.pandas.types import to_arrow_type +from pyspark.sql.types import BinaryType, StringType, StructType, _parse_datatype_json_string +from pyspark.util import fail_on_stopiteration, try_simplify_traceback +from pyspark import shuffle +from pyspark.errors import PySparkRuntimeError, PySparkTypeError + +pickleSer = CPickleSerializer() +utf8_deserializer = UTF8Deserializer() + + +def report_times(outfile, boot, init, finish): + write_int(SpecialLengths.TIMING_DATA, outfile) + write_long(int(1000 * boot), outfile) + write_long(int(1000 * init), outfile) + write_long(int(1000 * finish), outfile) + + +def add_path(path): + # worker can be used, so do not add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + +def read_command(serializer, file): + command = serializer._read_with_length(file) + if isinstance(command, Broadcast): + command = serializer.loads(command.value) + return command + + +def chain(f, g): + """chain two functions together""" + return lambda *a: g(f(*a)) + + +# def wrap_udf(f, return_type): +# if return_type.needConversion(): +# toInternal = return_type.toInternal +# return lambda *a: toInternal(f(*a)) +# else: +# return lambda *a: f(*a) + + +def wrap_scalar_pandas_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def verify_result_type(result): + if not hasattr(result, "__len__"): + pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": pd_type, + "actual": type(result).__name__, + }, + ) + return result + + def verify_result_length(result, length): + if len(result) != length: + raise PySparkRuntimeError( + error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF", + message_parameters={ + "expected": str(length), + "actual": str(len(result)), + }, + ) + return result + + return lambda *a: ( + verify_result_length(verify_result_type(f(*a)), len(a[0])), + arrow_return_type, + ) + + +def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): + num_arg = read_int(infile) + arg_offsets = [read_int(infile) for i in range(num_arg)] + chained_func = None + for i in range(read_int(infile)): + f, return_type = read_command(pickleSer, infile) + if chained_func is None: + chained_func = f + else: + chained_func = chain(chained_func, f) + + if eval_type in ( + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, + ): + func = chained_func + else: + # make sure StopIteration's raised in the user code are not ignored + # when they are processed in a for loop, raise them as RuntimeError's instead + func = fail_on_stopiteration(chained_func) + + # the last returnType will be the return type of UDF + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + return arg_offsets, wrap_scalar_pandas_udf(func, return_type) + else: + raise ValueError("Unknown eval type: {}".format(eval_type)) + + +# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and SQL_ARROW_BATCHED_UDF when +# returning StructType +def assign_cols_by_name(runner_conf): + return ( + runner_conf.get( + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true" + ).lower() + == "true" + ) + + +def read_udfs(pickleSer, infile, eval_type): + runner_conf = {} + + if eval_type in ( + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + ): + + # Load conf used for pandas_udf evaluation + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + + state_object_schema = None + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + timezone = runner_conf.get("spark.sql.session.timeZone", None) + safecheck = ( + runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() + == "true" + ) + + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name(runner_conf)) + elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: + ser = ArrowStreamUDFSerializer() + else: + # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of + # pandas Series. See SPARK-27240. + df_for_struct = ( + eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + ) + # Arrow-optimized Python UDF takes a struct type argument as a Row + struct_in_pandas = ( + "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict" + ) + ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion + arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + ser = ArrowStreamPandasUDFSerializer( + timezone, + safecheck, + assign_cols_by_name(runner_conf), + df_for_struct, + struct_in_pandas, + ndarray_as_list, + arrow_cast, + ) + else: + ser = BatchedSerializer(CPickleSerializer(), 100) + + num_udfs = read_int(infile) + + is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF + + if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter: + if is_scalar_iter: + assert num_udfs == 1, "One SCALAR_ITER UDF expected here." + if is_map_pandas_iter: + assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." + if is_map_arrow_iter: + assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." + + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + + def func(_, iterator): + num_input_rows = 0 + + def map_batch(batch): + nonlocal num_input_rows + + udf_args = [batch[offset] for offset in arg_offsets] + num_input_rows += len(udf_args[0]) + if len(udf_args) == 1: + return udf_args[0] + else: + return tuple(udf_args) + + iterator = map(map_batch, iterator) + result_iter = udf(iterator) + + num_output_rows = 0 + for result_batch, result_type in result_iter: + num_output_rows += len(result_batch) + # This assert is for Scalar Iterator UDF to fail fast. + # The length of the entire input can only be explicitly known + # by consuming the input iterator in user side. Therefore, + # it's very unlikely the output length is higher than + # input length. + assert ( + is_map_pandas_iter or is_map_arrow_iter or num_output_rows <= num_input_rows + ), "Pandas SCALAR_ITER UDF outputted more rows than input rows." + yield (result_batch, result_type) + + if is_scalar_iter: + try: + next(iterator) + except StopIteration: + pass + else: + raise PySparkRuntimeError( + error_class="STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF", + message_parameters={}, + ) + + if num_output_rows != num_input_rows: + raise PySparkRuntimeError( + error_class="RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF", + message_parameters={ + "output_length": str(num_output_rows), + "input_length": str(num_input_rows), + }, + ) + + # profiling is not supported for UDF + return func, None, ser, ser + + def extract_key_value_indexes(grouped_arg_offsets): + """ + Helper function to extract the key and value indexes from arg_offsets for the grouped and + cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code. + + Parameters + ---------- + grouped_arg_offsets: list + List containing the key and value indexes of columns of the + DataFrames to be passed to the udf. It consists of n repeating groups where n is the + number of DataFrames. Each group has the following format: + group[0]: length of group + group[1]: length of key indexes + group[2.. group[1] +2]: key attributes + group[group[1] +3 group[0]]: value attributes + """ + parsed = [] + idx = 0 + while idx < len(grouped_arg_offsets): + offsets_len = grouped_arg_offsets[idx] + idx += 1 + offsets = grouped_arg_offsets[idx : idx + offsets_len] + split_index = offsets[0] + 1 + offset_keys = offsets[1:split_index] + offset_values = offsets[split_index:] + parsed.append([offset_keys, offset_values]) + idx += offsets_len + return parsed + + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + # Create function like this: + # mapper a: f([a[0]], [a[0], a[1]]) + def mapper(a): + keys = [a[o] for o in parsed_offsets[0][0]] + vals = [a[o] for o in parsed_offsets[0][1]] + return f(keys, vals) + + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def mapper(a): + """ + The function receives (iterator of data, state) and performs extraction of key and + value from the data, with retaining lazy evaluation. + + See `load_stream` in `ApplyInPandasWithStateSerializer` for more details on the input + and see `wrap_grouped_map_pandas_udf_with_state` for more details on how output will + be used. + """ + from itertools import tee + + state = a[1] + data_gen = (x[0] for x in a[0]) + + # We know there should be at least one item in the iterator/generator. + # We want to peek the first element to construct the key, hence applying + # tee to construct the key while we retain another iterator/generator + # for values. + keys_gen, values_gen = tee(data_gen) + keys_elem = next(keys_gen) + keys = [keys_elem[o] for o in parsed_offsets[0][0]] + + # This must be generator comprehension - do not materialize. + vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen) + + return f(keys, vals, state) + + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + # We assume there is only one UDF here because cogrouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def mapper(a): + df1_keys = [a[0][o] for o in parsed_offsets[0][0]] + df1_vals = [a[0][o] for o in parsed_offsets[0][1]] + df2_keys = [a[1][o] for o in parsed_offsets[1][0]] + df2_vals = [a[1][o] for o in parsed_offsets[1][1]] + return f(df1_keys, df1_vals, df2_keys, df2_vals) + + else: + udfs = [] + for i in range(num_udfs): + udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)) + + def mapper(a): + result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets, f) in udfs) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + if len(result) == 1: + return result[0] + else: + return result + + def func(_, it): + return map(mapper, it) + + # profiling is not supported for UDF + return func, None, ser, ser + + +def main(infile, outfile): + faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) + try: + if faulthandler_log_path: + faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) + faulthandler_log_file = open(faulthandler_log_path, "w") + faulthandler.enable(file=faulthandler_log_file) + + boot_time = time.time() + split_index = read_int(infile) + if split_index == -1: # for unit tests + sys.exit(-1) + + version = utf8_deserializer.loads(infile) + if version != "%d.%d" % sys.version_info[:2]: + raise PySparkRuntimeError( + error_class="PYTHON_VERSION_MISMATCH", + message_parameters={ + "worker_version": str(sys.version_info[:2]), + "driver_version": str(version), + }, + ) + + # read inputs only for a barrier task + isBarrier = read_bool(infile) + boundPort = read_int(infile) + secret = UTF8Deserializer().loads(infile) + + # set up memory limits + memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) + if memory_limit_mb > 0 and has_resource_module: + total_memory = resource.RLIMIT_AS + try: + (soft_limit, hard_limit) = resource.getrlimit(total_memory) + msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) + print(msg, file=sys.stderr) + + # convert to bytes + new_limit = memory_limit_mb * 1024 * 1024 + + if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: + msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) + print(msg, file=sys.stderr) + resource.setrlimit(total_memory, (new_limit, new_limit)) + + except (resource.error, OSError, ValueError) as e: + # not all systems support resource limits, so warn instead of failing + lineno = ( + getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0 + ) + if "__file__" in globals(): + print( + warnings.formatwarning( + "Failed to set memory limit: {0}".format(e), + ResourceWarning, + __file__, + lineno, + ), + file=sys.stderr, + ) + + # initialize global state + taskContext = None + if isBarrier: + taskContext = BarrierTaskContext._getOrCreate() + BarrierTaskContext._initialize(boundPort, secret) + # Set the task context instance here, so we can get it by TaskContext.get for + # both TaskContext and BarrierTaskContext + TaskContext._setTaskContext(taskContext) + else: + taskContext = TaskContext._getOrCreate() + # read inputs for TaskContext info + taskContext._stageId = read_int(infile) + taskContext._partitionId = read_int(infile) + taskContext._attemptNumber = read_int(infile) + taskContext._taskAttemptId = read_long(infile) + taskContext._cpus = read_int(infile) + taskContext._resources = {} + for r in range(read_int(infile)): + key = utf8_deserializer.loads(infile) + name = utf8_deserializer.loads(infile) + addresses = [] + taskContext._resources = {} + for a in range(read_int(infile)): + addresses.append(utf8_deserializer.loads(infile)) + taskContext._resources[key] = ResourceInformation(name, addresses) + + taskContext._localProperties = dict() + for i in range(read_int(infile)): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + taskContext._localProperties[k] = v + + shuffle.MemoryBytesSpilled = 0 + shuffle.DiskBytesSpilled = 0 + _accumulatorRegistry.clear() + + # fetch name of workdir + spark_files_dir = utf8_deserializer.loads(infile) + SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True + + # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH + add_path(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) + for _ in range(num_python_includes): + filename = utf8_deserializer.loads(infile) + add_path(os.path.join(spark_files_dir, filename)) + + importlib.invalidate_caches() + + # fetch names and values of broadcast variables + needs_broadcast_decryption_server = read_bool(infile) + num_broadcast_variables = read_int(infile) + if needs_broadcast_decryption_server: + # read the decrypted data from a server in the jvm + port = read_int(infile) + auth_secret = utf8_deserializer.loads(infile) + (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) + + for _ in range(num_broadcast_variables): + bid = read_long(infile) + if bid >= 0: + if needs_broadcast_decryption_server: + read_bid = read_long(broadcast_sock_file) + assert read_bid == bid + _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file) + else: + path = utf8_deserializer.loads(infile) + _broadcastRegistry[bid] = Broadcast(path=path) + + else: + bid = -bid - 1 + _broadcastRegistry.pop(bid) + + if needs_broadcast_decryption_server: + broadcast_sock_file.write(b"1") + broadcast_sock_file.close() + + _accumulatorRegistry.clear() + eval_type = read_int(infile) + if eval_type == PythonEvalType.NON_UDF: + func, profiler, deserializer, serializer = read_command(pickleSer, infile) + else: + func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) + + init_time = time.time() + + def process(): + iterator = deserializer.load_stream(infile) + out_iter = func(split_index, iterator) + try: + serializer.dump_stream(out_iter, outfile) + finally: + if hasattr(out_iter, "close"): + out_iter.close() + + if profiler: + profiler.profile(process) + else: + process() + + # Reset task context to None. This is a guard code to avoid residual context when worker + # reuse. + TaskContext._setTaskContext(None) + BarrierTaskContext._setTaskContext(None) + except BaseException as e: + try: + exc_info = None + if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): + tb = try_simplify_traceback(sys.exc_info()[-1]) + if tb is not None: + e.__cause__ = None + exc_info = "".join(traceback.format_exception(type(e), e, tb)) + if exc_info is None: + exc_info = traceback.format_exc() + + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) + write_with_length(exc_info.encode("utf-8"), outfile) + except IOError: + # JVM close the socket + pass + except BaseException: + # Write the error to stderr if it happened while serializing + print("PySpark worker failed with exception:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + sys.exit(-1) + finally: + if faulthandler_log_path: + faulthandler.disable() + faulthandler_log_file.close() + os.remove(faulthandler_log_path) + finish_time = time.time() + report_times(outfile, boot_time, init_time, finish_time) + write_long(shuffle.MemoryBytesSpilled, outfile) + write_long(shuffle.DiskBytesSpilled, outfile) + + # Mark the beginning of the accumulators section of the output + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + write_int(len(_accumulatorRegistry), outfile) + for (aid, accum) in _accumulatorRegistry.items(): + pickleSer._write_with_length((aid, accum._value), outfile) + + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + sys.exit(-1) + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + # TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8. + write_int(os.getpid(), sock_file) + sock_file.flush() + main(sock_file, sock_file) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala new file mode 100644 index 0000000000..9449a291f5 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala @@ -0,0 +1,495 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File +import java.net.Socket +import java.util.Locale + +import scala.collection.JavaConverters._ +import scala.collection.concurrent +import scala.collection.mutable +import scala.util.Properties + +import com.google.common.cache.CacheBuilder +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.python.PythonWorkerFactory +import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.executor.ExecutorBackend +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.config._ +import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager} +import org.apache.spark.metrics.{MetricsSystem, MetricsSystemInstances} +import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} +import org.apache.spark.network.shuffle.ExternalBlockStoreClient +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} +import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.storage._ +import org.apache.spark.util.{RpcUtils, Utils} + +/** + * :: DeveloperApi :: + * Holds all the runtime environment objects for a running Spark instance (either master or worker), + * including the serializer, RpcEnv, block manager, map output tracker, etc. Currently + * Spark code finds the SparkEnv through a global variable, so all the threads can access the same + * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext). + */ +@DeveloperApi +class SedonaSparkEnv ( + val executorId: String, + private[spark] val rpcEnv: RpcEnv, + val serializer: Serializer, + val closureSerializer: Serializer, + val serializerManager: SerializerManager, + val mapOutputTracker: MapOutputTracker, + val shuffleManager: ShuffleManager, + val broadcastManager: BroadcastManager, + val blockManager: BlockManager, + val securityManager: SecurityManager, + val metricsSystem: MetricsSystem, + val memoryManager: MemoryManager, + val outputCommitCoordinator: OutputCommitCoordinator, + val conf: SparkConf) extends Logging { + + @volatile private[spark] var isStopped = false + private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() + + // A general, soft-reference map for metadata needed during HadoopRDD split computation + // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). + private[spark] val hadoopJobMetadata = + CacheBuilder.newBuilder().maximumSize(1000).softValues().build[String, AnyRef]().asMap() + + private[spark] var driverTmpDir: Option[String] = None + + private[spark] var executorBackend: Option[ExecutorBackend] = None + + private[spark] def stop(): Unit = { + + if (!isStopped) { + isStopped = true + pythonWorkers.values.foreach(_.stop()) + mapOutputTracker.stop() + shuffleManager.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + rpcEnv.awaitTermination() + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver + driverTmpDir match { + case Some(path) => + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor + } + } + } + + private[spark] + def createPythonWorker( + pythonExec: String, + envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() + } + } + + private[spark] + def destroyPythonWorker(pythonExec: String, + envVars: Map[String, String], worker: Socket): Unit = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.stopWorker(worker)) + } + } + + private[spark] + def releasePythonWorker(pythonExec: String, + envVars: Map[String, String], worker: Socket): Unit = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.releaseWorker(worker)) + } + } +} + +object SedonaSparkEnv extends Logging { + @volatile private var env: SedonaSparkEnv = _ + + private[spark] val driverSystemName = "sparkDriver" + private[spark] val executorSystemName = "sparkExecutor" + + def set(e: SedonaSparkEnv): Unit = { + env = e + } + + /** + * Returns the SparkEnv. + */ + def get: SedonaSparkEnv = { + env + } + + /** + * Create a SparkEnv for the driver. + */ + private[spark] def createDriverEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus, + numCores: Int, + sparkContext: SparkContext, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + assert(conf.contains(DRIVER_HOST_ADDRESS), + s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!") + assert(conf.contains(DRIVER_PORT), s"${DRIVER_PORT.key} is not set on the driver!") + val bindAddress = conf.get(DRIVER_BIND_ADDRESS) + val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) + val port = conf.get(DRIVER_PORT) + val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(conf)) + } else { + None + } + create( + conf, + SparkContext.DRIVER_IDENTIFIER, + bindAddress, + advertiseAddress, + Option(port), + isLocal, + numCores, + ioEncryptionKey, + listenerBus = listenerBus, + Option(sparkContext), + mockOutputCommitCoordinator = mockOutputCommitCoordinator + ) + } + + /** + * Create a SparkEnv for an executor. + * In coarse-grained mode, the executor provides an RpcEnv that is already instantiated. + */ + private[spark] def createExecutorEnv( + conf: SparkConf, + executorId: String, + bindAddress: String, + hostname: String, + numCores: Int, + ioEncryptionKey: Option[Array[Byte]], + isLocal: Boolean): SparkEnv = { + val env = create( + conf, + executorId, + bindAddress, + hostname, + None, + isLocal, + numCores, + ioEncryptionKey + ) + SparkEnv.set(env) + env + } + + private[spark] def createExecutorEnv( + conf: SparkConf, + executorId: String, + hostname: String, + numCores: Int, + ioEncryptionKey: Option[Array[Byte]], + isLocal: Boolean): SparkEnv = { + createExecutorEnv(conf, executorId, hostname, + hostname, numCores, ioEncryptionKey, isLocal) + } + + /** + * Helper method to create a SparkEnv for a driver or an executor. + */ + // scalastyle:off argcount + private def create( + conf: SparkConf, + executorId: String, + bindAddress: String, + advertiseAddress: String, + port: Option[Int], + isLocal: Boolean, + numUsableCores: Int, + ioEncryptionKey: Option[Array[Byte]], + listenerBus: LiveListenerBus = null, + sc: Option[SparkContext] = None, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + // scalastyle:on argcount + + val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + + // Listener bus is only used on the driver + if (isDriver) { + assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") + } + val authSecretFileConf = if (isDriver) AUTH_SECRET_FILE_DRIVER else AUTH_SECRET_FILE_EXECUTOR + val securityManager = new SecurityManager(conf, ioEncryptionKey, authSecretFileConf) + if (isDriver) { + securityManager.initializeAuth() + } + + ioEncryptionKey.foreach { _ => + if (!securityManager.isEncryptionEnabled()) { + logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + + "wire.") + } + } + + val systemName = if (isDriver) driverSystemName else executorSystemName + val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, + securityManager, numUsableCores, !isDriver) + + // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. + if (isDriver) { + conf.set(DRIVER_PORT, rpcEnv.address.port) + } + + val serializer = Utils.instantiateSerializerFromConf[Serializer](SERIALIZER, conf, isDriver) + logDebug(s"Using serializer: ${serializer.getClass}") + + val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) + + val closureSerializer = new JavaSerializer(conf) + + def registerOrLookupEndpoint( + name: String, endpointCreator: => RpcEndpoint): + RpcEndpointRef = { + if (isDriver) { + logInfo("Registering " + name) + rpcEnv.setupEndpoint(name, endpointCreator) + } else { + RpcUtils.makeDriverRef(name, conf, rpcEnv) + } + } + + val broadcastManager = new BroadcastManager(isDriver, conf) + + val mapOutputTracker = if (isDriver) { + new MapOutputTrackerMaster(conf, broadcastManager, isLocal) + } else { + new MapOutputTrackerWorker(conf) + } + + // Have to assign trackerEndpoint after initialization as MapOutputTrackerEndpoint + // requires the MapOutputTracker itself + mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + + // Let the user specify short names for shuffle managers + val shortShuffleMgrNames = Map( + "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, + "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) + val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER) + val shuffleMgrClass = + shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) + val shuffleManager = Utils.instantiateSerializerOrShuffleManager[ShuffleManager]( + shuffleMgrClass, conf, isDriver) + + val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores) + + val blockManagerPort = if (isDriver) { + conf.get(DRIVER_BLOCK_MANAGER_PORT) + } else { + conf.get(BLOCK_MANAGER_PORT) + } + + val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + Some(new ExternalBlockStoreClient(transConf, securityManager, + securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) + } else { + None + } + + // Mapping from block manager id to the block manager's information. + val blockManagerInfo = new concurrent.TrieMap[BlockManagerId, BlockManagerInfo]() + val blockManagerMaster = new BlockManagerMaster( + registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint( + rpcEnv, + isLocal, + conf, + listenerBus, + if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + externalShuffleClient + } else { + None + }, blockManagerInfo, + mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + shuffleManager, + isDriver)), + registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME, + new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)), + conf, + isDriver) + + val blockTransferService = + new NettyBlockTransferService(conf, securityManager, serializerManager, bindAddress, + advertiseAddress, blockManagerPort, numUsableCores, blockManagerMaster.driverEndpoint) + + // NB: blockManager is not valid until initialize() is called later. + val blockManager = new BlockManager( + executorId, + rpcEnv, + blockManagerMaster, + serializerManager, + conf, + memoryManager, + mapOutputTracker, + shuffleManager, + blockTransferService, + securityManager, + externalShuffleClient) + + val metricsSystem = if (isDriver) { + // Don't start metrics system right now for Driver. + // We need to wait for the task scheduler to give us an app ID. + // Then we can start the metrics system. + MetricsSystem.createMetricsSystem(MetricsSystemInstances.DRIVER, conf) + } else { + // We need to set the executor ID before the MetricsSystem is created because sources and + // sinks specified in the metrics configuration file will want to incorporate this executor's + // ID into the metrics they report. + conf.set(EXECUTOR_ID, executorId) + val ms = MetricsSystem.createMetricsSystem(MetricsSystemInstances.EXECUTOR, conf) + ms.start(conf.get(METRICS_STATIC_SOURCES_ENABLED)) + ms + } + + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { + if (isDriver) { + new OutputCommitCoordinator(conf, isDriver, sc) + } else { + new OutputCommitCoordinator(conf, isDriver) + } + + } + val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", + new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) + outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) + + val envInstance = new SparkEnv( + executorId, + rpcEnv, + serializer, + closureSerializer, + serializerManager, + mapOutputTracker, + shuffleManager, + broadcastManager, + blockManager, + securityManager, + metricsSystem, + memoryManager, + outputCommitCoordinator, + conf) + + // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is + // called, and we only need to do it for driver. Because driver may run as a service, and if we + // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) + } + + envInstance + } + + /** + * Return a map representation of jvm information, Spark properties, system properties, and + * class paths. Map keys define the category, and map values represent the corresponding + * attributes as a sequence of KV pairs. This is used mainly for SparkListenerEnvironmentUpdate. + */ + private[spark] def environmentDetails( + conf: SparkConf, + hadoopConf: Configuration, + schedulingMode: String, + addedJars: Seq[String], + addedFiles: Seq[String], + addedArchives: Seq[String], + metricsProperties: Map[String, String]): Map[String, Seq[(String, String)]] = { + + import Properties._ + val jvmInformation = Seq( + ("Java Version", s"$javaVersion ($javaVendor)"), + ("Java Home", javaHome), + ("Scala Version", versionString) + ).sorted + + // Spark properties + // This includes the scheduling mode whether or not it is configured (used by SparkUI) + val schedulerMode = + if (!conf.contains(SCHEDULER_MODE)) { + Seq((SCHEDULER_MODE.key, schedulingMode)) + } else { + Seq.empty[(String, String)] + } + val sparkProperties = (conf.getAll ++ schedulerMode).sorted + + // System properties that are not java classpaths + val systemProperties = Utils.getSystemProperties.toSeq + val otherProperties = systemProperties.filter { case (k, _) => + k != "java.class.path" && !k.startsWith("spark.") + }.sorted + + // Class paths including all added jars and files + val classPathEntries = javaClassPath + .split(File.pathSeparator) + .filterNot(_.isEmpty) + .map((_, "System Classpath")) + val addedJarsAndFiles = (addedJars ++ addedFiles ++ addedArchives).map((_, "Added By User")) + val classPaths = (addedJarsAndFiles ++ classPathEntries).sorted + + // Add Hadoop properties, it will not ignore configs including in Spark. Some spark + // conf starting with "spark.hadoop" may overwrite it. + val hadoopProperties = hadoopConf.asScala + .map(entry => (entry.getKey, entry.getValue)).toSeq.sorted + Map[String, Seq[(String, String)]]( + "JVM Information" -> jvmInformation, + "Spark Properties" -> sparkProperties, + "Hadoop Properties" -> hadoopProperties, + "System Properties" -> otherProperties, + "Classpath Entries" -> classPaths, + "Metrics Properties" -> metricsProperties.toSeq.sorted) + } +} + diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala new file mode 100644 index 0000000000..026518272c --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala @@ -0,0 +1,811 @@ +package org.apache.spark.api.python + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark._ +import org.apache.spark.SedonaSparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Python._ +import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} +import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} +import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.util._ + +import java.io._ +import java.net._ +import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.{Path, Files => JavaFiles} +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + + +/** + * Enumerate the type of command that will be sent to the Python worker + */ +private[spark] object PythonEvalType { + val NON_UDF = 0 + + val SQL_BATCHED_UDF = 100 + val SQL_ARROW_BATCHED_UDF = 101 + + val SQL_SCALAR_PANDAS_UDF = 200 + val SQL_GROUPED_MAP_PANDAS_UDF = 201 + val SQL_GROUPED_AGG_PANDAS_UDF = 202 + val SQL_WINDOW_AGG_PANDAS_UDF = 203 + val SQL_SCALAR_PANDAS_ITER_UDF = 204 + val SQL_MAP_PANDAS_ITER_UDF = 205 + val SQL_COGROUPED_MAP_PANDAS_UDF = 206 + val SQL_MAP_ARROW_ITER_UDF = 207 + val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 + + val SQL_TABLE_UDF = 300 + val SQL_ARROW_TABLE_UDF = 301 + + def toString(pythonEvalType: Int): String = pythonEvalType match { + case NON_UDF => "NON_UDF" + case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" + case SQL_ARROW_BATCHED_UDF => "SQL_ARROW_BATCHED_UDF" + case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" + case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" + case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" + case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF" + case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" + case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" + case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" + case SQL_TABLE_UDF => "SQL_TABLE_UDF" + case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF" + } +} + +private object SedonaBasePythonRunner { + + private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") + + private def faultHandlerLogPath(pid: Int): Path = { + new File(faultHandlerLogDir, pid.toString).toPath + } +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). + */ +private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( + protected val funcs: Seq[ChainedPythonFunctions], + protected val evalType: Int, + protected val argOffsets: Array[Array[Int]], + protected val jobArtifactUUID: Option[String]) + extends Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + private val conf = SparkEnv.get.conf + protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) + private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) + private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) + protected val simplifiedTraceback: Boolean = false + + // All the Python functions should have the same exec, version and envvars. + protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars + protected val pythonExec: String = funcs.head.funcs.head.pythonExec + protected val pythonVer: String = funcs.head.funcs.head.pythonVer + + // TODO: support accumulator in multiple UDF + protected val accumulator: PythonAccumulatorV2 = funcs.head.funcs.head.accumulator + + // Python accumulator is always set in production except in tests. See SPARK-27893 + private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator) + + // Expose a ServerSocket to support method calls via socket from Python side. + private[spark] var serverSocket: Option[ServerSocket] = None + + // Authentication helper used when serving method calls via socket from Python side. + private lazy val authHelper = new SocketAuthHelper(conf) + + // each python worker gets an equal part of the allocation. the worker pool will grow to the + // number of concurrent tasks, which is determined by the number of cores in this executor. + private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = { + mem.map(_ / cores) + } + + def compute( + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): Iterator[OUT] = { + val startTime = System.currentTimeMillis + val sedonaEnv = SedonaSparkEnv.get + val env = SparkEnv.get + + // Get the executor cores and pyspark memory, they are passed via the local properties when + // the user specified them in a ResourceProfile. + val execCoresProp = Option(context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY)) + val memoryMb = Option(context.getLocalProperty(PYSPARK_MEMORY_LOCAL_PROPERTY)).map(_.toLong) + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + // If OMP_NUM_THREADS is not explicitly set, override it with the number of task cpus. + // See SPARK-42613 for details. + if (conf.getOption("spark.executorEnv.OMP_NUM_THREADS").isEmpty) { + envVars.put("OMP_NUM_THREADS", conf.get("spark.task.cpus", "1")) + } + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuseWorker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + if (simplifiedTraceback) { + envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") + } + // SPARK-30299 this could be wrong with standalone mode when executor + // cores might not be correct because it defaults to all cores on the box. + val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES)) + val workerMemoryMb = getWorkerMemoryMb(memoryMb, execCores) + if (workerMemoryMb.isDefined) { + envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) + } + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) + envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + if (faultHandlerEnabled) { + envVars.put("PYTHON_FAULTHANDLER_DIR", SedonaBasePythonRunner.faultHandlerLogDir.toString) + } + + envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) + + val (worker: Socket, pid: Option[Int]) = env.createPythonWorker( + pythonExec, envVars.asScala.toMap) + // Whether is the worker released into idle pool or closed. When any codes try to release or + // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make + // sure there is only one winner that is going to release or close the worker. + val releasedOrClosed = new AtomicBoolean(false) + + // Start a thread to feed the process input from our parent's iterator + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + + context.addTaskCompletionListener[Unit] { _ => + writerThread.shutdownOnTaskCompletion() + if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + new WriterMonitorThread(SparkEnv.get, worker, writerThread, context).start() + if (reuseWorker) { + val key = (worker, context.taskAttemptId) + // SPARK-35009: avoid creating multiple monitor threads for the same python worker + // and task context + if (PythonRunner.runningMonitorThreads.add(key)) { + new MonitorThread(SparkEnv.get, worker, context).start() + } + } else { + new MonitorThread(SparkEnv.get, worker, context).start() + } + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val stdoutIterator = newReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) + new InterruptibleIterator(context, stdoutIterator) + } + + protected def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[OUT] + + /** + * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Python process. + */ + abstract class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Throwable = null + + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) + + setDaemon(true) + + /** Contains the throwable thrown while writing the parent iterator to the Python process. */ + def exception: Option[Throwable] = Option(_exception) + + /** + * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur + * due to cleanup. + */ + def shutdownOnTaskCompletion(): Unit = { + assert(context.isCompleted) + this.interrupt() + // Task completion listeners that run after this method returns may invalidate + // `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized + // reader, a task completion listener will free the underlying off-heap buffers. If the writer + // thread is still running when `inputIterator` is invalidated, it can cause a use-after-free + // bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer + // thread to exit before returning. + this.join() + } + + /** + * Writes a command section to the stream connected to the Python worker. + */ + protected def writeCommand(dataOut: DataOutputStream): Unit + + /** + * Writes input data to the stream connected to the Python worker. + */ + protected def writeIteratorToStream(dataOut: DataOutputStream): Unit + + override def run(): Unit = Utils.logUncaughtExceptions { + try { + TaskContext.setTaskContext(context) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // Init a ServerSocket to accept method calls from Python side. + val isBarrier = context.isInstanceOf[BarrierTaskContext] + if (isBarrier) { + serverSocket = Some(new ServerSocket(/* port */ 0, + /* backlog */ 1, + InetAddress.getByName("localhost"))) + // A call to accept() for ServerSocket shall block infinitely. + serverSocket.foreach(_.setSoTimeout(0)) + new Thread("accept-connections") { + setDaemon(true) + + override def run(): Unit = { + while (!serverSocket.get.isClosed()) { + var sock: Socket = null + try { + sock = serverSocket.get.accept() + // Wait for function call from python side. + sock.setSoTimeout(10000) + authHelper.authClient(sock) + val input = new DataInputStream(sock.getInputStream()) + val requestMethod = input.readInt() + // The BarrierTaskContext function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + requestMethod match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + barrierAndServe(requestMethod, sock) + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val length = input.readInt() + val message = new Array[Byte](length) + input.readFully(message) + barrierAndServe(requestMethod, sock, new String(message, UTF_8)) + case _ => + val out = new DataOutputStream(new BufferedOutputStream( + sock.getOutputStream)) + writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) + } + } catch { + case e: SocketException if e.getMessage.contains("Socket closed") => + // It is possible that the ServerSocket is not closed, but the native socket + // has already been closed, we shall catch and silently ignore this case. + } finally { + if (sock != null) { + sock.close() + } + } + } + } + }.start() + } + val secret = if (isBarrier) { + authHelper.secret + } else { + "" + } + // Close ServerSocket on task completion. + serverSocket.foreach { server => + context.addTaskCompletionListener[Unit](_ => server.close()) + } + val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) + if (boundPort == -1) { + val message = "ServerSocket failed to bind to Java side." + logError(message) + throw new SparkException(message) + } else if (isBarrier) { + logDebug(s"Started ServerSocket on port $boundPort.") + } + // Write out the TaskContextInfo + dataOut.writeBoolean(isBarrier) + dataOut.writeInt(boundPort) + val secretBytes = secret.getBytes(UTF_8) + dataOut.writeInt(secretBytes.length) + dataOut.write(secretBytes, 0, secretBytes.length) + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) + dataOut.writeInt(context.cpus()) + val resources = context.resources() + dataOut.writeInt(resources.size) + resources.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v.name, dataOut) + dataOut.writeInt(v.addresses.size) + v.addresses.foreach { case addr => + PythonRDD.writeUTF(addr, dataOut) + } + } + val localProps = context.getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + + // sparkFilesDir + val root = jobArtifactUUID.map { uuid => + new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath + }.getOrElse(SparkFiles.getRootDirectory()) + PythonRDD.writeUTF(root, dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val addedBids = newBids.diff(oldBids) + val cnt = toRemove.size + addedBids.size + val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty + dataOut.writeBoolean(needsDecryptionServer) + dataOut.writeInt(cnt) + def sendBidsToRemove(): Unit = { + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(-bid - 1) // bid >= 0 + oldBids.remove(bid) + } + } + if (needsDecryptionServer) { + // if there is encryption, we setup a server which reads the encrypted files, and sends + // the decrypted data to python + val idsAndFiles = broadcastVars.flatMap { broadcast => + if (!oldBids.contains(broadcast.id)) { + oldBids.add(broadcast.id) + Some((broadcast.id, broadcast.value.path)) + } else { + None + } + } + val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) + dataOut.writeInt(server.port) + logTrace(s"broadcast decryption server setup on ${server.port}") + PythonRDD.writeUTF(server.secret, dataOut) + sendBidsToRemove() + idsAndFiles.foreach { case (id, _) => + // send new broadcast + dataOut.writeLong(id) + } + dataOut.flush() + logTrace("waiting for python to read decrypted broadcast data from server") + server.waitTillBroadcastDataSent() + logTrace("done sending decrypted data to python") + } else { + sendBidsToRemove() + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + } + dataOut.flush() + + dataOut.writeInt(evalType) + writeCommand(dataOut) + writeIteratorToStream(dataOut) + + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } catch { + case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => + if (context.isCompleted || context.isInterrupted) { + logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + + "cleanup)", t) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } else { + // We must avoid throwing exceptions/NonFatals here, because the thread uncaught + // exception handler will kill the whole executor (see + // org.apache.spark.executor.Executor). + _exception = t + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } + } + } + + /** + * Gateway to call BarrierTaskContext methods. + */ + def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { + require( + serverSocket.isDefined, + "No available ServerSocket to redirect the BarrierTaskContext method call." + ) + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + try { + val messages = requestMethod match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + context.asInstanceOf[BarrierTaskContext].barrier() + Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS) + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + context.asInstanceOf[BarrierTaskContext].allGather(message) + } + out.writeInt(messages.length) + messages.foreach(writeUTF(_, out)) + } catch { + case e: SparkException => + writeUTF(e.getMessage, out) + } finally { + out.close() + } + } + + def writeUTF(str: String, dataOut: DataOutputStream): Unit = { + val bytes = str.getBytes(UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + } + + abstract class ReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext) + extends Iterator[OUT] { + + private var nextObj: OUT = _ + private var eos = false + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false + } + } + + override def next(): OUT = { + if (hasNext) { + val obj = nextObj + nextObj = null.asInstanceOf[OUT] + obj + } else { + Iterator.empty.next() + } + } + + /** + * Reads next object from the stream. + * When the stream reaches end of data, needs to process the following sections, + * and then returns null. + */ + protected def read(): OUT + + protected def handleTimingData(): Unit = { + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + } + + protected def handlePythonException(): PythonException = { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + new PythonException(new String(obj, StandardCharsets.UTF_8), + writerThread.exception.orNull) + } + + protected def handleEndOfDataSection(): Unit = { + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + maybeAccumulator.foreach(_.add(update)) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + } + } + eos = true + } + + protected val handleException: PartialFunction[Throwable, OUT] = { + case e: Exception if context.isInterrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) + + case e: Exception if writerThread.exception.isDefined => + logError("Python worker exited unexpectedly (crashed)", e) + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get + + case eof: EOFException if faultHandlerEnabled && pid.isDefined && + JavaFiles.exists(SedonaBasePythonRunner.faultHandlerLogPath(pid.get)) => + val path = SedonaBasePythonRunner.faultHandlerLogPath(pid.get) + val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n" + JavaFiles.deleteIfExists(path) + throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof) + + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + } + + /** + * It is necessary to have a monitor thread for python workers if the user cancels with + * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the + * threads can block indefinitely. + */ + class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + extends Thread(s"Worker Monitor for $pythonExec") { + + /** How long to wait before killing the python worker if a task cannot be interrupted. */ + private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) + + setDaemon(true) + + private def monitorWorker(): Unit = { + // Kill the worker if it is interrupted, checking until task completion. + // TODO: This has a race condition if interruption occurs, as completed may still become true. + while (!context.isInterrupted && !context.isCompleted) { + Thread.sleep(2000) + } + if (!context.isCompleted) { + Thread.sleep(taskKillTimeout) + if (!context.isCompleted) { + try { + // Mimic the task name used in `Executor` to help the user find out the task to blame. + val taskName = s"${context.partitionId}.${context.attemptNumber} " + + s"in stage ${context.stageId} (TID ${context.taskAttemptId})" + logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } + + override def run(): Unit = { + try { + monitorWorker() + } finally { + if (reuseWorker) { + val key = (worker, context.taskAttemptId) + PythonRunner.runningMonitorThreads.remove(key) + } + } + } + } + + /** + * This thread monitors the WriterThread and kills it in case of deadlock. + * + * A deadlock can arise if the task completes while the writer thread is sending input to the + * Python process (e.g. due to the use of `take()`), and the Python process is still producing + * output. When the inputs are sufficiently large, this can result in a deadlock due to the use of + * blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the socket. + */ + class WriterMonitorThread( + env: SparkEnv, worker: Socket, writerThread: WriterThread, context: TaskContext) + extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") { + + /** + * How long to wait before closing the socket if the writer thread has not exited after the task + * ends. + */ + private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) + + setDaemon(true) + + override def run(): Unit = { + // Wait until the task is completed (or the writer thread exits, in which case this thread has + // nothing to do). + while (!context.isCompleted && writerThread.isAlive) { + Thread.sleep(2000) + } + if (writerThread.isAlive) { + Thread.sleep(taskKillTimeout) + // If the writer thread continues running, this indicates a deadlock. Kill the worker to + // resolve the deadlock. + if (writerThread.isAlive) { + try { + // Mimic the task name used in `Executor` to help the user find out the task to blame. + val taskName = s"${context.partitionId}.${context.attemptNumber} " + + s"in stage ${context.stageId} (TID ${context.taskAttemptId})" + logWarning( + s"Detected deadlock while completing task $taskName: " + + "Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } + } +} + +private[spark] object PythonRunner { + + // already running worker monitor threads for worker and task attempts ID pairs + val runningMonitorThreads = ConcurrentHashMap.newKeySet[(Socket, Long)]() + + private var printPythonInfo: AtomicBoolean = new AtomicBoolean(true) + + def apply(func: PythonFunction, jobArtifactUUID: Option[String]): PythonRunner = { + if (printPythonInfo.compareAndSet(true, false)) { + PythonUtils.logPythonInfo(func.pythonExec) + } + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), jobArtifactUUID) + } +} + +/** + * A helper class to run Python mapPartition in Spark. + */ +private[spark] class PythonRunner( + funcs: Seq[ChainedPythonFunctions], jobArtifactUUID: Option[String]) + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, PythonEvalType.NON_UDF, Array(Array(0)), jobArtifactUUID) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command.toArray) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.emptyByteArray + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +private[spark] object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 + val END_OF_STREAM = -4 + val NULL = -5 + val START_ARROW_STREAM = -6 + val END_OF_MICRO_BATCH = -7 +} + +private[spark] object BarrierTaskContextMessageProtocol { + val BARRIER_FUNCTION = 1 + val ALL_GATHER_FUNCTION = 2 + val BARRIER_RESULT_SUCCESS = "success" + val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala new file mode 100644 index 0000000000..27e4b851ee --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -0,0 +1,70 @@ +package org.apache.spark.sql.execution.python + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class SedonaArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + protected override val schema: StructType, + protected override val timeZoneId: String, + protected override val largeVarTypes: Boolean, + protected override val workerConf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + funcs, evalType, argOffsets, jobArtifactUUID) + with SedonaBasicPythonArrowInput + with SedonaBasicPythonArrowOutput { + + override val pythonExec: String = + SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( + funcs.head.funcs.head.pythonExec) + + override val errorOnDuplicatedFieldNames: Boolean = true + + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") +} + +object SedonaArrowPythonRunner { + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> + conf.pandasGroupedMapAssignColumnsByName.toString) + val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> + conf.arrowSafeTypeConversion.toString) + Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala new file mode 100644 index 0000000000..d2c390282c --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -0,0 +1,148 @@ +package org.apache.spark.sql.execution.python + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.apache.spark.api.python +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, SedonaBasePythonRunner} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.Utils +import org.apache.spark.{SparkEnv, TaskContext} + +import java.io.DataOutputStream +import java.net.Socket + +/** + * A trait that can be mixed-in with [[python.BasePythonRunner]]. It implements the logic from + * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). + */ +private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[IN, _] => + protected val workerConf: Map[String, String] + + protected val schema: StructType + + protected val timeZoneId: String + + protected val errorOnDuplicatedFieldNames: Boolean + + protected val largeVarTypes: Boolean + + protected def pythonMetrics: Map[String, SQLMetric] + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[IN]): Unit + + protected def writeUDF( + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]]): Unit = + SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + + protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + // Write config for the worker as a number of key -> value pairs of strings + stream.writeInt(workerConf.size) + for ((k, v) <- workerConf) { + PythonRDD.writeUTF(k, stream) + PythonRDD.writeUTF(v, stream) + } + } + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + handleMetadataBeforeExec(dataOut) + writeUDF(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val arrowSchema = ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + Utils.tryWithSafeFinally { + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + + writeIteratorToArrowStream(root, writer, dataOut, inputIterator) + + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. + root.close() + allocator.close() + } + } + } + } +} + + +private[python] trait SedonaBasicPythonArrowInput extends SedonaPythonArrowInput[Iterator[InternalRow]] { + self: SedonaBasePythonRunner[Iterator[InternalRow], _] => + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[InternalRow]]): Unit = { + val arrowWriter = ArrowWriter.create(root) + + while (inputIterator.hasNext) { + val startData = dataOut.size() + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala new file mode 100644 index 0000000000..91e840da58 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala @@ -0,0 +1,135 @@ +package org.apache.spark.sql.execution.python + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamReader +import org.apache.spark.api.python.{BasePythonRunner, SedonaBasePythonRunner, SpecialLengths} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} +import org.apache.spark.{SparkEnv, TaskContext} + +import java.io.DataInputStream +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ + +/** + * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from + * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch). + */ +private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: SedonaBasePythonRunner[_, OUT] => + + protected def pythonMetrics: Map[String, SQLMetric] + + protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } + + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[OUT] = { + + new ReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def handleEndOfDataSection(): Unit = { + handleMetadataAfterExec(stream) + super.handleEndOfDataSection() + } + + protected override def read(): OUT = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + val bytesReadStart = reader.bytesRead() + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + val rowCount = root.getRowCount + batch.setNumRows(root.getRowCount) + val bytesReadEnd = reader.bytesRead() + pythonMetrics("pythonNumRowsReceived") += rowCount + pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart + deserializeColumnarBatch(batch, schema) + } else { + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null.asInstanceOf[OUT] + } + } + } catch handleException + } + } + } +} + +private[python] trait SedonaBasicPythonArrowOutput extends SedonaPythonArrowOutput[ColumnarBatch] { + self: SedonaBasePythonRunner[_, ColumnarBatch] => + + protected def deserializeColumnarBatch( + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala new file mode 100644 index 0000000000..ced32cf801 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala @@ -0,0 +1,147 @@ +package org.apache.spark.sql.execution.python + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf + +/** + * A helper class to run Python UDFs in Spark. + */ +abstract class SedonaBasePythonUDFRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends SedonaBasePythonRunner[Array[Byte], Array[Byte]]( + funcs, evalType, argOffsets, jobArtifactUUID) { + + override val pythonExec: String = + SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( + funcs.head.funcs.head.pythonExec) + + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + + abstract class SedonaPythonUDFWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext) + extends WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val startData = dataOut.size() + + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + pythonMetrics("pythonDataReceived") += length + obj + case 0 => Array.emptyByteArray + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +class SedonaPythonUDFRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends SedonaBasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new SedonaPythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + } + } +} + +object SedonaPythonUDFRunner { + + def writeUDFs( + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]]): Unit = { + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command.toArray) + } + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala index a403fa6b9e..5883fd905d 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator, EvalPythonExec, PythonSQLMetrics} +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator, EvalPythonExec, PythonSQLMetrics, SedonaArrowPythonRunner} import org.apache.spark.sql.types.StructType import scala.collection.JavaConverters.asScalaIteratorConverter @@ -68,7 +68,7 @@ case class SedonaArrowEvalPythonExec( val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) - val columnarBatchIter = new ArrowPythonRunner( + val columnarBatchIter = new SedonaArrowPythonRunner( funcs, evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, argOffsets, diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index c0a2d8f260..80a4c64106 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.udf import org.apache.sedona.sql.UDF -import org.apache.spark.TestUtils +import org.apache.spark.{SparkEnv, TestUtils} import org.apache.spark.api.python._ import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.config.Python.{PYTHON_DAEMON_MODULE, PYTHON_USE_DAEMON, PYTHON_WORKER_MODULE} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.util.Utils @@ -70,10 +71,11 @@ object ScalarUDF { finally Utils.deleteRecursively(path) } + val additionalModule = "spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf" + val pandasFunc: Array[Byte] = { var binaryPandasFunc: Array[Byte] = null withTempPath { path => - println(path) Process( Seq( pythonExec, @@ -85,6 +87,17 @@ object ScalarUDF { |from pyspark.serializers import CloudPickleSerializer |from sedona.utils import geometry_serde |from shapely import box + |import logging + |logging.basicConfig(level=logging.INFO) + |logger = logging.getLogger(__name__) + |logger.info("Loading Sedona Python UDF") + |import os + |logger.info(os.getcwd()) + |import sys + |import sys + |print("boring stuff") + |sys.path.append('$additionalModule') + |logger.info(sys.path) |f = open('$path', 'wb'); |def w(x): | def apply_function(w): @@ -104,7 +117,9 @@ object ScalarUDF { } private val workerEnv = new java.util.HashMap[String, String]() - workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.worker") + SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) val geoPandasScalaFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( name = "geospatial_udf",
