leewyang commented on code in PR #37734:
URL: https://github.com/apache/spark/pull/37734#discussion_r1014476080


##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +117,542 @@ def array_to_vector(col: Column) -> Column:
     return 
Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        index = 0
+        data_size = len(data)
+        while index < data_size:
+            yield data.iloc[index : index + batch_size]
+            index += batch_size
+    else:
+        # convert (tuple of) pd.Series into pd.DataFrame
+        if isinstance(data, pd.Series):
+            df = pd.concat((data,), axis=1)
+        else:  # isinstance(data, Tuple[pd.Series]):
+            df = pd.concat(data, axis=1)
+
+        index = 0
+        data_size = len(df)
+        while index < data_size:
+            yield df.iloc[index : index + batch_size]
+            index += batch_size
+
+
+def _is_tensor_col(data: pd.Series | pd.DataFrame) -> bool:
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], 
(np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:
+        raise ValueError(
+            "Unexpected data type: {}, expected pd.Series or 
pd.DataFrame.".format(type(data))
+        )
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> 
bool:
+    """Check if input Series/DataFrame/Tuple contains any tensor-valued 
columns."""
+    if isinstance(data, (pd.Series, pd.DataFrame)):
+        return _is_tensor_col(data)
+    else:  # isinstance(data, Tuple):
+        return any(_is_tensor_col(elem) for elem in data)
+
+
+def _validate(
+    preds: np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, Any]],
+    num_input_rows: int,
+    return_type: DataType,
+) -> None:
+    """Validate model predictions against the expected pandas_udf 
return_type."""
+    if isinstance(return_type, StructType):
+        struct_rtype: StructType = return_type
+        fieldNames = struct_rtype.names
+        if isinstance(preds, dict):
+            # dictionary of columns
+            predNames = list(preds.keys())
+            if not all(v.shape == (num_input_rows,) for v in preds.values()):
+                raise ValueError("Prediction results for StructType fields 
must be scalars.")
+        elif isinstance(preds, list) and isinstance(preds[0], dict):
+            # rows of dictionaries
+            predNames = list(preds[0].keys())
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as 
input data.")

Review Comment:
   Yeah, @mengxr had the same question.  Some models, e.g. [Huggingface 
pipeline for sentiment 
analysis](https://huggingface.co/docs/transformers/quicktour#pipeline-usage), 
can produce results in this format, so he agreed to keep this case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to