Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/22610#discussion_r222173904 --- Diff: python/pyspark/worker.py --- @@ -84,13 +84,36 @@ def wrap_scalar_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): + import pyarrow as pa result = f(*a) if not hasattr(result, "__len__"): raise TypeError("Return type of the user-defined function should be " "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) + + # Ensure return type of Pandas.Series matches the arrow return type of the user-defined + # function. Otherwise, we may produce incorrect serialized data. + # Note: for timestamp type, we only need to ensure both types are timestamp because the + # serializer will do conversion. + try: + arrow_type_of_result = pa.from_numpy_dtype(result.dtype) + both_are_timestamp = pa.types.is_timestamp(arrow_type_of_result) and \ + pa.types.is_timestamp(arrow_return_type) + if not both_are_timestamp and arrow_return_type != arrow_type_of_result: + print("WARN: Arrow type %s of return Pandas.Series of the user-defined function's " + "dtype %s doesn't match the arrow type %s " + "of defined return type %s" % (arrow_type_of_result, result.dtype, + arrow_return_type, return_type), + file=sys.stderr) + except: + print("WARN: Can't infer arrow type of Pandas.Series's dtype: %s, which might not " + "match the arrow type %s of defined return type %s" % (result.dtype, + arrow_return_type, + return_type), --- End diff -- ok. thanks. :-)
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org