Yicong-Huang commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2479144600
##########
python/pyspark/sql/pandas/typehints.py:
##########
@@ -394,6 +394,84 @@ def infer_group_arrow_eval_type_from_func(
return None
+def infer_group_pandas_eval_type(
+ sig: Signature,
+ type_hints: Dict[str, Any],
+) -> Optional[Union["PandasGroupedMapUDFType", "PandasGroupedMapIterUDFType"]]:
+ from pyspark.sql.pandas.functions import PythonEvalType
+
+ require_minimum_pandas_version()
+
+ import pandas as pd
+
+ annotations = {}
+ for param in sig.parameters.values():
+ if param.annotation is not param.empty:
+ annotations[param.name] = type_hints.get(param.name,
param.annotation)
+
+ # Check if all arguments have type hints
+ parameters_sig = [
+ annotations[parameter] for parameter in sig.parameters if parameter in
annotations
+ ]
+ if len(parameters_sig) != len(sig.parameters):
+ raise PySparkValueError(
+ errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+ messageParameters={"target": "all parameters", "sig": str(sig)},
+ )
+
+ # Check if the return has a type hint
+ return_annotation = type_hints.get("return", sig.return_annotation)
+ if sig.empty is return_annotation:
+ raise PySparkValueError(
+ errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+ messageParameters={"target": "the return type", "sig": str(sig)},
+ )
+
+ # Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+ is_iterator_dataframe = (
+ len(parameters_sig) == 1
+ and check_iterator_annotation( # Iterator
+ parameters_sig[0],
+ parameter_check_func=lambda t: t == pd.DataFrame,
+ )
+ and check_iterator_annotation(
+ return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
+ )
+ )
+ # Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+ is_iterator_dataframe_with_keys = (
+ len(parameters_sig) == 2
+ and check_iterator_annotation( # Iterator
+ parameters_sig[1],
+ parameter_check_func=lambda t: t == pd.DataFrame,
+ )
+ and check_iterator_annotation(
+ return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
+ )
+ )
+
+ if is_iterator_dataframe or is_iterator_dataframe_with_keys:
+ return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
+
+ # Default to non-iterator (standard grouped map)
+ return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
Review Comment:
let's discuss and do it with a follow up if needed.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]