Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/20531#discussion_r166831572
--- Diff: python/pyspark/sql/udf.py ---
@@ -112,15 +112,31 @@ def returnType(self):
else:
self._returnType_placeholder =
_parse_datatype_string(self._returnType)
- if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
- and not isinstance(self._returnType_placeholder,
StructType):
- raise ValueError("Invalid returnType: returnType must be a
StructType for "
- "pandas_udf with function type GROUPED_MAP")
- elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \
- and isinstance(self._returnType_placeholder, (StructType,
ArrayType, MapType)):
- raise NotImplementedError(
- "ArrayType, StructType and MapType are not supported with "
- "PandasUDFType.GROUPED_AGG")
+ if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
--- End diff --
nit: I'd prefer to keep the check order by the definition in
`PythonEvalType` if you don't have a special reason.
E.g.,
```
if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
...
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
...
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
...
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]