WeichenXu123 commented on code in PR #37734:
URL: https://github.com/apache/spark/pull/37734#discussion_r1015423592


##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +117,543 @@ def array_to_vector(col: Column) -> Column:
     return 
Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        df = data
+    elif isinstance(data, pd.Series):
+        df = pd.concat((data,), axis=1)
+    else:  # isinstance(data, Tuple[pd.Series]):
+        df = pd.concat(data, axis=1)
+
+    index = 0
+    data_size = len(df)
+    while index < data_size:
+        yield df.iloc[index : index + batch_size]
+        index += batch_size
+
+
+def _is_tensor_col(data: pd.Series | pd.DataFrame) -> bool:
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], 
(np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:
+        raise ValueError(
+            "Unexpected data type: {}, expected pd.Series or 
pd.DataFrame.".format(type(data))
+        )
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> 
bool:
+    """Check if input Series/DataFrame/Tuple contains any tensor-valued 
columns."""
+    if isinstance(data, (pd.Series, pd.DataFrame)):
+        return _is_tensor_col(data)
+    else:  # isinstance(data, Tuple):
+        return any(_is_tensor_col(elem) for elem in data)
+
+
+def _validate_and_transform(
+    preds: np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, Any]],
+    num_input_rows: int,
+    return_type: DataType,
+) -> pd.DataFrame | pd.Series:
+    """Validate numpy-based model predictions against the expected pandas_udf 
return_type and
+    transforms the predictions into an equivalent pandas DataFrame or 
Series."""
+    if isinstance(return_type, StructType):
+        struct_rtype: StructType = return_type
+        fieldNames = struct_rtype.names
+        if isinstance(preds, dict):
+            # dictionary of columns
+            predNames = list(preds.keys())
+            for field in struct_rtype.fields:
+                if len(preds[field.name]) != num_input_rows:
+                    raise ValueError("Prediction results must have same length 
as input data.")
+                if field.dataType == ArrayType and preds[field.name].shape != 
2:
+                    raise ValueError("Prediction results for ArrayType must be 
two-dimensional.")
+
+        elif isinstance(preds, list) and isinstance(preds[0], dict):
+            # rows of dictionaries
+            predNames = list(preds[0].keys())
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as 
input data.")
+        else:
+            raise ValueError(
+                "Prediction results for StructType must be a dictionary or "
+                "a list of dictionary, got: {}".format(type(preds))
+            )
+
+        # check column names
+        if set(predNames) != set(fieldNames):
+            raise ValueError(
+                "Prediction result columns did not match expected return_type "
+                "columns: expected {}, got: {}".format(fieldNames, predNames)
+            )
+
+        return pd.DataFrame(preds)
+    elif isinstance(return_type, ArrayType):
+        if isinstance(preds, np.ndarray):
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as 
input data.")
+            if len(preds.shape) != 2:
+                raise ValueError("Prediction results for ArrayType must be 
two-dimensional.")
+        else:
+            raise ValueError("Prediction results for ArrayType must be an 
ndarray.")
+
+        return pd.Series(list(preds))
+    else:  # scalar
+        if len(preds) != num_input_rows:
+            raise ValueError("Prediction results must have same length as 
input data.")
+
+        return pd.Series(np.squeeze(preds))  # type: ignore
+
+
+def predict_batch_udf(
+    predict_batch_fn: Callable[
+        [],
+        Callable[
+            [np.ndarray | List[np.ndarray]],
+            np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, 
np.dtype]],
+        ],
+    ],
+    *,
+    return_type: DataType,
+    batch_size: int,
+    input_tensor_shapes: List[List[int] | None] | Mapping[int, List[int]] | 
None = None,
+) -> UserDefinedFunctionLike:
+    """Given a function which loads a model, returns a pandas_udf for 
inferencing over that model.
+
+    This will handle:
+    - conversion of the Spark DataFrame to numpy arrays.
+    - batching of the inputs sent to the model predict() function.
+    - caching of the model and prediction function on the executors.
+
+    This assumes that the `predict_batch_fn` encapsulates all of the necessary 
dependencies for
+    running the model or the Spark executor environment already satisfies all 
runtime requirements.
+
+    For the conversion of Spark DataFrame to numpy, the following table 
describes the behavior,
+    where tensor columns in the Spark DataFrame must be represented as a 
flattened 1-D array/list.
+
+    | dataframe \\ model | single input | multiple inputs |
+    | :----------------- | :----------- | :-------------- |
+    | single-col scalar  | 1            | N/A             |
+    | single-col tensor  | 1,2          | N/A             |
+    | multi-col scalar   | 3            | 4               |
+    | multi-col tensor   | N/A          | 4,2             |
+
+    Notes:
+    1. pass thru dataframe column => model input as single numpy array.

Review Comment:
   typo: pass through ? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to