You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/06/14 04:42:41 UTC

[spark] branch master updated: [SPARK-43981][PYTHON][ML] Basic saving / loading implementation for ML on spark connect

This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a5d3bea04eb [SPARK-43981][PYTHON][ML] Basic saving / loading implementation for ML on spark connect
a5d3bea04eb is described below

commit a5d3bea04eb8430fd747905633b8164e21c95190
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Wed Jun 14 12:42:22 2023 +0800

    [SPARK-43981][PYTHON][ML] Basic saving / loading implementation for ML on spark connect
    
    ### What changes were proposed in this pull request?
    
    * Base class / helper functions for saving/loading estimator / transformer / evaluator / model.
    * Add saving/loading implementation for feature transformers.
    * Add saving/loading implementation for logistic regression estimator.
    
    Design goals:
    
    * The model format is decoupled from spark, i.e. we can run model inference without spark service.
    * We can save model to either local file system or cloud storage file system.
    
    ### Why are the changes needed?
    
    We need to support saving/loading estimator / transformer / evaluator / model.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ### How was this patch tested?
    
    Unit tests.
    
    Closes #41478 from WeichenXu123/mlv2-read-write.
    
    Authored-by: Weichen Xu <we...@databricks.com>
    Signed-off-by: Weichen Xu <we...@databricks.com>
---
 .../scala/org/apache/spark/ml/python/MLUtil.scala  |  43 ++++
 python/pyspark/ml/torch/distributor.py             |  14 +-
 python/pyspark/ml/util.py                          |  13 ++
 python/pyspark/mlv2/classification.py              | 118 ++++++----
 python/pyspark/mlv2/evaluation.py                  |   3 +-
 python/pyspark/mlv2/feature.py                     | 133 ++++++++---
 python/pyspark/mlv2/io_utils.py                    | 242 +++++++++++++++++++++
 python/pyspark/mlv2/summarizer.py                  |   4 +-
 .../tests/connect/test_parity_classification.py    |   6 +-
 python/pyspark/mlv2/tests/test_classification.py   |  84 ++++++-
 python/pyspark/mlv2/tests/test_feature.py          |  58 ++++-
 11 files changed, 634 insertions(+), 84 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/python/MLUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/python/MLUtil.scala
new file mode 100644
index 00000000000..5e2b8943ed8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/python/MLUtil.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.python
+
+import java.nio.file.Paths
+
+import org.apache.hadoop.fs.{Path => FSPath}
+
+import org.apache.spark.sql.SparkSession
+
+
+object MLUtil {
+
+  def copyFileFromLocalToFs(localPath: String, destPath: String): Unit = {
+    val sparkContext = SparkSession.getActiveSession.get.sparkContext
+
+    val hadoopConf = sparkContext.hadoopConfiguration
+    assert(
+      Paths.get(destPath).isAbsolute,
+      "Destination path must be an absolute path on cloud storage."
+    )
+    val destFSPath = new FSPath(destPath)
+    val fs = destFSPath.getFileSystem(hadoopConf)
+
+    fs.copyFromLocalFile(false, true, new FSPath(localPath.toString), destFSPath)
+  }
+
+}
diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py
index be49dc147c0..2ed70854cc6 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -48,19 +48,7 @@ from pyspark.ml.torch.log_communication import (  # type: ignore
     LogStreamingClient,
     LogStreamingServer,
 )
-
-
-def _get_active_session(is_remote: bool) -> SparkSession:
-    if not is_remote:
-        spark = SparkSession.getActiveSession()
-    else:
-        import pyspark.sql.connect.session
-
-        spark = pyspark.sql.connect.session._active_spark_session  # type: ignore[assignment]
-
-    if spark is None:
-        raise RuntimeError("An active SparkSession is required for the distributor.")
-    return spark
+from pyspark.ml.util import _get_active_session
 
 
 def _get_resources(session: SparkSession) -> Dict[str, ResourceInformation]:
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 5d1f89cbc13..74ce8162d18 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -747,3 +747,16 @@ def try_remote_functions(f: FuncT) -> FuncT:
             return f(*args, **kwargs)
 
     return cast(FuncT, wrapped)
+
+
+def _get_active_session(is_remote: bool) -> SparkSession:
+    if not is_remote:
+        spark = SparkSession.getActiveSession()
+    else:
+        import pyspark.sql.connect.session
+
+        spark = pyspark.sql.connect.session._active_spark_session  # type: ignore[assignment]
+
+    if spark is None:
+        raise RuntimeError("An active SparkSession is required for the distributor.")
+    return spark
diff --git a/python/pyspark/mlv2/classification.py b/python/pyspark/mlv2/classification.py
index a72fe89c01b..0fcded0d769 100644
--- a/python/pyspark/mlv2/classification.py
+++ b/python/pyspark/mlv2/classification.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 #
 
+from pyspark import keyword_only
 from pyspark.mlv2.base import _PredictorParams
 
 from pyspark.ml.param.shared import HasProbabilityCol
 
