zhengruifeng commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2476958936


##########
python/pyspark/sql/pandas/typehints.py:
##########
@@ -394,6 +394,84 @@ def infer_group_arrow_eval_type_from_func(
         return None
 
 
+def infer_group_pandas_eval_type(
+    sig: Signature,
+    type_hints: Dict[str, Any],
+) -> Optional[Union["PandasGroupedMapUDFType", "PandasGroupedMapIterUDFType"]]:
+    from pyspark.sql.pandas.functions import PythonEvalType
+
+    require_minimum_pandas_version()
+
+    import pandas as pd
+
+    annotations = {}
+    for param in sig.parameters.values():
+        if param.annotation is not param.empty:
+            annotations[param.name] = type_hints.get(param.name, 
param.annotation)
+
+    # Check if all arguments have type hints
+    parameters_sig = [
+        annotations[parameter] for parameter in sig.parameters if parameter in 
annotations
+    ]
+    if len(parameters_sig) != len(sig.parameters):
+        raise PySparkValueError(
+            errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+            messageParameters={"target": "all parameters", "sig": str(sig)},
+        )
+
+    # Check if the return has a type hint
+    return_annotation = type_hints.get("return", sig.return_annotation)
+    if sig.empty is return_annotation:
+        raise PySparkValueError(
+            errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+            messageParameters={"target": "the return type", "sig": str(sig)},
+        )
+
+    # Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+    is_iterator_dataframe = (
+        len(parameters_sig) == 1
+        and check_iterator_annotation(  # Iterator
+            parameters_sig[0],
+            parameter_check_func=lambda t: t == pd.DataFrame,
+        )
+        and check_iterator_annotation(
+            return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
+        )
+    )
+    # Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+    is_iterator_dataframe_with_keys = (
+        len(parameters_sig) == 2
+        and check_iterator_annotation(  # Iterator
+            parameters_sig[1],
+            parameter_check_func=lambda t: t == pd.DataFrame,
+        )
+        and check_iterator_annotation(
+            return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
+        )
+    )
+
+    if is_iterator_dataframe or is_iterator_dataframe_with_keys:
+        return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
+
+    # Default to non-iterator (standard grouped map)
+    return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF

Review Comment:
   ```suggestion
       if is_iterator_dataframe or is_iterator_dataframe_with_keys:
           return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
   
       # Default to non-iterator (standard grouped map)
       return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
   ```
   
   this part should match 
https://github.com/apache/spark/blob/9e122017df62b5588693eaaeb7d59225370ed1ec/python/pyspark/sql/pandas/typehints.py#L368-L379
   
   we can align it in a followup



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to