You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/06/19 13:36:40 UTC

[spark] branch master updated: [SPARK-43982][ML][PYTHON][CONNECT] Implement pipeline estimator for ML on spark connect

This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c0c226d901 [SPARK-43982][ML][PYTHON][CONNECT] Implement pipeline estimator for ML on spark connect
6c0c226d901 is described below

commit 6c0c226d90192e54a4965b6d69905936619e20d6
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Mon Jun 19 21:36:21 2023 +0800

    [SPARK-43982][ML][PYTHON][CONNECT] Implement pipeline estimator for ML on spark connect
    
    ### What changes were proposed in this pull request?
    
    Implement pipeline estimator for ML on spark connect
    
    ### Why are the changes needed?
    
    See Distributed ML <> spark connect project design doc:
    https://docs.google.com/document/d/1LHzwCjm2SluHkta_08cM3jxFSgfF-niaCZbtIThG-H8/edit#heading=h.x8uc4xogrzbk
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. New estimator `pyspark.mlv2.pipeline.Pipeline` is added.
    
    ### How was this patch tested?
    
    Unit tests.
    
    Closes #41479 from WeichenXu123/mlv2-pipeline.
    
    Authored-by: Weichen Xu <we...@databricks.com>
    Signed-off-by: Weichen Xu <we...@databricks.com>
---
 python/pyspark/mlv2/__init__.py            |   4 +
 python/pyspark/mlv2/classification.py      |   6 +-
 python/pyspark/mlv2/feature.py             |   6 +-
 python/pyspark/mlv2/io_utils.py            | 187 ++++++++++++++--------
 python/pyspark/mlv2/pipeline.py            | 241 +++++++++++++++++++++++++++++
 python/pyspark/mlv2/tests/test_pipeline.py | 184 ++++++++++++++++++++++
 6 files changed, 561 insertions(+), 67 deletions(-)

diff --git a/python/pyspark/mlv2/__init__.py b/python/pyspark/mlv2/__init__.py
index 990b4fa9da8..352d24baabe 100644
--- a/python/pyspark/mlv2/__init__.py
+++ b/python/pyspark/mlv2/__init__.py
@@ -26,6 +26,8 @@ from pyspark.mlv2 import (
     evaluation,
 )
 