-from typing import Any, Union, List, Tuple, Callable
+from typing import Any, Dict, Union, List, Tuple, Callable, Optional
 import numpy as np
 import pandas as pd
 import math
@@ -39,6 +40,7 @@ from pyspark.ml.param.shared import (
     HasMomentum,
 )
 from pyspark.mlv2.base import Predictor, PredictionModel
+from pyspark.mlv2.io_utils import ParamsReadWrite, ModelReadWrite
 from pyspark.sql.functions import lit, count, countDistinct
 
 import torch
@@ -64,18 +66,16 @@ class _LogisticRegressionParams(
     .. versionadded:: 3.0.0
     """
 
-    pass
-
-
-class _LinearNet(torch_nn.Module):
-    def __init__(self, num_features: int, num_classes: int, bias: bool) -> None:
-        super(_LinearNet, self).__init__()
-        output_dim = num_classes
-        self.fc = torch_nn.Linear(num_features, output_dim, bias=bias, dtype=torch.float32)
-
-    def forward(self, x: Any) -> Any:
-        output = self.fc(x)
-        return output
+    def __init__(self, *args: Any):
+        super(_LogisticRegressionParams, self).__init__(*args)
+        self._setDefault(
+            maxIter=100,
+            tol=1e-6,
+            batchSize=32,
+            learningRate=0.001,
+            momentum=0.9,
+            seed=0,
+        )
 
 
 def _train_logistic_regression_model_worker_fn(
@@ -101,9 +101,10 @@ def _train_logistic_regression_model_worker_fn(
     # TODO: support L1 / L2 regularization
     torch.distributed.init_process_group("gloo")
 
-    ddp_model = DDP(
-        _LinearNet(num_features=num_features, num_classes=num_classes, bias=fit_intercept)
+    linear_model = torch_nn.Linear(
+        num_features, num_classes, bias=fit_intercept, dtype=torch.float32
     )
+    ddp_model = DDP(linear_model)
 
     loss_fn = torch_nn.CrossEntropyLoss()
 
@@ -144,13 +145,18 @@ def _train_logistic_regression_model_worker_fn(
 
 
 @inherit_doc
-class LogisticRegression(Predictor["LogisticRegressionModel"], _LogisticRegressionParams):
+class LogisticRegression(
+    Predictor["LogisticRegressionModel"], _LogisticRegressionParams, ParamsReadWrite
+):
     """
     Logistic regression estimator.
 
     .. versionadded:: 3.5.0
     """
 
+    _input_kwargs: Dict[str, Any]
+
+    @keyword_only
     def __init__(
         self,
         *,
@@ -166,20 +172,26 @@ class LogisticRegression(Predictor["LogisticRegressionModel"], _LogisticRegressi
         momentum: float = 0.9,
         seed: int = 0,
     ):
-        super(_LogisticRegressionParams, self).__init__()
-        self._set(
-            featuresCol=featuresCol,
-            labelCol=labelCol,
-            predictionCol=predictionCol,
-            probabilityCol=probabilityCol,
-            maxIter=maxIter,
-            tol=tol,
-            numTrainWorkers=numTrainWorkers,
-            batchSize=batchSize,
-            learningRate=learningRate,
-            momentum=momentum,
-            seed=seed,
+        """
+        __init__(
+            self,
+            *,
+            featuresCol: str = "features",
+            labelCol: str = "label",
+            predictionCol: str = "prediction",
+            probabilityCol: str = "probability",
+            maxIter: int = 100,
+            tol: float = 1e-6,
+            numTrainWorkers: int = 1,
+            batchSize: int = 32,
+            learningRate: float = 0.001,
+            momentum: float = 0.9,
+            seed: int = 0,
         )
+        """
+        super(LogisticRegression, self).__init__()
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
 
     def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LogisticRegressionModel":
         if isinstance(dataset, pd.DataFrame):
@@ -228,8 +240,8 @@ class LogisticRegression(Predictor["LogisticRegressionModel"], _LogisticRegressi
 
         dataset.unpersist()
 
-        torch_model = _LinearNet(
-            num_features=num_features, num_classes=num_classes, bias=self.getFitIntercept()
+        torch_model = torch_nn.Linear(
+            num_features, num_classes, bias=self.getFitIntercept(), dtype=torch.float32
         )
         torch_model.load_state_dict(model_state_dict)
 
@@ -241,24 +253,31 @@ class LogisticRegression(Predictor["LogisticRegressionModel"], _LogisticRegressi
 
 
 @inherit_doc
-class LogisticRegressionModel(PredictionModel, _LogisticRegressionParams):
+class LogisticRegressionModel(PredictionModel, _LogisticRegressionParams, ModelReadWrite):
     """
     Model fitted by LogisticRegression.
 
     .. versionadded:: 3.5.0
     """
 
-    def __init__(self, torch_model: Any, num_features: int, num_classes: int):
+    def __init__(
+        self,
+        torch_model: Any = None,
+        num_features: Optional[int] = None,
+        num_classes: Optional[int] = None,
+    ):
         super().__init__()
         self.torch_model = torch_model
         self.num_features = num_features
         self.num_classes = num_classes
 
+    @property
     def numFeatures(self) -> int:
-        return self.num_features
+        return self.num_features  # type: ignore[return-value]
 
+    @property
     def numClasses(self) -> int:
-        return self.num_classes
+        return self.num_classes  # type: ignore[return-value]
 
     def _input_columns(self) -> List[str]:
         return [self.getOrDefault(self.featuresCol)]
@@ -277,8 +296,11 @@ class LogisticRegressionModel(PredictionModel, _LogisticRegressionParams):
         fit_intercept = self.getFitIntercept()
 
         def transform_fn(input_series: Any) -> Any:
-            torch_model = _LinearNet(
-                num_features=num_features, num_classes=num_classes, bias=fit_intercept
+            torch_model = torch_nn.Linear(
+                num_features,  # type: ignore[arg-type]
+                num_classes,  # type: ignore[arg-type]
+                bias=fit_intercept,
+                dtype=torch.float32,
             )
             # TODO: Use spark broadast for `model_state_dict`,
             #  it can improve performance when model is large.
@@ -304,3 +326,25 @@ class LogisticRegressionModel(PredictionModel, _LogisticRegressionParams):
                 return pd.Series(data=list(predictions), index=input_series.index.copy())
 
         return transform_fn
+
+    def _get_core_model_filename(self) -> str:
+        return self.__class__.__name__ + ".torch"
+
+    def _save_core_model(self, path: str) -> None:
+        torch.save(self.torch_model, path)
+
+    def _load_core_model(self, path: str) -> None:
+        self.torch_model = torch.load(path)
+
+    def _get_extra_metadata(self) -> Dict[str, Any]:
+        return {
+            "num_features": self.num_features,
+            "num_classes": self.num_classes,
+        }
+
+    def _load_extra_metadata(self, extra_metadata: Dict[str, Any]) -> None:
+        """
+        Load extra metadata attribute from extra metadata json object.
+        """
+        self.num_features = extra_metadata["num_features"]
+        self.num_classes = extra_metadata["num_classes"]
diff --git a/python/pyspark/mlv2/evaluation.py b/python/pyspark/mlv2/evaluation.py
index 720179ed9b4..671819d29e8 100644
--- a/python/pyspark/mlv2/evaluation.py
+++ b/python/pyspark/mlv2/evaluation.py
@@ -21,6 +21,7 @@ from typing import Any, Union
 from pyspark.ml.param import Param, Params, TypeConverters
 from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol
 from pyspark.mlv2.base import Evaluator
+from pyspark.mlv2.io_utils import ParamsReadWrite
 from pyspark.mlv2.util import aggregate_dataframe
 from pyspark.sql import DataFrame
 
@@ -28,7 +29,7 @@ import torch
 import torcheval.metrics as torchmetrics
 
 
-class RegressionEvaluator(Evaluator, HasLabelCol, HasPredictionCol):
+class RegressionEvaluator(Evaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite):
     """
     Evaluator for Regression, which expects input columns prediction and label.
     Supported metrics are 'mse' and 'r2'.
diff --git a/python/pyspark/mlv2/feature.py b/python/pyspark/mlv2/feature.py
index 43ecc0e17ea..57c6213d2bb 100644
--- a/python/pyspark/mlv2/feature.py
+++ b/python/pyspark/mlv2/feature.py
@@ -17,44 +17,60 @@
 
 import numpy as np
 import pandas as pd
-from typing import Any, Union, List, Tuple, Callable
+import pickle
+from typing import Any, Union, List, Tuple, Callable, Dict, Optional
 
+from pyspark import keyword_only
 from pyspark.sql import DataFrame
+from pyspark.ml.param.shared import HasInputCol, HasOutputCol
 from pyspark.mlv2.base import Estimator, Model
+from pyspark.mlv2.io_utils import ParamsReadWrite, ModelReadWrite
 from pyspark.mlv2.summarizer import summarize_dataframe
-from pyspark.ml.param.shared import HasInputCol, HasOutputCol
 
 
-class MaxAbsScaler(Estimator, HasInputCol, HasOutputCol):
+class MaxAbsScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite):
     """
     Rescale each feature individually to range [-1, 1] by dividing through the largest maximum
     absolute value in each feature. It does not shift/center the data, and thus does not destroy
     any sparsity.
     """
 
