Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/20531#discussion_r166655962
--- Diff: python/pyspark/sql/udf.py ---
@@ -112,15 +112,31 @@ def returnType(self):
else:
self._returnType_placeholder =
_parse_datatype_string(self._returnType)
- if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
- and not isinstance(self._returnType_placeholder,
StructType):
- raise ValueError("Invalid returnType: returnType must be a
StructType for "
- "pandas_udf with function type GROUPED_MAP")
- elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \
- and isinstance(self._returnType_placeholder, (StructType,
ArrayType, MapType)):
- raise NotImplementedError(
- "ArrayType, StructType and MapType are not supported with "
- "PandasUDFType.GROUPED_AGG")
+ if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+ if isinstance(self._returnType_placeholder, StructType):
+ try:
+ to_arrow_schema(self._returnType_placeholder)
+ except TypeError:
+ raise NotImplementedError(
+ "Invalid returnType with a grouped map Pandas UDF:
"
+ "%s is not supported" %
str(self._returnType_placeholder))
+ else:
+ raise TypeError("Invalid returnType for a grouped map
Pandas "
+ "UDF: returnType must be a StructType.")
+ elif self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
+ try:
+ to_arrow_type(self._returnType_placeholder)
+ except TypeError:
+ raise NotImplementedError(
+ "Invalid returnType with a scalar Pandas UDF: %s is "
+ "not supported" % str(self._returnType_placeholder))
+ elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
+ try:
+ to_arrow_type(self._returnType_placeholder)
+ except TypeError:
+ raise NotImplementedError(
+ "Invalid returnType with a grouped aggregate Pandas
UDF: "
--- End diff --
ditto
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]