You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/10/06 09:48:24 UTC
[spark] branch master updated: [SPARK-45434][ML][CONNECT] LogisticRegression checks the training labels
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 39d43e0ac3b5 [SPARK-45434][ML][CONNECT] LogisticRegression checks the training labels
39d43e0ac3b5 is described below
commit 39d43e0ac3b58fb7e804362bb07665e8d6536250
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Oct 6 17:48:03 2023 +0800
[SPARK-45434][ML][CONNECT] LogisticRegression checks the training labels
### What changes were proposed in this pull request?
- checks the training labels
- get `num_features` together with `num_rows`
### Why are the changes needed?
training labels should be in [0, numClasses)
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43246 from zhengruifeng/ml_lr_nit.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/ml/connect/classification.py | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/python/pyspark/ml/connect/classification.py b/python/pyspark/ml/connect/classification.py
index f8b525db8edd..ca6e01e9577c 100644
--- a/python/pyspark/ml/connect/classification.py
+++ b/python/pyspark/ml/connect/classification.py
@@ -41,7 +41,7 @@ from pyspark.ml.param.shared import (
)
from pyspark.ml.connect.base import Predictor, PredictionModel
from pyspark.ml.connect.io_utils import ParamsReadWrite, CoreModelReadWrite
-from pyspark.sql.functions import lit, count, countDistinct
+from pyspark.sql import functions as sf
import torch
import torch.nn as torch_nn
@@ -232,18 +232,20 @@ class LogisticRegression(
num_train_workers
)
- # TODO: check label values are in range of [0, num_classes)
- num_rows, num_classes = dataset.agg(
- count(lit(1)), countDistinct(self.getLabelCol())
+ num_rows, num_features, classes = dataset.select(
+ sf.count(sf.lit(1)),
+ sf.first(sf.array_size(self.getFeaturesCol())),
+ sf.collect_set(self.getLabelCol()),
).head() # type: ignore[misc]
- num_batches_per_worker = math.ceil(num_rows / num_train_workers / batch_size)
- num_samples_per_worker = num_batches_per_worker * batch_size
-
- num_features = len(dataset.select(self.getFeaturesCol()).head()[0]) # type: ignore[index]
-
+ num_classes = len(classes)
if num_classes < 2:
raise ValueError("Training dataset distinct labels must >= 2.")
+ if any(c not in range(0, num_classes) for c in classes):
+ raise ValueError("Training labels must be integers in [0, numClasses).")
+
+ num_batches_per_worker = math.ceil(num_rows / num_train_workers / batch_size)
+ num_samples_per_worker = num_batches_per_worker * batch_size
# TODO: support GPU.
distributor = TorchDistributor(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org