-    def __init__(self, inputCol: str, outputCol: str) -> None:
+    _input_kwargs: Dict[str, Any]
+
+    @keyword_only
+    def __init__(self, *, inputCol: Optional[str] = None, outputCol: Optional[str] = None) -> None:
+        """
+        __init__(self, \\*, inputCol=None, outputCol=None)
+        """
         super().__init__()
-        self.set(self.inputCol, inputCol)
-        self.set(self.outputCol, outputCol)
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
 
     def _fit(self, dataset: Union["pd.DataFrame", "DataFrame"]) -> "MaxAbsScalerModel":
         input_col = self.getInputCol()
 
-        min_max_res = summarize_dataframe(dataset, input_col, ["min", "max"])
-        min_values = min_max_res["min"]
-        max_values = min_max_res["max"]
+        stat_res = summarize_dataframe(dataset, input_col, ["min", "max", "count"])
+        min_values = stat_res["min"]
+        max_values = stat_res["max"]
+        n_samples_seen = stat_res["count"]
 
         max_abs_values = np.maximum(np.abs(min_values), np.abs(max_values))
 
-        model = MaxAbsScalerModel(max_abs_values)
+        model = MaxAbsScalerModel(max_abs_values, n_samples_seen)
         model._resetUid(self.uid)
         return self._copyValues(model)
 
 
