You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/10/06 09:48:24 UTC

[spark] branch master updated: [SPARK-45434][ML][CONNECT] LogisticRegression checks the training labels

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 39d43e0ac3b5 [SPARK-45434][ML][CONNECT] LogisticRegression checks the training labels
39d43e0ac3b5 is described below

commit 39d43e0ac3b58fb7e804362bb07665e8d6536250
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Oct 6 17:48:03 2023 +0800

    [SPARK-45434][ML][CONNECT] LogisticRegression checks the training labels
    
    ### What changes were proposed in this pull request?
    
    - checks the training labels
    - get `num_features` together with `num_rows`
    
    ### Why are the changes needed?
    training labels should be in [0, numClasses)
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43246 from zhengruifeng/ml_lr_nit.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/ml/connect/classification.py | 20 +++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/ml/connect/classification.py b/python/pyspark/ml/connect/classification.py
index f8b525db8edd..ca6e01e9577c 100644
--- a/python/pyspark/ml/connect/classification.py
+++ b/python/pyspark/ml/connect/classification.py
@@ -41,7 +41,7 @@ from pyspark.ml.param.shared import (
 )
 from pyspark.ml.connect.base import Predictor, PredictionModel
 from pyspark.ml.connect.io_utils import ParamsReadWrite, CoreModelReadWrite
-from pyspark.sql.functions import lit, count, countDistinct
+from pyspark.sql import functions as sf
 
 import torch
 import torch.nn as torch_nn
@@ -232,18 +232,20 @@ class LogisticRegression(
             num_train_workers
         )
 
-        # TODO: check label values are in range of [0, num_classes)
-        num_rows, num_classes = dataset.agg(
-            count(lit(1)), countDistinct(self.getLabelCol())
+        num_rows, num_features, classes = dataset.select(
+            sf.count(sf.lit(1)),
+            sf.first(sf.array_size(self.getFeaturesCol())),
+            sf.collect_set(self.getLabelCol()),
         ).head()  # type: ignore[misc]
 
-        num_batches_per_worker = math.ceil(num_rows / num_train_workers / batch_size)
-        num_samples_per_worker = num_batches_per_worker * batch_size
-
-        num_features = len(dataset.select(self.getFeaturesCol()).head()[0])  # type: ignore[index]
-
+        num_classes = len(classes)
         if num_classes < 2:
             raise ValueError("Training dataset distinct labels must >= 2.")
+        if any(c not in range(0, num_classes) for c in classes):
+            raise ValueError("Training labels must be integers in [0, numClasses).")
+
+        num_batches_per_worker = math.ceil(num_rows / num_train_workers / batch_size)
+        num_samples_per_worker = num_batches_per_worker * batch_size
 
         # TODO: support GPU.
         distributor = TorchDistributor(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org