This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 39d43e0ac3b5 [SPARK-45434][ML][CONNECT] LogisticRegression checks the
training labels
39d43e0ac3b5 is described below
commit 39d43e0ac3b58fb7e804362bb07665e8d6536250
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Oct 6 17:48:03 2023 +0800
[SPARK-45434][ML][CONNECT] LogisticRegression checks the training labels
### What changes were proposed in this pull request?
- checks the training labels
- get `num_features` together with `num_rows`
### Why are the changes needed?
training labels should be in [0, numClasses)
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43246 from zhengruifeng/ml_lr_nit.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/ml/connect/classification.py | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/python/pyspark/ml/connect/classification.py
b/python/pyspark/ml/connect/classification.py
index f8b525db8edd..ca6e01e9577c 100644
--- a/python/pyspark/ml/connect/classification.py
+++ b/python/pyspark/ml/connect/classification.py
@@ -41,7 +41,7 @@ from pyspark.ml.param.shared import (
)
from pyspark.ml.connect.base import Predictor, PredictionModel
from pyspark.ml.connect.io_utils import ParamsReadWrite, CoreModelReadWrite
-from pyspark.sql.functions import lit, count, countDistinct
+from pyspark.sql import functions as sf
import torch
import torch.nn as torch_nn
@@ -232,18 +232,20 @@ class LogisticRegression(
num_train_workers
)
- # TODO: check label values are in range of [0, num_classes)
- num_rows, num_classes = dataset.agg(
- count(lit(1)), countDistinct(self.getLabelCol())
+ num_rows, num_features, classes = dataset.select(
+ sf.count(sf.lit(1)),
+ sf.first(sf.array_size(self.getFeaturesCol())),
+ sf.collect_set(self.getLabelCol()),
).head() # type: ignore[misc]
- num_batches_per_worker = math.ceil(num_rows / num_train_workers /
batch_size)
- num_samples_per_worker = num_batches_per_worker * batch_size
-
- num_features = len(dataset.select(self.getFeaturesCol()).head()[0]) #
type: ignore[index]
-
+ num_classes = len(classes)
if num_classes < 2:
raise ValueError("Training dataset distinct labels must >= 2.")
+ if any(c not in range(0, num_classes) for c in classes):
+ raise ValueError("Training labels must be integers in [0,
numClasses).")
+
+ num_batches_per_worker = math.ceil(num_rows / num_train_workers /
batch_size)
+ num_samples_per_worker = num_batches_per_worker * batch_size
# TODO: support GPU.
distributor = TorchDistributor(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]