-class MaxAbsScalerModel(Model, HasInputCol, HasOutputCol):
-    def __init__(self, max_abs_values: "np.ndarray") -> None:
+class MaxAbsScalerModel(Model, HasInputCol, HasOutputCol, ModelReadWrite):
+    def __init__(
+        self, max_abs_values: Optional["np.ndarray"] = None, n_samples_seen: Optional[int] = None
+    ) -> None:
         super().__init__()
         self.max_abs_values = max_abs_values
+        if max_abs_values is not None:
+            # if scale value is zero, replace it with 1.0 (for preventing division by zero)
+            self.scale_values = np.where(max_abs_values == 0.0, 1.0, max_abs_values)
+        self.n_samples_seen = n_samples_seen
 
     def _input_columns(self) -> List[str]:
         return [self.getInputCol()]
@@ -63,46 +79,84 @@ class MaxAbsScalerModel(Model, HasInputCol, HasOutputCol):
         return [(self.getOutputCol(), "array<double>")]
 
     def _get_transform_fn(self) -> Callable[..., Any]:
-        max_abs_values = self.max_abs_values
-        max_abs_values_zero_cond = max_abs_values == 0.0
+        scale_values = self.scale_values
 
         def transform_fn(series: Any) -> Any:
             def map_value(x: "np.ndarray") -> "np.ndarray":
-                return np.where(max_abs_values_zero_cond, 0.0, x / max_abs_values)
+                return x / scale_values
 
             return series.apply(map_value)
 
         return transform_fn
 
+    def _get_core_model_filename(self) -> str:
+        return self.__class__.__name__ + ".sklearn.pkl"
+
+    def _save_core_model(self, path: str) -> None:
+        from sklearn.preprocessing import MaxAbsScaler as sk_MaxAbsScaler
+
+        sk_model = sk_MaxAbsScaler()
+        sk_model.scale_ = self.scale_values
+        sk_model.max_abs_ = self.max_abs_values
+        sk_model.n_features_in_ = len(self.max_abs_values)  # type: ignore[arg-type]
+        sk_model.n_samples_seen_ = self.n_samples_seen
+
+        with open(path, "wb") as fp:
+            pickle.dump(sk_model, fp)
+
+    def _load_core_model(self, path: str) -> None:
+        with open(path, "rb") as fp:
+            sk_model = pickle.load(fp)
+
+        self.max_abs_values = sk_model.max_abs_
+        self.scale_values = sk_model.scale_
+        self.n_samples_seen = sk_model.n_samples_seen_
 
-class StandardScaler(Estimator, HasInputCol, HasOutputCol):
+
+class StandardScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite):
     """
     Standardizes features by removing the mean and scaling to unit variance using column summary
     statistics on the samples in the training set.
     """
 
-    def __init__(self, inputCol: str, outputCol: str) -> None:
+    _input_kwargs: Dict[str, Any]
+
+    @keyword_only
+    def __init__(self, inputCol: Optional[str] = None, outputCol: Optional[str] = None) -> None:
+        """
+        __init__(self, \\*, inputCol=None, outputCol=None)
+        """
         super().__init__()
-        self.set(self.inputCol, inputCol)
-        self.set(self.outputCol, outputCol)
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
 
     def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "StandardScalerModel":
         input_col = self.getInputCol()
 
-        min_max_res = summarize_dataframe(dataset, input_col, ["mean", "std"])
-        mean_values = min_max_res["mean"]
-        std_values = min_max_res["std"]
+        stat_result = summarize_dataframe(dataset, input_col, ["mean", "std", "count"])
+        mean_values = stat_result["mean"]
+        std_values = stat_result["std"]
+        n_samples_seen = stat_result["count"]
 
-        model = StandardScalerModel(mean_values, std_values)
+        model = StandardScalerModel(mean_values, std_values, n_samples_seen)
         model._resetUid(self.uid)
         return self._copyValues(model)
 
 
-class StandardScalerModel(Model, HasInputCol, HasOutputCol):
-    def __init__(self, mean_values: "np.ndarray", std_values: "np.ndarray") -> None:
+class StandardScalerModel(Model, HasInputCol, HasOutputCol, ModelReadWrite):
+    def __init__(
+        self,
+        mean_values: Optional["np.ndarray"] = None,
+        std_values: Optional["np.ndarray"] = None,
+        n_samples_seen: Optional[int] = None,
+    ) -> None:
         super().__init__()
         self.mean_values = mean_values
         self.std_values = std_values
