You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "WeichenXu123 (via GitHub)" <gi...@apache.org> on 2023/07/10 11:14:38 UTC

[GitHub] [spark] WeichenXu123 commented on a diff in pull request #41881: [WIP][SPARK-43983][PYTHON][ML][CONNECT] Implement cross validator estimator

WeichenXu123 commented on code in PR #41881:
URL: https://github.com/apache/spark/pull/41881#discussion_r1258093956


##########
python/pyspark/ml/connect/tuning.py:
##########
@@ -0,0 +1,596 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import itertools
+from multiprocessing.pool import ThreadPool
+
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+    cast,
+    overload,
+    TYPE_CHECKING,
+)
+
+import numpy as np
+import pandas as pd
+
+from pyspark import keyword_only, since, inheritable_thread_target
+from pyspark.ml.connect import Estimator, Transformer, Model
+from pyspark.ml.connect.base import Evaluator
+from pyspark.ml.connect.io_utils import (
+    MetaAlgorithmReadWrite,
+    CoreModelReadWrite,
+    ParamsReadWrite,
+)
+from pyspark.ml.param import Params, Param, TypeConverters
+from pyspark.ml.param.shared import HasParallelism, HasSeed
+from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
+from pyspark.sql.types import BooleanType
+from pyspark.sql.dataframe import DataFrame
+
+from pyspark.ml.tuning import ParamGridBuilder
+
+if TYPE_CHECKING:
+    from pyspark.ml._typing import ParamMap
+
+
+class _ValidatorParams(HasSeed):
+    """
+    Common params for TrainValidationSplit and CrossValidator.
+    """
+
+    estimator: Param[Estimator] = Param(
+        Params._dummy(), "estimator", "estimator to be cross-validated"
+    )
+    estimatorParamMaps: Param[List["ParamMap"]] = Param(
+        Params._dummy(), "estimatorParamMaps", "estimator param maps"
+    )
+    evaluator: Param[Evaluator] = Param(
+        Params._dummy(),
+        "evaluator",
+        "evaluator used to select hyper-parameters that maximize the validator metric",
+    )
+
+    @since("2.0.0")
+    def getEstimator(self) -> Estimator:
+        """
+        Gets the value of estimator or its default value.
+        """
+        return self.getOrDefault(self.estimator)
+
+    @since("2.0.0")
+    def getEstimatorParamMaps(self) -> List["ParamMap"]:
+        """
+        Gets the value of estimatorParamMaps or its default value.
+        """
+        return self.getOrDefault(self.estimatorParamMaps)
+
+    @since("2.0.0")
+    def getEvaluator(self) -> Evaluator:
+        """
+        Gets the value of evaluator or its default value.
+        """
+        return self.getOrDefault(self.evaluator)
+
+
+class _CrossValidatorParams(_ValidatorParams):
+    """
+    Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.
+
+    .. versionadded:: 3.0.0
+    """
+
+    numFolds: Param[int] = Param(
+        Params._dummy(),
+        "numFolds",
+        "number of folds for cross validation",
+        typeConverter=TypeConverters.toInt,
+    )
+
+    foldCol: Param[str] = Param(
+        Params._dummy(),
+        "foldCol",
+        "Param for the column name of user "
+        + "specified fold number. Once this is specified, :py:class:`CrossValidator` "
+        + "won't do random k-fold split. Note that this column should be integer type "
+        + "with range [0, numFolds) and Spark will throw exception on out-of-range "
+        + "fold numbers.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    def __init__(self, *args: Any):
+        super(_CrossValidatorParams, self).__init__(*args)
+        self._setDefault(numFolds=3, foldCol="")
+
+    @since("1.4.0")
+    def getNumFolds(self) -> int:
+        """
+        Gets the value of numFolds or its default value.
+        """
+        return self.getOrDefault(self.numFolds)
+
+    @since("3.1.0")
+    def getFoldCol(self) -> str:
+        """
+        Gets the value of foldCol or its default value.
+        """
+        return self.getOrDefault(self.foldCol)
+
+
+def _parallelFitTasks(
+    estimator: Estimator,
+    train: DataFrame,
+    evaluator: Evaluator,
+    validation: DataFrame,
+    epm: Sequence["ParamMap"],
+) -> List[Callable[[], Tuple[int, float, Transformer]]]:
+    """
+    Creates a list of callables which can be called from different threads to fit and evaluate
+    an estimator in parallel. Each callable returns an `(index, metric)` pair.
+
+    Parameters
+    ----------
+    est : :py:class:`pyspark.ml.baseEstimator`
+        he estimator to be fit.
+    train : :py:class:`pyspark.sql.DataFrame`
+        DataFrame, training data set, used for fitting.
+    eva : :py:class:`pyspark.ml.evaluation.Evaluator`
+        used to compute `metric`
+    validation : :py:class:`pyspark.sql.DataFrame`
+        DataFrame, validation data set, used for evaluation.
+    epm : :py:class:`collections.abc.Sequence`
+        Sequence of ParamMap, params maps to be used during fitting & evaluation.
+    collectSubModel : bool
+        Whether to collect sub model.
+
+    Returns
+    -------
+    tuple
+        (int, float, subModel), an index into `epm` and the associated metric value.
+    """
+
+    def get_single_task(index, param_map):
+        def single_task() -> Tuple[int, float, Transformer]:
+            model = estimator.fit(train, param_map)
+            metric = evaluator.evaluate(model.transform(validation, param_map))
+            return index, metric
+
+        return single_task
+
+    return [get_single_task(index, param_map) for index, param_map in enumerate(epm)]
+
+
+class _CrossValidatorReadWrite(MetaAlgorithmReadWrite):
+
+    def _get_skip_saving_params(self) -> List[str]:
+        """
+        Returns params to be skipped when saving metadata.
+        """
+        return ["estimator", "estimatorParamMaps", "evaluator"]
+
+    def _save_meta_algorithm(
+            self, root_path: str, node_path: List[str]
+    ) -> Dict[str, Any]:
+        metadata = self._get_metadata_to_save()
+        metadata["estimator"] = self.getEstimator()._save_to_node_path(
+            root_path, node_path + ["crossvalidator_estimator"]
+        )
+        metadata["evaluator"] = self.getEvaluator()._save_to_node_path(
+            root_path, node_path + ["crossvalidator_evaluator"]
+        )
+        metadata["estimator_param_maps"] = [
+            [
+                {"parent": param.parent, "name": param.name, "value": value}
+                for param, value in param_map.items()
+            ]
+            for param_map in self.getEstimatorParamMaps()
+        ]
+
+        if isinstance(self, CrossValidatorModel):
+            metadata["avg_metrics"] = self.avgMetrics
+            metadata["std_metrics"] = self.stdMetrics
+
+            metadata["best_model"] = self.bestModel._save_to_node_path(
+                root_path, node_path + ["crossvalidator_best_model"]
+            )
+        return metadata
+
+    def _load_meta_algorithm(
+            self, root_path: str, node_metadata: Dict[str, Any]
+    ) -> None:
+        estimator = ParamsReadWrite._load_instance_from_metadata(node_metadata["estimator"])
+        self.set(self.estimator, estimator)
+
+        evaluator = ParamsReadWrite._load_instance_from_metadata(node_metadata["evaluator"])
+        self.set(self.evaluator, evaluator)
+
+        json_epm = node_metadata["estimator_param_maps"]
+
+        uid_to_instances = MetaAlgorithmReadWrite.get_uid_map(estimator)
+
+        epm = []
+        for json_param_map in json_epm:
+            param_map = {}
+            for json_param in json_param_map:
+                est = uid_to_instances[json_param["parent"]]
+                param = getattr(est, json_param["name"])
+                value = json_param["value"]
+                param_map[param] = value
+            epm.append(param_map)
+
+        self.set(self.estimatorParamMaps, epm)
+
+        if isinstance(self, CrossValidatorModel):
+            self.avgMetrics = node_metadata["avg_metrics"]
+            self.stdMetrics = node_metadata["std_metrics"]
+
+            self.bestModel = ParamsReadWrite._load_instance_from_metadata(
+                node_metadata["best_model"]
+            )
+
+
+class CrossValidator(
+    Estimator["CrossValidatorModel"],
+    _CrossValidatorParams,
+    HasParallelism,
+    _CrossValidatorReadWrite,
+):
+    """
+
+    K-fold cross validation performs model selection by splitting the dataset into a set of
+    non-overlapping randomly partitioned folds which are used as separate training and test datasets
+    e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
+    each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
+    test set exactly once.
+
+    .. versionadded:: 1.4.0
+
+    Examples
+    --------
+    >>> from pyspark.ml.classification import LogisticRegression

Review Comment:
   removed !



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org