Github user BryanCutler commented on a diff in the pull request:
https://github.com/apache/spark/pull/22610#discussion_r222501309
--- Diff: python/pyspark/worker.py ---
@@ -84,13 +84,36 @@ def wrap_scalar_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
def verify_result_length(*a):
+ import pyarrow as pa
result = f(*a)
if not hasattr(result, "__len__"):
raise TypeError("Return type of the user-defined function
should be "
"Pandas.Series, but is
{}".format(type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the
required length: "
"expected %d, got %d" % (len(a[0]),
len(result)))
+
+ # Ensure return type of Pandas.Series matches the arrow return
type of the user-defined
+ # function. Otherwise, we may produce incorrect serialized data.
+ # Note: for timestamp type, we only need to ensure both types are
timestamp because the
+ # serializer will do conversion.
+ try:
+ arrow_type_of_result = pa.from_numpy_dtype(result.dtype)
+ both_are_timestamp =
pa.types.is_timestamp(arrow_type_of_result) and \
+ pa.types.is_timestamp(arrow_return_type)
+ if not both_are_timestamp and arrow_return_type !=
arrow_type_of_result:
+ print("WARN: Arrow type %s of return Pandas.Series of the
user-defined function's "
--- End diff --
Yeah, it might be useful to see the warning if doing some local tests etc.
My only concern is that users might be confused why they see a warning locally,
but doesn't appear in logs.. Man, it would be nice to have some proper python
logging for this!
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]