Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/18732#discussion_r142498841
--- Diff: python/pyspark/worker.py ---
@@ -74,17 +75,33 @@ def wrap_udf(f, return_type):
def wrap_pandas_udf(f, return_type):
- arrow_return_type = toArrowType(return_type)
-
- def verify_result_length(*a):
- result = f(*a)
- if not hasattr(result, "__len__"):
- raise TypeError("Return type of pandas_udf should be a
Pandas.Series")
- if len(result) != len(a[0]):
- raise RuntimeError("Result vector from pandas_udf was not the
required length: "
- "expected %d, got %d" % (len(a[0]),
len(result)))
- return result
- return lambda *a: (verify_result_length(*a), arrow_return_type)
+ if isinstance(return_type, StructType):
+ arrow_return_types = list(to_arrow_type(field.dataType) for field
in return_type)
+
+ def fn(*a):
+ import pandas as pd
+ out = f(*a)
+ assert isinstance(out, pd.DataFrame), 'Must return a
pd.DataFrame'
+ assert len(out.columns) == len(arrow_return_types), \
+ 'Columns of pd.DataFrame don\'t match return schema'
--- End diff --
The result df actually don't have length required- it could be of different
length.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]