dtenedor commented on code in PR #42420:
URL: https://github.com/apache/spark/pull/42420#discussion_r1293778050
##########
python/pyspark/worker.py:
##########
@@ -621,22 +663,30 @@ def verify_result(result):
return lambda *a: map(lambda res: (res, arrow_return_type),
map(verify_result, f(*a)))
- eval = wrap_arrow_udtf(getattr(udtf, "eval"), return_type)
-
- if hasattr(udtf, "terminate"):
- terminate = wrap_arrow_udtf(getattr(udtf, "terminate"),
return_type)
- else:
- terminate = None
+ udtf_state.eval = wrap_arrow_udtf(getattr(udtf_state.udtf, "eval"),
return_type)
+ udtf_state.set_terminate(wrap_arrow_udtf, return_type)
def mapper(_, it):
try:
for a in it:
# The eval function yields an iterator. Each element
produced by this
# iterator is a tuple in the form of (pandas.DataFrame,
arrow_return_type).
- yield from eval(*[a[o] for o in arg_offsets])
+ arguments = [a[o] for o in arg_offsets]
+ changed_partitions = check_partition_boundaries(arguments)
+ if changed_partitions:
+ # Call 'terminate' on the UDTF class instance, if
applicable.
+ # Then destroy the UDTF class instance and create a
new one.
+ if udtf_state.terminate is not None:
+ yield from udtf_state.terminate()
+ create_udtf_classs_instance(return_type)
Review Comment:
Good catch, fixed!
##########
python/pyspark/worker.py:
##########
@@ -564,23 +566,66 @@ def read_udtf(pickleSer, infile, eval_type):
f"The return type of a UDTF must be a struct type, but got
{type(return_type)}."
)
+ # This class holds the UDTF class instance and associated state as we
evaluate rows.
+ # We keep this state in a class to simplify updating it as needed within
nested
+ # function calls, such as when we destroy the UDTF class instance when
partition
+ # boundaries change and then create a new one.
+ class UdtfState:
+ def __init__(self):
+ self.udtf = None
+ self.prev_arguments = None
+ self.eval = None
+ self.terminate = None
+
+ def set_terminate(self, wrap_udtf, return_type):
+ if hasattr(self.udtf, "terminate"):
+ self.terminate = wrap_udtf(getattr(self.udtf, "terminate"),
return_type)
+ else:
+ self.terminate = None
+
+ udtf_state = UdtfState()
+
# Instantiate the UDTF class.
- try:
- udtf = handler()
- except Exception as e:
- raise PySparkRuntimeError(
- error_class="UDTF_EXEC_ERROR",
- message_parameters={"method_name": "__init__", "error": str(e)},
- )
+ def create_udtf_class_instance(return_type):
+ try:
+ if udtf_state.udtf is not None:
+ del udtf_state.udtf
+ udtf_state.udtf = handler()
+ udtf_state.prev_arguments = None
+ except Exception as e:
+ raise PySparkRuntimeError(
+ error_class="UDTF_EXEC_ERROR",
+ message_parameters={"method_name": "__init__", "error":
str(e)},
+ )
+
+ create_udtf_class_instance(return_type)
# Validate the UDTF
- if not hasattr(udtf, "eval"):
+ if not hasattr(udtf_state.udtf, "eval"):
raise PySparkRuntimeError(
"Failed to execute the user defined table function because it has
not "
"implemented the 'eval' method. Please add the 'eval' method and
try "
"the query again."
)
+ # Inspects the values of the projected PARTITION BY expressions, if any.
Returns true when
+ # these values change, in which case the caller should invoke the
'terminate' method on the
+ # UDTF class instance and then destroy it and creates a new one to
implement the desired
+ # partitioning semantics.
+ def check_partition_boundaries(arguments):
+ if num_partition_child_indexes == 0 or udtf_state.prev_arguments is
None:
+ udtf_state.prev_arguments = arguments
+ return False
+ cur_table_arg = [arg for arg in arguments if type(arg) is Row][0]
+ prev_table_arg = [arg for arg in udtf_state.prev_arguments if
type(arg) is Row][0]
+ cur_partition_values = [cur_table_arg[i] for i in
partition_child_indexes]
+ prev_partition_values = [prev_table_arg[i] for i in
partition_child_indexes]
+ udtf_state.prev_arguments = arguments
Review Comment:
Reviewing the code again, you are right. We should be manually setting
`udtf_state.prev_arguments = None` at the start just once, and then not doing
that in `create_udtf_class_instance`. I made this change.
I tried to think of a unit test that would have caught this bug, but
`test_udtf_with_table_argument_and_partition_by_and_order_by` should have
already caught it; I am not sure why it didn't before. Leaving that test as-is
for now.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]