+from pyspark.mlv2.pipeline import Pipeline, PipelineModel
+
 __all__ = [
     "Estimator",
     "Transformer",
@@ -33,4 +35,6 @@ __all__ = [
     "Model",
     "feature",
     "evaluation",
+    "Pipeline",
+    "PipelineModel",
 ]
diff --git a/python/pyspark/mlv2/classification.py b/python/pyspark/mlv2/classification.py
index fe0d76837f9..522c54b5289 100644
--- a/python/pyspark/mlv2/classification.py
+++ b/python/pyspark/mlv2/classification.py
@@ -40,7 +40,7 @@ from pyspark.ml.param.shared import (
     HasMomentum,
 )
 from pyspark.mlv2.base import Predictor, PredictionModel
-from pyspark.mlv2.io_utils import ParamsReadWrite, ModelReadWrite
+from pyspark.mlv2.io_utils import ParamsReadWrite, CoreModelReadWrite
 from pyspark.sql.functions import lit, count, countDistinct
 
 import torch
@@ -253,7 +253,9 @@ class LogisticRegression(
 
 
 @inherit_doc
-class LogisticRegressionModel(PredictionModel, _LogisticRegressionParams, ModelReadWrite):
+class LogisticRegressionModel(
+    PredictionModel, _LogisticRegressionParams, ParamsReadWrite, CoreModelReadWrite
+):
     """
     Model fitted by LogisticRegression.
 
diff --git a/python/pyspark/mlv2/feature.py b/python/pyspark/mlv2/feature.py
index 57c6213d2bb..a58f214711c 100644
--- a/python/pyspark/mlv2/feature.py
+++ b/python/pyspark/mlv2/feature.py
@@ -24,7 +24,7 @@ from pyspark import keyword_only
 from pyspark.sql import DataFrame
 from pyspark.ml.param.shared import HasInputCol, HasOutputCol
 from pyspark.mlv2.base import Estimator, Model
-from pyspark.mlv2.io_utils import ParamsReadWrite, ModelReadWrite
+from pyspark.mlv2.io_utils import ParamsReadWrite, CoreModelReadWrite
 from pyspark.mlv2.summarizer import summarize_dataframe
 
 
@@ -61,7 +61,7 @@ class MaxAbsScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite):
         return self._copyValues(model)
 
 
-class MaxAbsScalerModel(Model, HasInputCol, HasOutputCol, ModelReadWrite):
+class MaxAbsScalerModel(Model, HasInputCol, HasOutputCol, ParamsReadWrite, CoreModelReadWrite):
     def __init__(
         self, max_abs_values: Optional["np.ndarray"] = None, n_samples_seen: Optional[int] = None
     ) -> None:
@@ -143,7 +143,7 @@ class StandardScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite):
         return self._copyValues(model)
 
 
-class StandardScalerModel(Model, HasInputCol, HasOutputCol, ModelReadWrite):
+class StandardScalerModel(Model, HasInputCol, HasOutputCol, ParamsReadWrite, CoreModelReadWrite):
     def __init__(
         self,
         mean_values: Optional["np.ndarray"] = None,
diff --git a/python/pyspark/mlv2/io_utils.py b/python/pyspark/mlv2/io_utils.py
index 8f7263206a7..c701736712f 100644
--- a/python/pyspark/mlv2/io_utils.py
+++ b/python/pyspark/mlv2/io_utils.py
@@ -21,7 +21,8 @@ import os
 import tempfile
 import time
 from urllib.parse import urlparse
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List
+from pyspark.ml.base import Params
 from pyspark.ml.util import _get_active_session
 from pyspark.sql.utils import is_remote
 
@@ -56,43 +57,6 @@ def _copy_dir_from_local_to_fs(local_path: str, dest_path: str) -> None:
         _copy_file_from_local_to_fs(file_path, dest_file_path)
 
 
-def _get_metadata_to_save(
-    instance: Any,
-    extra_metadata: Optional[Dict[str, Any]] = None,
-) -> Dict[str, Any]:
-    """
-    Extract metadata of Estimator / Transformer / Model / Evaluator instance.
-    """
-    uid = instance.uid
-    cls = instance.__module__ + "." + instance.__class__.__name__
-
-    # User-supplied param values
-    params = instance._paramMap
-    json_params = {}
-    for p in params:
-        json_params[p.name] = params[p]
-
-    # Default param values
-    json_default_params = {}
-    for p in instance._defaultParamMap:
-        json_default_params[p.name] = instance._defaultParamMap[p]
-
-    metadata = {
-        "class": cls,
-        "timestamp": int(round(time.time() * 1000)),
-        "sparkVersion": pyspark_version,
-        "uid": uid,
-        "paramMap": json_params,
-        "defaultParamMap": json_default_params,
-        "type": "spark_connect",
-    }
-    if extra_metadata is not None:
-        assert isinstance(extra_metadata, dict)
-        metadata["extra"] = extra_metadata
-
-    return metadata
-
-
 def _get_class(clazz: str) -> Any:
     """
     Loads Python class from its name.
@@ -103,7 +67,7 @@ def _get_class(clazz: str) -> Any:
     return getattr(m, parts[-1])
 
 
-class ParamsReadWrite:
+class ParamsReadWrite(Params):
     """
     The base interface Estimator / Transformer / Model / Evaluator needs to inherit
     for supporting saving and loading.
@@ -115,18 +79,72 @@ class ParamsReadWrite:
         """
         return None
 
+    def _get_skip_saving_params(self) -> List[str]:
+        """
+        Returns params to be skipped when saving metadata.
+        """
+        return []
+
+    def _get_metadata_to_save(self) -> Dict[str, Any]:
+        """
+        Extract metadata of Estimator / Transformer / Model / Evaluator instance.
+        """
+        extra_metadata = self._get_extra_metadata()
+        skipped_params = self._get_skip_saving_params()
+
+        uid = self.uid
+        cls = self.__module__ + "." + self.__class__.__name__
+
+        # User-supplied param values
+        params = self._paramMap
+        json_params = {}
+        skipped_params = skipped_params or []
+        for p in params:
+            if p.name not in skipped_params:
+                json_params[p.name] = params[p]
+
+        # Default param values
+        json_default_params = {}
+        for p in self._defaultParamMap:
+            json_default_params[p.name] = self._defaultParamMap[p]
+
+        metadata = {
+            "class": cls,
+            "timestamp": int(round(time.time() * 1000)),
+            "sparkVersion": pyspark_version,
+            "uid": uid,
+            "paramMap": json_params,
+            "defaultParamMap": json_default_params,
+            "type": "spark_connect",
+        }
+        if extra_metadata is not None:
+            assert isinstance(extra_metadata, dict)
+            metadata["extra"] = extra_metadata
+
+        return metadata
+
     def _load_extra_metadata(self, metadata: Dict[str, Any]) -> None:
         """
         Load extra metadata attribute from metadata json object.
         """
         pass
 
+    def _save_to_local(self, path: str) -> None:
+        metadata = self._get_metadata_to_save()
+        if isinstance(self, CoreModelReadWrite):
+            core_model_path = self._get_core_model_filename()
+            self._save_core_model(os.path.join(path, core_model_path))
+            metadata["core_model_path"] = core_model_path
+
+        with open(os.path.join(path, _META_DATA_FILE_NAME), "w") as fp:
+            json.dump(metadata, fp)
+
     def saveToLocal(self, path: str, *, overwrite: bool = False) -> None:
         """
-        Save model to provided local path.
-        """
-        metadata = _get_metadata_to_save(self, extra_metadata=self._get_extra_metadata())
+        Save Estimator / Transformer / Model / Evaluator to provided local path.
 
+        .. versionadded:: 3.5.0
+        """
         if os.path.exists(path):
             if overwrite:
                 if os.path.isdir(path):
@@ -137,19 +155,15 @@ class ParamsReadWrite:
                 raise ValueError(f"The path {path} already exists.")
 
         os.makedirs(path)
-        with open(os.path.join(path, _META_DATA_FILE_NAME), "w") as fp:
-            json.dump(metadata, fp)
+        self._save_to_local(path)
 
     @classmethod
-    def loadFromLocal(cls, path: str) -> Any:
-        """
-        Load model from provided local path.
-        """
-        with open(os.path.join(path, _META_DATA_FILE_NAME), "r") as fp:
-            metadata = json.load(fp)
-
+    def _load_from_metadata(cls, metadata: Dict[str, Any]) -> "Params":
         if "type" not in metadata or metadata["type"] != "spark_connect":
-            raise RuntimeError("The model is not saved by spark ML under spark connect mode.")
+            raise RuntimeError(
+                "The saved data is not saved by ML algorithm implemented in 'pyspark.ml.connect' "
+                "module."
+            )
 
         class_name = metadata["class"]
         instance = _get_class(class_name)()
@@ -169,7 +183,34 @@ class ParamsReadWrite:
             instance._load_extra_metadata(metadata["extra"])
         return instance
 
+    @classmethod
+    def _load_from_local(cls, path: str) -> "Params":
+        with open(os.path.join(path, _META_DATA_FILE_NAME), "r") as fp:
+            metadata = json.load(fp)
+
+        instance = cls._load_from_metadata(metadata)
+
+        if isinstance(instance, CoreModelReadWrite):
+            core_model_path = metadata["core_model_path"]
+            instance._load_core_model(os.path.join(path, core_model_path))
+
+        return instance
+
+    @classmethod
+    def loadFromLocal(cls, path: str) -> "Params":
+        """
+        Load Estimator / Transformer / Model / Evaluator from provided local path.
+
+        .. versionadded:: 3.5.0
+        """
+        return cls._load_from_local(path)
+
     def save(self, path: str, *, overwrite: bool = False) -> None:
+        """
+        Save Estimator / Transformer / Model / Evaluator to provided cloud storage path.
+
+        .. versionadded:: 3.5.0
+        """
         session = _get_active_session(is_remote())
         path_exist = True
         try:
@@ -186,13 +227,18 @@ class ParamsReadWrite:
 
         tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_")
         try:
-            self.saveToLocal(tmp_local_dir, overwrite=True)
+            self._save_to_local(tmp_local_dir)
             _copy_dir_from_local_to_fs(tmp_local_dir, path)
         finally:
             shutil.rmtree(tmp_local_dir, ignore_errors=True)
 
     @classmethod
-    def load(cls, path: str) -> Any:
+    def load(cls, path: str) -> "Params":
+        """
+        Load Estimator / Transformer / Model / Evaluator from provided cloud storage path.
+
+        .. versionadded:: 3.5.0
+        """
         session = _get_active_session(is_remote())
 
         tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_")
@@ -205,12 +251,12 @@ class ParamsReadWrite:
                 with open(os.path.join(tmp_local_dir, file_name), "wb") as f:
                     f.write(file_content)
 
-            return cls.loadFromLocal(tmp_local_dir)
+            return cls._load_from_local(tmp_local_dir)
         finally:
             shutil.rmtree(tmp_local_dir, ignore_errors=True)
 
 
-class ModelReadWrite(ParamsReadWrite):
+class CoreModelReadWrite:
     def _get_core_model_filename(self) -> str:
         """
         Returns the name of the file for saving the core model.
@@ -231,12 +277,29 @@ class ModelReadWrite(ParamsReadWrite):
         """
         raise NotImplementedError()
 
-    def saveToLocal(self, path: str, *, overwrite: bool = False) -> None:
-        super(ModelReadWrite, self).saveToLocal(path, overwrite=overwrite)
-        self._save_core_model(os.path.join(path, self._get_core_model_filename()))
+
+class MetaAlgorithmReadWrite(ParamsReadWrite):
+    """
+    Meta-algorithm such as pipeline and cross validator must implement this interface.
+    """
+
+    def _save_meta_algorithm(self, root_path: str, node_path: List[str]) -> Dict[str, Any]:
+        raise NotImplementedError()
+
+    def _load_meta_algorithm(self, root_path: str, node_metadata: Dict[str, Any]) -> None:
+        raise NotImplementedError()
+
+    def _save_to_local(self, path: str) -> None:
+        metadata = self._save_meta_algorithm(path, [])
+        with open(os.path.join(path, _META_DATA_FILE_NAME), "w") as fp:
+            json.dump(metadata, fp)
 
     @classmethod
-    def loadFromLocal(cls, path: str) -> Any:
-        instance = super(ModelReadWrite, cls).loadFromLocal(path)
-        instance._load_core_model(os.path.join(path, instance._get_core_model_filename()))
+    def _load_from_local(cls, path: str) -> Any:
+        with open(os.path.join(path, _META_DATA_FILE_NAME), "r") as fp:
+            metadata = json.load(fp)
+
+        instance = cls._load_from_metadata(metadata)
+        instance._load_meta_algorithm(path, metadata)  # type: ignore[attr-defined]
+
         return instance
diff --git a/python/pyspark/mlv2/pipeline.py b/python/pyspark/mlv2/pipeline.py
new file mode 100644
index 00000000000..81e0651f178
--- /dev/null
+++ b/python/pyspark/mlv2/pipeline.py
@@ -0,0 +1,241 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+
+import pandas as pd
+from typing import Any, Dict, List, Optional, Union, cast, TYPE_CHECKING
+
+from pyspark import keyword_only, since
+from pyspark.mlv2.base import Estimator, Model, Transformer
+from pyspark.mlv2.io_utils import ParamsReadWrite, MetaAlgorithmReadWrite, CoreModelReadWrite
+from pyspark.ml.param import Param, Params
+from pyspark.ml.common import inherit_doc
+from pyspark.sql.dataframe import DataFrame
+
+
+if TYPE_CHECKING:
+    from pyspark.ml._typing import ParamMap
+
+
+class _PipelineReadWrite(MetaAlgorithmReadWrite):
+    def _get_skip_saving_params(self) -> List[str]:
+        """
+        Returns params to be skipped when saving metadata.
+        """
+        return ["stages"]
+
+    def _save_meta_algorithm(self, root_path: str, node_path: List[str]) -> Dict[str, Any]:
+        metadata = self._get_metadata_to_save()
+        metadata["stages"] = []
+
+        if isinstance(self, Pipeline):
+            stages = self.getStages()
+        elif isinstance(self, PipelineModel):
+            stages = self.stages
+        else:
+            raise ValueError()
+
+        for stage_index, stage in enumerate(stages):
+            stage_name = f"pipeline_stage_{stage_index}"
+            node_path.append(stage_name)
+            if isinstance(stage, MetaAlgorithmReadWrite):
+                stage_metadata = stage._save_meta_algorithm(root_path, node_path)
+            else:
+                stage_metadata = stage._get_metadata_to_save()  # type: ignore[attr-defined]
+                if isinstance(stage, CoreModelReadWrite):
+                    core_model_path = ".".join(node_path + [stage._get_core_model_filename()])
+                    stage._save_core_model(os.path.join(root_path, core_model_path))
+                    stage_metadata["core_model_path"] = core_model_path
+
+            metadata["stages"].append(stage_metadata)
+            node_path.pop()
+        return metadata
+
+    def _load_meta_algorithm(self, root_path: str, node_metadata: Dict[str, Any]) -> None:
+        stages = []
+        for stage_meta in node_metadata["stages"]:
+            stage = ParamsReadWrite._load_from_metadata(stage_meta)
+
+            if isinstance(stage, MetaAlgorithmReadWrite):
+                stage._load_meta_algorithm(root_path, stage_meta)
+
+            if isinstance(stage, CoreModelReadWrite):
+                core_model_path = stage_meta["core_model_path"]
+                stage._load_core_model(os.path.join(root_path, core_model_path))
+
+            stages.append(stage)
+
+        if isinstance(self, Pipeline):
+            self.setStages(stages)
+        elif isinstance(self, PipelineModel):
+            self.stages = stages
+        else:
+            raise ValueError()
+
+
+@inherit_doc
+class Pipeline(Estimator["PipelineModel"], _PipelineReadWrite):
+    """
+    A simple pipeline, which acts as an estimator. A Pipeline consists
+    of a sequence of stages, each of which is either an
+    :py:class:`Estimator` or a :py:class:`Transformer`. When
+    :py:meth:`Pipeline.fit` is called, the stages are executed in
+    order. If a stage is an :py:class:`Estimator`, its
+    :py:meth:`Estimator.fit` method will be called on the input
+    dataset to fit a model. Then the model, which is a transformer,
+    will be used to transform the dataset as the input to the next
+    stage. If a stage is a :py:class:`Transformer`, its
+    :py:meth:`Transformer.transform` method will be called to produce
+    the dataset for the next stage. The fitted model from a
+    :py:class:`Pipeline` is a :py:class:`PipelineModel`, which
+    consists of fitted models and transformers, corresponding to the
+    pipeline stages. If stages is an empty list, the pipeline acts as an
+    identity transformer.
+
+    .. versionadded:: 3.5.0
+    """
+
+    stages: Param[List[Params]] = Param(
+        Params._dummy(), "stages", "a list of pipeline stages"
+    )  # type: ignore[assignment]
+
+    _input_kwargs: Dict[str, Any]
+
+    @keyword_only
+    def __init__(self, *, stages: Optional[List[Params]] = None):
+        """
+        __init__(self, \\*, stages=None)
+        """
+        super(Pipeline, self).__init__()
+        kwargs = self._input_kwargs
+        self.setParams(**kwargs)
+
+    def setStages(self, value: List[Params]) -> "Pipeline":
+        """
+        Set pipeline stages.
+
+        .. versionadded:: 3.5.0
+
+        Parameters
+        ----------
+        value : list
+            of :py:class:`pyspark.mlv2.Transformer`
+            or :py:class:`pyspark.mlv2.Estimator`
+
+        Returns
+        -------
+        :py:class:`Pipeline`
+            the pipeline instance
+        """
+        return self._set(stages=value)
+
+    @since("3.5.0")
+    def getStages(self) -> List[Params]:
+        """
+        Get pipeline stages.
+        """
+        return self.getOrDefault(self.stages)
+
+    @keyword_only
+    @since("3.5.0")
+    def setParams(self, *, stages: Optional[List[Params]] = None) -> "Pipeline":
+        """
+        setParams(self, \\*, stages=None)
+        Sets params for Pipeline.
+        """
+        kwargs = self._input_kwargs
+        return self._set(**kwargs)
+
+    def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "PipelineModel":
+        stages = self.getStages()
+        for stage in stages:
+            if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
+                raise TypeError("Cannot recognize a pipeline stage of type %s." % type(stage))
+        indexOfLastEstimator = -1
+        for i, stage in enumerate(stages):
+            if isinstance(stage, Estimator):
+                indexOfLastEstimator = i
+        transformers: List[Transformer] = []
+        for i, stage in enumerate(stages):
+            if i <= indexOfLastEstimator:
+                if isinstance(stage, Transformer):
+                    transformers.append(stage)
+                    dataset = stage.transform(dataset)
+                else:  # must be an Estimator
+                    model = stage.fit(dataset)  # type: ignore[attr-defined]
+                    transformers.append(model)
+                    if i < indexOfLastEstimator:
+                        dataset = model.transform(dataset)
+            else:
+                transformers.append(cast(Transformer, stage))
+        pipeline_model = PipelineModel(transformers)  # type: ignore[arg-type]
+        pipeline_model._resetUid(self.uid)
+        return pipeline_model
+
+    def copy(self, extra: Optional["ParamMap"] = None) -> "Pipeline":
+        """
+        Creates a copy of this instance.
+
+        .. versionadded:: 3.5.0
+
+        Parameters
+        ----------
+        extra : dict, optional
+            extra parameters
+
+        Returns
+        -------
+        :py:class:`Pipeline`
+            new instance
+        """
+        if extra is None:
+            extra = dict()
+        that = Params.copy(self, extra)
+        stages = [stage.copy(extra) for stage in that.getStages()]
+        return that.setStages(stages)
+
+
+@inherit_doc
+class PipelineModel(Model, _PipelineReadWrite):
+    """
+    Represents a compiled pipeline with transformers and fitted models.
+
+    .. versionadded:: 3.5.0
+    """
+
+    def __init__(self, stages: Optional[List[Params]] = None):
+        super(PipelineModel, self).__init__()
+        self.stages = stages  # type: ignore[assignment]
+
+    def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
+        for t in self.stages:
+            dataset = t.transform(dataset)  # type: ignore[attr-defined]
+        return dataset
+
+    def copy(self, extra: Optional["ParamMap"] = None) -> "PipelineModel":
+        """
+        Creates a copy of this instance.
+
+        .. versionadded:: 3.5.0
+
+        :param extra: extra parameters
+        :returns: new instance
+        """
+        if extra is None:
+            extra = dict()
+        stages = [stage.copy(extra) for stage in self.stages]
+        return PipelineModel(stages)
diff --git a/python/pyspark/mlv2/tests/test_pipeline.py b/python/pyspark/mlv2/tests/test_pipeline.py
new file mode 100644
index 00000000000..cec421b7ee0
--- /dev/null
+++ b/python/pyspark/mlv2/tests/test_pipeline.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import tempfile
+import unittest
+import numpy as np
+from pyspark.mlv2.feature import StandardScaler
+from pyspark.mlv2.classification import LogisticRegression as LORV2
+from pyspark.mlv2.pipeline import Pipeline
+from pyspark.sql import SparkSession
+
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+
+class PipelineTestsMixin:
+    @staticmethod
+    def _check_result(result_dataframe, expected_predictions, expected_probabilities=None):
+        np.testing.assert_array_equal(list(result_dataframe.prediction), expected_predictions)
+        if "probability" in result_dataframe.columns:
+            np.testing.assert_allclose(
+                list(result_dataframe.probability),
+                expected_probabilities,
+                rtol=1e-2,
+            )
+
+    def test_pipeline(self):
+        train_dataset = self.spark.createDataFrame(
+            [
+                (1.0, [0.0, 5.0]),
+                (0.0, [1.0, 2.0]),
+                (1.0, [2.0, 1.0]),
+                (0.0, [3.0, 3.0]),
+            ]
+            * 100,
+            ["label", "features"],
+        )
+        eval_dataset = self.spark.createDataFrame(
+            [
+                ([0.0, 2.0],),
+                ([3.5, 3.0],),
+            ],
+            ["features"],
+        )
+        scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
+        lorv2 = LORV2(
+            maxIter=200, numTrainWorkers=2, learningRate=0.001, featuresCol="scaled_features"
+        )
+
+        pipeline = Pipeline(stages=[scaler, lorv2])
+        model = pipeline.fit(train_dataset)
+        assert model.uid == pipeline.uid
+
+        expected_predictions = [1, 0]
+        expected_probabilities = [
+            [0.117658, 0.882342],
+            [0.878738, 0.121262],
+        ]
+
+        result = model.transform(eval_dataset).toPandas()
+        self._check_result(result, expected_predictions, expected_probabilities)
+        local_transform_result = model.transform(eval_dataset.toPandas())
+        self._check_result(local_transform_result, expected_predictions, expected_probabilities)
+
+        pipeline2 = Pipeline(stages=[pipeline])
+        model2 = pipeline2.fit(train_dataset)
+        result2 = model2.transform(eval_dataset).toPandas()
+        self._check_result(result2, expected_predictions, expected_probabilities)
+        local_transform_result2 = model2.transform(eval_dataset.toPandas())
+        self._check_result(local_transform_result2, expected_predictions, expected_probabilities)
+
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            pipeline_local_path = os.path.join(tmp_dir, "pipeline")
+            pipeline.saveToLocal(pipeline_local_path)
+            loaded_pipeline = Pipeline.loadFromLocal(pipeline_local_path)
+
+            assert pipeline.uid == loaded_pipeline.uid
+            assert loaded_pipeline.getStages()[1].getMaxIter() == 200
+
+            pipeline_model_local_path = os.path.join(tmp_dir, "pipeline_model")
+            model.saveToLocal(pipeline_model_local_path)
+            loaded_model = Pipeline.loadFromLocal(pipeline_model_local_path)
+
+            assert model.uid == loaded_model.uid
+            assert loaded_model.stages[1].getMaxIter() == 200
+
+            loaded_model_transform_result = loaded_model.transform(eval_dataset).toPandas()
+            self._check_result(
+                loaded_model_transform_result, expected_predictions, expected_probabilities
+            )
+
+            pipeline2_local_path = os.path.join(tmp_dir, "pipeline2")
+            pipeline2.saveToLocal(pipeline2_local_path)
+            loaded_pipeline2 = Pipeline.loadFromLocal(pipeline2_local_path)
+
+            assert pipeline2.uid == loaded_pipeline2.uid
+            assert loaded_pipeline2.getStages()[0].getStages()[1].getMaxIter() == 200
+
+            pipeline2_model_local_path = os.path.join(tmp_dir, "pipeline2_model")
+            model2.saveToLocal(pipeline2_model_local_path)
+            loaded_model2 = Pipeline.loadFromLocal(pipeline2_model_local_path)
+
+            assert model2.uid == loaded_model2.uid
+            assert loaded_model2.stages[0].stages[1].getMaxIter() == 200
+
+            loaded_model2_transform_result = loaded_model2.transform(eval_dataset).toPandas()
+            self._check_result(
+                loaded_model2_transform_result, expected_predictions, expected_probabilities
+            )
+
+    @staticmethod
+    def test_pipeline_copy():
+        scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
+        lorv2 = LORV2(
+            maxIter=200, numTrainWorkers=2, learningRate=0.001, featuresCol="scaled_features"
+        )
+
+        pipeline = Pipeline(stages=[scaler, lorv2])
+
+        copied_pipeline = pipeline.copy(
+            {scaler.inputCol: "f1", lorv2.maxIter: 10, lorv2.numTrainWorkers: 1}
+        )
+
+        stages = copied_pipeline.getStages()
+
+        assert stages[0].getInputCol() == "f1"
+        assert stages[1].getOrDefault(stages[1].maxIter) == 10
+        assert stages[1].getOrDefault(stages[1].numTrainWorkers) == 1
+        assert stages[1].getOrDefault(stages[1].featuresCol) == "scaled_features"
+
+        pipeline2 = Pipeline(stages=[pipeline])
+        copied_pipeline2 = pipeline2.copy(
+            {scaler.inputCol: "f2", lorv2.maxIter: 20, lorv2.numTrainWorkers: 20}
+        )
+
+        stages = copied_pipeline2.getStages()[0].getStages()
+
+        assert stages[0].getInputCol() == "f2"
+        assert stages[1].getOrDefault(stages[1].maxIter) == 20
+        assert stages[1].getOrDefault(stages[1].numTrainWorkers) == 20
+        assert stages[1].getOrDefault(stages[1].featuresCol) == "scaled_features"
+
+        # test original stage instance params are not modified after pipeline copying.
+        assert scaler.getInputCol() == "features"
+        assert lorv2.getOrDefault(lorv2.maxIter) == 200
+
+
+class PipelineTests(PipelineTestsMixin, unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.master("local[2]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.mlv2.tests.test_pipeline import *  # noqa: F401,F403
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org