+        if std_values is not None:
+            # if scale value is zero, replace it with 1.0 (for preventing division by zero)
+            self.scale_values = np.where(std_values == 0.0, 1.0, std_values)
+        self.n_samples_seen = n_samples_seen
 
     def _input_columns(self) -> List[str]:
         return [self.getInputCol()]
@@ -112,12 +166,37 @@ class StandardScalerModel(Model, HasInputCol, HasOutputCol):
 
     def _get_transform_fn(self) -> Callable[..., Any]:
         mean_values = self.mean_values
-        std_values = self.std_values
+        scale_values = self.scale_values
 
         def transform_fn(series: Any) -> Any:
             def map_value(x: "np.ndarray") -> "np.ndarray":
-                return (x - mean_values) / std_values
+                return (x - mean_values) / scale_values
 
             return series.apply(map_value)
 
         return transform_fn
+
+    def _get_core_model_filename(self) -> str:
+        return self.__class__.__name__ + ".sklearn.pkl"
+
+    def _save_core_model(self, path: str) -> None:
+        from sklearn.preprocessing import StandardScaler as sk_StandardScaler
+
+        sk_model = sk_StandardScaler(with_mean=True, with_std=True)
+        sk_model.scale_ = self.scale_values
+        sk_model.var_ = self.std_values * self.std_values  # type: ignore[operator]
+        sk_model.mean_ = self.mean_values
+        sk_model.n_features_in_ = len(self.std_values)  # type: ignore[arg-type]
+        sk_model.n_samples_seen_ = self.n_samples_seen
+
+        with open(path, "wb") as fp:
+            pickle.dump(sk_model, fp)
+
+    def _load_core_model(self, path: str) -> None:
+        with open(path, "rb") as fp:
+            sk_model = pickle.load(fp)
+
+        self.std_values = np.sqrt(sk_model.var_)
+        self.scale_values = sk_model.scale_
+        self.mean_values = sk_model.mean_
+        self.n_samples_seen = sk_model.n_samples_seen_
diff --git a/python/pyspark/mlv2/io_utils.py b/python/pyspark/mlv2/io_utils.py
new file mode 100644
index 00000000000..8f7263206a7
--- /dev/null
+++ b/python/pyspark/mlv2/io_utils.py
@@ -0,0 +1,242 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import shutil
+import os
+import tempfile
+import time
+from urllib.parse import urlparse
+from typing import Any, Dict, Optional
+from pyspark.ml.util import _get_active_session
+from pyspark.sql.utils import is_remote
+
+
+from pyspark import __version__ as pyspark_version
+
+
+_META_DATA_FILE_NAME = "metadata.json"
+
+
+def _copy_file_from_local_to_fs(local_path: str, dest_path: str) -> None:
+    session = _get_active_session(is_remote())
+    if is_remote():
+        session.copyFromLocalToFs(local_path, dest_path)  # type: ignore[attr-defined]
+    else:
+        jvm = session.sparkContext._gateway.jvm  # type: ignore[union-attr]
+        jvm.org.apache.spark.ml.python.MLUtil.copyFileFromLocalToFs(local_path, dest_path)
+
+
+def _copy_dir_from_local_to_fs(local_path: str, dest_path: str) -> None:
+    """
+    Copy directory from local path to cloud storage path.
+    Limitation: Currently only one level directory is supported.
+    """
+    assert os.path.isdir(local_path)
+
+    file_list = os.listdir(local_path)
+    for file_name in file_list:
+        file_path = os.path.join(local_path, file_name)
+        dest_file_path = os.path.join(dest_path, file_name)
+        assert os.path.isfile(file_path)
+        _copy_file_from_local_to_fs(file_path, dest_file_path)
+
+
+def _get_metadata_to_save(
+    instance: Any,
+    extra_metadata: Optional[Dict[str, Any]] = None,
+) -> Dict[str, Any]:
+    """
+    Extract metadata of Estimator / Transformer / Model / Evaluator instance.
+    """
+    uid = instance.uid
+    cls = instance.__module__ + "." + instance.__class__.__name__
+
+    # User-supplied param values
+    params = instance._paramMap
+    json_params = {}
+    for p in params:
+        json_params[p.name] = params[p]
+
+    # Default param values
+    json_default_params = {}
+    for p in instance._defaultParamMap:
+        json_default_params[p.name] = instance._defaultParamMap[p]
+
+    metadata = {
+        "class": cls,
+        "timestamp": int(round(time.time() * 1000)),
+        "sparkVersion": pyspark_version,
+        "uid": uid,
+        "paramMap": json_params,
+        "defaultParamMap": json_default_params,
+        "type": "spark_connect",
+    }
+    if extra_metadata is not None:
+        assert isinstance(extra_metadata, dict)
+        metadata["extra"] = extra_metadata
+
+    return metadata
+
+
+def _get_class(clazz: str) -> Any:
+    """
+    Loads Python class from its name.
+    """
+    parts = clazz.split(".")
+    module = ".".join(parts[:-1])
+    m = __import__(module, fromlist=[parts[-1]])
+    return getattr(m, parts[-1])
+
+
+class ParamsReadWrite:
+    """
+    The base interface Estimator / Transformer / Model / Evaluator needs to inherit
+    for supporting saving and loading.
+    """
+
+    def _get_extra_metadata(self) -> Any:
+        """
+        Returns exta metadata of the instance
+        """
+        return None
+
+    def _load_extra_metadata(self, metadata: Dict[str, Any]) -> None:
+        """
+        Load extra metadata attribute from metadata json object.
+        """
+        pass
+
+    def saveToLocal(self, path: str, *, overwrite: bool = False) -> None:
+        """
+        Save model to provided local path.
+        """
+        metadata = _get_metadata_to_save(self, extra_metadata=self._get_extra_metadata())
+
+        if os.path.exists(path):
+            if overwrite:
+                if os.path.isdir(path):
+                    shutil.rmtree(path)
+                else:
+                    os.remove(path)
+            else:
+                raise ValueError(f"The path {path} already exists.")
+
+        os.makedirs(path)
+        with open(os.path.join(path, _META_DATA_FILE_NAME), "w") as fp:
+            json.dump(metadata, fp)
+
+    @classmethod
+    def loadFromLocal(cls, path: str) -> Any:
+        """
+        Load model from provided local path.
+        """
+        with open(os.path.join(path, _META_DATA_FILE_NAME), "r") as fp:
+            metadata = json.load(fp)
+
+        if "type" not in metadata or metadata["type"] != "spark_connect":
+            raise RuntimeError("The model is not saved by spark ML under spark connect mode.")
+
+        class_name = metadata["class"]
+        instance = _get_class(class_name)()
+        instance._resetUid(metadata["uid"])
+
+        # Set user-supplied param values
+        for paramName in metadata["paramMap"]:
+            param = instance.getParam(paramName)
+            paramValue = metadata["paramMap"][paramName]
+            instance.set(param, paramValue)
+
+        for paramName in metadata["defaultParamMap"]:
+            paramValue = metadata["defaultParamMap"][paramName]
+            instance._setDefault(**{paramName: paramValue})
+
+        if "extra" in metadata:
+            instance._load_extra_metadata(metadata["extra"])
+        return instance
+
+    def save(self, path: str, *, overwrite: bool = False) -> None:
+        session = _get_active_session(is_remote())
+        path_exist = True
+        try:
+            session.read.format("binaryFile").load(path).head()
+        except Exception as e:
+            if "Path does not exist" in str(e):
+                path_exist = False
+            else:
+                # Unexpected error.
+                raise e
+
+        if path_exist and not overwrite:
+            raise ValueError(f"The path {path} already exists.")
+
+        tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_")
+        try:
+            self.saveToLocal(tmp_local_dir, overwrite=True)
+            _copy_dir_from_local_to_fs(tmp_local_dir, path)
+        finally:
+            shutil.rmtree(tmp_local_dir, ignore_errors=True)
+
+    @classmethod
+    def load(cls, path: str) -> Any:
+        session = _get_active_session(is_remote())
+
+        tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_")
+        try:
+            file_data_df = session.read.format("binaryFile").load(path)
+
+            for row in file_data_df.toLocalIterator():
+                file_name = os.path.basename(urlparse(row.path).path)
+                file_content = bytes(row.content)
+                with open(os.path.join(tmp_local_dir, file_name), "wb") as f:
+                    f.write(file_content)
+
+            return cls.loadFromLocal(tmp_local_dir)
+        finally:
+            shutil.rmtree(tmp_local_dir, ignore_errors=True)
+
+
+class ModelReadWrite(ParamsReadWrite):
+    def _get_core_model_filename(self) -> str:
+        """
+        Returns the name of the file for saving the core model.
+        """
+        raise NotImplementedError()
+
+    def _save_core_model(self, path: str) -> None:
+        """
+        Save the core model to provided local path.
+        Different pyspark models contain different type of core model,
+        e.g. for LogisticRegressionModel, its core model is a pytorch model.
+        """
+        raise NotImplementedError()
+
+    def _load_core_model(self, path: str) -> None:
+        """
+        Load the core model from provided local path.
+        """
+        raise NotImplementedError()
+
+    def saveToLocal(self, path: str, *, overwrite: bool = False) -> None:
+        super(ModelReadWrite, self).saveToLocal(path, overwrite=overwrite)
+        self._save_core_model(os.path.join(path, self._get_core_model_filename()))
+
+    @classmethod
+    def loadFromLocal(cls, path: str) -> Any:
+        instance = super(ModelReadWrite, cls).loadFromLocal(path)
+        instance._load_core_model(os.path.join(path, instance._get_core_model_filename()))
+        return instance
diff --git a/python/pyspark/mlv2/summarizer.py b/python/pyspark/mlv2/summarizer.py
index 4d5e1d8988f..d80e770ebe4 100644
--- a/python/pyspark/mlv2/summarizer.py
+++ b/python/pyspark/mlv2/summarizer.py
@@ -70,6 +70,8 @@ class SummarizerAggState:
                     )
                     * (self.count / (self.count - 1))
                 )
+            if metric == "count":
+                result["count"] = self.count  # type: ignore[assignment]
 
         return result
 
@@ -90,7 +92,7 @@ def summarize_dataframe(
         and all values in the column must have the same length.
     metrics:
         The metrics to be summarized, available metrics are:
-        "min", "max",  "sum", "mean"
+        "min", "max",  "sum", "mean", "count"
 
     Returns
     -------
diff --git a/python/pyspark/mlv2/tests/connect/test_parity_classification.py b/python/pyspark/mlv2/tests/connect/test_parity_classification.py
index 16bc1af1a64..8796556e0d8 100644
--- a/python/pyspark/mlv2/tests/connect/test_parity_classification.py
+++ b/python/pyspark/mlv2/tests/connect/test_parity_classification.py
@@ -23,7 +23,11 @@ from pyspark.mlv2.tests.test_classification import ClassificationTestsMixin
 
 class FeatureTestsOnConnect(ClassificationTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
-        self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
+        self.spark = (
+            SparkSession.builder.remote("local[2]")
+            .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
+            .getOrCreate()
+        )
 
     def tearDown(self) -> None:
         self.spark.stop()
diff --git a/python/pyspark/mlv2/tests/test_classification.py b/python/pyspark/mlv2/tests/test_classification.py
index b9c112ef094..159862ef5f6 100644
--- a/python/pyspark/mlv2/tests/test_classification.py
+++ b/python/pyspark/mlv2/tests/test_classification.py
@@ -15,10 +15,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import os
+import tempfile
 import unittest
 import numpy as np
-from pyspark.mlv2.classification import LogisticRegression as LORV2
+from pyspark.mlv2.classification import (
+    LogisticRegression as LORV2,
+    LogisticRegressionModel as LORV2Model,
+)
 from pyspark.sql import SparkSession
 
 
@@ -118,6 +122,82 @@ class ClassificationTestsMixin:
         local_transform_result = model.transform(eval_df1.toPandas())
         self._check_result(local_transform_result, expected_predictions, expected_probabilities)
 
+    def test_save_load(self):
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            estimator = LORV2(maxIter=2, numTrainWorkers=2, learningRate=0.001)
+            local_path = os.path.join(tmp_dir, "estimator")
+            estimator.saveToLocal(local_path)
+            loaded_estimator = LORV2.loadFromLocal(local_path)
+            assert loaded_estimator.uid == estimator.uid
+            assert loaded_estimator.getOrDefault(loaded_estimator.maxIter) == 2
+            assert loaded_estimator.getOrDefault(loaded_estimator.numTrainWorkers) == 2
+            assert loaded_estimator.getOrDefault(loaded_estimator.learningRate) == 0.001
+
+            # test overwriting
+            estimator2 = estimator.copy()
+            estimator2.set(estimator2.maxIter, 10)
+            estimator2.saveToLocal(local_path, overwrite=True)
+            loaded_estimator2 = LORV2.loadFromLocal(local_path)
+            assert loaded_estimator2.getOrDefault(loaded_estimator2.maxIter) == 10
+
+            fs_path = os.path.join(tmp_dir, "fs", "estimator")
+            estimator.save(fs_path)
+            loaded_estimator = LORV2.load(fs_path)
+            assert loaded_estimator.uid == estimator.uid
+            assert loaded_estimator.getOrDefault(loaded_estimator.maxIter) == 2
+            assert loaded_estimator.getOrDefault(loaded_estimator.numTrainWorkers) == 2
+            assert loaded_estimator.getOrDefault(loaded_estimator.learningRate) == 0.001
+
+            training_dataset = self.spark.createDataFrame(
+                [
+                    (1.0, [0.0, 5.0]),
+                    (0.0, [1.0, 2.0]),
+                    (1.0, [2.0, 1.0]),
+                    (0.0, [3.0, 3.0]),
+                ]
+                * 100,
+                ["label", "features"],
+            )
+            eval_df1 = self.spark.createDataFrame(
+                [
+                    ([0.0, 2.0],),
+                    ([3.5, 3.0],),
+                ],
+                ["features"],
+            )
+
+            model = estimator.fit(training_dataset)
+            assert model.uid == estimator.uid
+
+            local_model_path = os.path.join(tmp_dir, "model")
+            model.saveToLocal(local_model_path)
+            loaded_model = LORV2Model.loadFromLocal(local_model_path)
+            assert loaded_model.numFeatures == 2
+            assert loaded_model.numClasses == 2
+            assert loaded_model.getOrDefault(loaded_model.maxIter) == 2
+            assert loaded_model.torch_model is not None
+            np.testing.assert_allclose(
+                loaded_model.torch_model.weight.detach().numpy(),
+                model.torch_model.weight.detach().numpy(),
+            )
+            np.testing.assert_allclose(
+                loaded_model.torch_model.bias.detach().numpy(),
+                model.torch_model.bias.detach().numpy(),
+            )
+
+            # Test loaded model transformation.
+            loaded_model.transform(eval_df1.toPandas())
+
+            fs_model_path = os.path.join(tmp_dir, "fs", "model")
+            model.save(fs_model_path)
+            loaded_model = LORV2Model.load(fs_model_path)
+            assert loaded_model.numFeatures == 2
+            assert loaded_model.numClasses == 2
+            assert loaded_model.getOrDefault(loaded_model.maxIter) == 2
+            assert loaded_model.torch_model is not None
+            # Test loaded model transformation works.
+            loaded_model.transform(eval_df1.toPandas())
+
 
 class ClassificationTests(ClassificationTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
diff --git a/python/pyspark/mlv2/tests/test_feature.py b/python/pyspark/mlv2/tests/test_feature.py
index bd58cccfbfa..65d35dc16f7 100644
--- a/python/pyspark/mlv2/tests/test_feature.py
+++ b/python/pyspark/mlv2/tests/test_feature.py
@@ -16,10 +16,18 @@
 # limitations under the License.
 #
 
-import unittest
+import os
+import pickle
 import numpy as np
+import tempfile
+import unittest
 
-from pyspark.mlv2.feature import MaxAbsScaler, StandardScaler
+from pyspark.mlv2.feature import (
+    MaxAbsScaler,
+    MaxAbsScalerModel,
+    StandardScaler,
+    StandardScalerModel,
+)
 from pyspark.sql import SparkSession
 
 
@@ -34,7 +42,9 @@ class FeatureTestsMixin:
         )
 
         scaler = MaxAbsScaler(inputCol="features", outputCol="scaled_features")
+
         model = scaler.fit(df1)
+        assert model.uid == scaler.uid
         result = model.transform(df1).toPandas()
         assert list(result.columns) == ["features", "scaled_features"]
 
@@ -50,6 +60,27 @@ class FeatureTestsMixin:
 
         np.testing.assert_allclose(list(local_transform_result.scaled_features), expected_result)
 
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            estimator_path = os.path.join(tmp_dir, "estimator")
+            scaler.saveToLocal(estimator_path)
+            loaded_scaler = MaxAbsScaler.loadFromLocal(estimator_path)
+            assert loaded_scaler.getInputCol() == "features"
+            assert loaded_scaler.getOutputCol() == "scaled_features"
+
+            model_path = os.path.join(tmp_dir, "model")
+            model.saveToLocal(model_path)
+            loaded_model = MaxAbsScalerModel.loadFromLocal(model_path)
+
+            np.testing.assert_allclose(model.scale_values, loaded_model.scale_values)
+            np.testing.assert_allclose(model.max_abs_values, loaded_model.max_abs_values)
+            assert model.n_samples_seen == loaded_model.n_samples_seen
+
+            # Test loading core model as scikit-learn model
+            with open(os.path.join(model_path, "MaxAbsScalerModel.sklearn.pkl"), "rb") as f:
+                sk_model = pickle.load(f)
+                sk_result = sk_model.transform(np.stack(list(local_df1.features)))
+                np.testing.assert_allclose(sk_result, expected_result)
+
     def test_standard_scaler(self):
         df1 = self.spark.createDataFrame(
             [
@@ -62,6 +93,7 @@ class FeatureTestsMixin:
 
         scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
         model = scaler.fit(df1)
+        assert model.uid == scaler.uid
         result = model.transform(df1).toPandas()
         assert list(result.columns) == ["features", "scaled_features"]
 
@@ -81,6 +113,28 @@ class FeatureTestsMixin:
 
         np.testing.assert_allclose(list(local_transform_result.scaled_features), expected_result)
 
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            estimator_path = os.path.join(tmp_dir, "estimator")
+            scaler.saveToLocal(estimator_path)
+            loaded_scaler = StandardScaler.loadFromLocal(estimator_path)
+            assert loaded_scaler.getInputCol() == "features"
+            assert loaded_scaler.getOutputCol() == "scaled_features"
+
+            model_path = os.path.join(tmp_dir, "model")
+            model.saveToLocal(model_path)
+            loaded_model = StandardScalerModel.loadFromLocal(model_path)
+
+            np.testing.assert_allclose(model.std_values, loaded_model.std_values)
+            np.testing.assert_allclose(model.mean_values, loaded_model.mean_values)
+            np.testing.assert_allclose(model.scale_values, loaded_model.scale_values)
+            assert model.n_samples_seen == loaded_model.n_samples_seen
+
+            # Test loading core model as scikit-learn model
+            with open(os.path.join(model_path, "StandardScalerModel.sklearn.pkl"), "rb") as f:
+                sk_model = pickle.load(f)
+                sk_result = sk_model.transform(np.stack(list(local_df1.features)))
+                np.testing.assert_allclose(sk_result, expected_result)
+
 
 class FeatureTests(FeatureTestsMixin, unittest.TestCase):
     def setUp(self) -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org