You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/11/10 09:40:30 UTC

[spark] branch master updated: [SPARK-37022][PYTHON] Use black as a formatter for PySpark

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2fe9af8  [SPARK-37022][PYTHON] Use black as a formatter for  PySpark
2fe9af8 is described below

commit 2fe9af8b2b91d0a46782dd6fff57eca8609be105
Author: zero323 <ms...@gmail.com>
AuthorDate: Wed Nov 10 18:39:06 2021 +0900

    [SPARK-37022][PYTHON] Use black as a formatter for  PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR applies `black` (21.5b2) formatting to the whole `python/pyspark` source tree.
    
    Additionally, the following changes were made:
    
    - Disabled E501 (line too long) in pycodestyle config ‒ black allows line to exceed `line-length` in cases of inline comments.  There are 15 cases like this, all listed below
        ```
        pycodestyle checks failed:
        ./python/pyspark/sql/catalog.py:349:101: E501 line too long (103 > 100 characters)
        ./python/pyspark/sql/session.py:652:101: E501 line too long (106 > 100 characters)
        ./python/pyspark/sql/utils.py:50:101: E501 line too long (108 > 100 characters)
        ./python/pyspark/sql/streaming.py:1063:101: E501 line too long (128 > 100 characters)
        ./python/pyspark/sql/streaming.py:1071:101: E501 line too long (112 > 100 characters)
        ./python/pyspark/sql/streaming.py:1080:101: E501 line too long (124 > 100 characters)
        ./python/pyspark/sql/streaming.py:1259:101: E501 line too long (134 > 100 characters)
        ./python/pyspark/sql/pandas/conversion.py:136:101: E501 line too long (106 > 100 characters)
        ./python/pyspark/ml/param/_shared_params_code_gen.py:111:101: E501 line too long (103 > 100 characters)
        ./python/pyspark/ml/param/_shared_params_code_gen.py:136:101: E501 line too long (105 > 100 characters)
        ./python/pyspark/ml/param/_shared_params_code_gen.py:163:101: E501 line too long (101 > 100 characters)
        ./python/pyspark/ml/param/_shared_params_code_gen.py:233:101: E501 line too long (101 > 100 characters)
        ./python/pyspark/ml/param/_shared_params_code_gen.py:265:101: E501 line too long (101 > 100 characters)
        ./python/pyspark/tests/test_readwrite.py:235:101: E501 line too long (114 > 100 characters)
        ./python/pyspark/tests/test_readwrite.py:336:101: E501 line too long (114 > 100 characters)
        ```
    - After reformatting, minor typing changes were made:
       - Realign certain `type: ignore` comments with ignored code.
       -  Apply explicit `casts` to  ` unittest.skipIf` messages.  The following
          ```python
          unittest.skipIf(
              not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
          )  # type: ignore[arg-type]
          ```
          replaced with
    
          ```python
          unittest.skipIf(
              not have_pandas or not have_pyarrow,
              cast(str, pandas_requirement_message or pyarrow_requirement_message),
          )
          ```
    
    ### Why are the changes needed?
    
    Consistency and reduced maintenance overhead.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ### How was this patch tested?
    
    Existing liners and tests.
    
    Closes #34297 from zero323/SPARK-37022.
    
    Authored-by: zero323 <ms...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 dev/lint-python                                    |    2 +-
 dev/pyproject.toml                                 |    6 +
 dev/reformat-python                                |    3 +-
 dev/tox.ini                                        |    2 +-
 python/pyspark/__init__.py                         |   41 +-
 python/pyspark/_globals.py                         |    7 +-
 python/pyspark/accumulators.py                     |   20 +-
 python/pyspark/accumulators.pyi                    |    4 +-
 python/pyspark/broadcast.py                        |   20 +-
 python/pyspark/conf.py                             |   22 +-
 python/pyspark/context.py                          |  309 ++--
 python/pyspark/daemon.py                           |   13 +-
 python/pyspark/files.py                            |    5 +-
 python/pyspark/find_spark_home.py                  |   15 +-
 python/pyspark/install.py                          |   38 +-
 python/pyspark/java_gateway.py                     |   31 +-
 python/pyspark/join.py                             |    5 +
 python/pyspark/ml/__init__.py                      |   52 +-
 python/pyspark/ml/_typing.pyi                      |    4 +-
 python/pyspark/ml/base.py                          |   34 +-
 python/pyspark/ml/base.pyi                         |    8 +-
 python/pyspark/ml/classification.py                | 1259 ++++++++++-----
 python/pyspark/ml/classification.pyi               |   67 +-
 python/pyspark/ml/clustering.py                    |  535 +++++--
 python/pyspark/ml/clustering.pyi                   |   28 +-
 python/pyspark/ml/common.py                        |   29 +-
 python/pyspark/ml/evaluation.py                    |  394 +++--
 python/pyspark/ml/evaluation.pyi                   |   36 +-
 python/pyspark/ml/feature.py                       | 1633 ++++++++++++++------
 python/pyspark/ml/feature.pyi                      |  228 +--
 python/pyspark/ml/fpm.py                           |  149 +-
 python/pyspark/ml/fpm.pyi                          |   12 +-
 python/pyspark/ml/functions.py                     |   22 +-
 python/pyspark/ml/functions.pyi                    |    1 -
 python/pyspark/ml/image.py                         |   34 +-
 python/pyspark/ml/linalg/__init__.py               |  240 +--
 python/pyspark/ml/linalg/__init__.pyi              |   16 +-
 python/pyspark/ml/param/__init__.py                |   32 +-
 python/pyspark/ml/param/_shared_params_code_gen.py |  242 ++-
 python/pyspark/ml/param/shared.py                  |  238 ++-
 python/pyspark/ml/pipeline.py                      |   52 +-
 python/pyspark/ml/pipeline.pyi                     |    4 +-
 python/pyspark/ml/recommendation.py                |  206 ++-
 python/pyspark/ml/recommendation.pyi               |   16 +-
 python/pyspark/ml/regression.py                    |  956 +++++++++---
 python/pyspark/ml/regression.pyi                   |   44 +-
 python/pyspark/ml/stat.py                          |   35 +-
 python/pyspark/ml/stat.pyi                         |    8 +-
 python/pyspark/ml/tests/test_algorithms.py         |  245 +--
 python/pyspark/ml/tests/test_base.py               |   26 +-
 python/pyspark/ml/tests/test_evaluation.py         |   20 +-
 python/pyspark/ml/tests/test_feature.py            |  254 ++-
 python/pyspark/ml/tests/test_image.py              |   32 +-
 python/pyspark/ml/tests/test_linalg.py             |  145 +-
 python/pyspark/ml/tests/test_param.py              |  129 +-
 python/pyspark/ml/tests/test_persistence.py        |  214 ++-
 python/pyspark/ml/tests/test_pipeline.py           |   10 +-
 python/pyspark/ml/tests/test_stat.py               |   16 +-
 python/pyspark/ml/tests/test_training_summary.py   |  152 +-
 python/pyspark/ml/tests/test_tuning.py             |  650 ++++----
 python/pyspark/ml/tests/test_util.py               |   43 +-
 python/pyspark/ml/tests/test_wrapper.py            |   33 +-
 python/pyspark/ml/tree.py                          |  229 ++-
 python/pyspark/ml/tree.pyi                         |    4 +-
 python/pyspark/ml/tuning.py                        |  475 +++---
 python/pyspark/ml/tuning.pyi                       |   12 +-
 python/pyspark/ml/util.py                          |   95 +-
 python/pyspark/ml/wrapper.py                       |   13 +-
 python/pyspark/ml/wrapper.pyi                      |    4 +-
 python/pyspark/mllib/__init__.py                   |   17 +-
 python/pyspark/mllib/classification.py             |  226 ++-
 python/pyspark/mllib/clustering.py                 |  233 ++-
 python/pyspark/mllib/clustering.pyi                |   20 +-
 python/pyspark/mllib/common.py                     |   36 +-
 python/pyspark/mllib/evaluation.py                 |  125 +-
 python/pyspark/mllib/evaluation.pyi                |    4 +-
 python/pyspark/mllib/feature.py                    |  137 +-
 python/pyspark/mllib/feature.pyi                   |    4 +-
 python/pyspark/mllib/fpm.py                        |   18 +-
 python/pyspark/mllib/fpm.pyi                       |    4 +-
 python/pyspark/mllib/linalg/__init__.py            |  289 ++--
 python/pyspark/mllib/linalg/__init__.pyi           |   16 +-
 python/pyspark/mllib/linalg/distributed.py         |  130 +-
 python/pyspark/mllib/linalg/distributed.pyi        |   12 +-
 python/pyspark/mllib/random.py                     |   57 +-
 python/pyspark/mllib/recommendation.py             |   65 +-
 python/pyspark/mllib/recommendation.pyi            |    8 +-
 python/pyspark/mllib/regression.py                 |  194 ++-
 python/pyspark/mllib/regression.pyi                |   12 +-
 python/pyspark/mllib/stat/KernelDensity.py         |    4 +-
 python/pyspark/mllib/stat/__init__.py              |   11 +-
 python/pyspark/mllib/stat/_statistics.py           |   18 +-
 python/pyspark/mllib/stat/_statistics.pyi          |   12 +-
 python/pyspark/mllib/stat/distribution.py          |    3 +-
 python/pyspark/mllib/tests/test_algorithms.py      |  144 +-
 python/pyspark/mllib/tests/test_feature.py         |   70 +-
 python/pyspark/mllib/tests/test_linalg.py          |  185 ++-
 python/pyspark/mllib/tests/test_stat.py            |   45 +-
 .../mllib/tests/test_streaming_algorithms.py       |  117 +-
 python/pyspark/mllib/tests/test_util.py            |   24 +-
 python/pyspark/mllib/tree.py                       |  275 +++-
 python/pyspark/mllib/tree.pyi                      |    4 +-
 python/pyspark/mllib/util.py                       |   42 +-
 python/pyspark/pandas/base.py                      |   14 +-
 python/pyspark/pandas/config.py                    |    2 +-
 python/pyspark/pandas/data_type_ops/base.py        |    4 +-
 .../pandas/data_type_ops/categorical_ops.py        |    2 +-
 python/pyspark/pandas/frame.py                     |   12 +-
 python/pyspark/pandas/generic.py                   |    6 +-
 python/pyspark/pandas/groupby.py                   |    2 +-
 python/pyspark/pandas/indexes/base.py              |    4 +-
 python/pyspark/pandas/namespace.py                 |   24 +-
 python/pyspark/pandas/plot/matplotlib.py           |    6 +-
 python/pyspark/pandas/series.py                    |    4 +-
 python/pyspark/pandas/sql_processor.py             |    2 +-
 python/pyspark/pandas/utils.py                     |    6 +-
 python/pyspark/profiler.py                         |   20 +-
 python/pyspark/profiler.pyi                        |    4 +-
 python/pyspark/rdd.py                              |  342 ++--
 python/pyspark/rdd.pyi                             |   40 +-
 python/pyspark/rddsampler.py                       |    4 -
 python/pyspark/resource/__init__.py                |   16 +-
 python/pyspark/resource/profile.py                 |   70 +-
 python/pyspark/resource/requests.py                |   65 +-
 python/pyspark/resource/tests/test_resources.py    |    7 +-
 python/pyspark/serializers.py                      |   65 +-
 python/pyspark/shell.py                            |   20 +-
 python/pyspark/shuffle.py                          |  118 +-
 python/pyspark/sql/__init__.py                     |   21 +-
 python/pyspark/sql/avro/__init__.py                |    2 +-
 python/pyspark/sql/avro/functions.py               |   26 +-
 python/pyspark/sql/catalog.py                      |  106 +-
 python/pyspark/sql/column.py                       |  122 +-
 python/pyspark/sql/conf.py                         |   15 +-
 python/pyspark/sql/context.py                      |  120 +-
 python/pyspark/sql/dataframe.py                    |  338 ++--
 python/pyspark/sql/functions.py                    |  235 ++-
 python/pyspark/sql/group.py                        |   58 +-
 python/pyspark/sql/observation.py                  |   14 +-
 python/pyspark/sql/pandas/_typing/__init__.pyi     |    8 +-
 .../pyspark/sql/pandas/_typing/protocols/frame.pyi |   38 +-
 .../sql/pandas/_typing/protocols/series.pyi        |   38 +-
 python/pyspark/sql/pandas/conversion.py            |  192 ++-
 python/pyspark/sql/pandas/functions.py             |   69 +-
 python/pyspark/sql/pandas/functions.pyi            |   16 +-
 python/pyspark/sql/pandas/group_ops.py             |   39 +-
 python/pyspark/sql/pandas/map_ops.py               |   19 +-
 python/pyspark/sql/pandas/serializers.py           |   81 +-
 python/pyspark/sql/pandas/typehints.py             |  123 +-
 python/pyspark/sql/pandas/types.py                 |   94 +-
 python/pyspark/sql/pandas/utils.py                 |   39 +-
 python/pyspark/sql/readwriter.py                   |  260 ++--
 python/pyspark/sql/session.py                      |  192 ++-
 python/pyspark/sql/streaming.py                    |  191 ++-
 python/pyspark/sql/tests/test_arrow.py             |  258 +++-
 python/pyspark/sql/tests/test_catalog.py           |  227 +--
 python/pyspark/sql/tests/test_column.py            |  111 +-
 python/pyspark/sql/tests/test_conf.py              |    4 +-
 python/pyspark/sql/tests/test_context.py           |   75 +-
 python/pyspark/sql/tests/test_dataframe.py         |  659 +++++---
 python/pyspark/sql/tests/test_datasources.py       |  147 +-
 python/pyspark/sql/tests/test_functions.py         |  482 +++---
 python/pyspark/sql/tests/test_group.py             |   12 +-
 .../pyspark/sql/tests/test_pandas_cogrouped_map.py |  215 +--
 .../pyspark/sql/tests/test_pandas_grouped_map.py   |  567 ++++---
 python/pyspark/sql/tests/test_pandas_map.py        |   47 +-
 python/pyspark/sql/tests/test_pandas_udf.py        |  136 +-
 .../sql/tests/test_pandas_udf_grouped_agg.py       |  462 +++---
 python/pyspark/sql/tests/test_pandas_udf_scalar.py |  860 ++++++-----
 .../pyspark/sql/tests/test_pandas_udf_typehints.py |  141 +-
 python/pyspark/sql/tests/test_pandas_udf_window.py |  244 +--
 python/pyspark/sql/tests/test_readwriter.py        |   66 +-
 python/pyspark/sql/tests/test_serde.py             |   28 +-
 python/pyspark/sql/tests/test_session.py           |  111 +-
 python/pyspark/sql/tests/test_streaming.py         |  200 +--
 python/pyspark/sql/tests/test_types.py             |  690 +++++----
 python/pyspark/sql/tests/test_udf.py               |  281 ++--
 python/pyspark/sql/tests/test_utils.py             |   33 +-
 python/pyspark/sql/types.py                        |  474 +++---
 python/pyspark/sql/udf.py                          |  125 +-
 python/pyspark/sql/utils.py                        |   84 +-
 python/pyspark/sql/window.py                       |   11 +-
 python/pyspark/statcounter.py                      |   30 +-
 python/pyspark/status.py                           |    3 +
 python/pyspark/storagelevel.py                     |    8 +-
 python/pyspark/streaming/__init__.py               |    2 +-
 python/pyspark/streaming/context.py                |   33 +-
 python/pyspark/streaming/context.pyi               |    8 +-
 python/pyspark/streaming/dstream.py                |   99 +-
 python/pyspark/streaming/dstream.pyi               |   16 +-
 python/pyspark/streaming/kinesis.py                |   18 +-
 python/pyspark/streaming/listener.py               |    1 -
 python/pyspark/streaming/tests/test_context.py     |   13 +-
 python/pyspark/streaming/tests/test_dstream.py     |  241 +--
 python/pyspark/streaming/tests/test_kinesis.py     |   57 +-
 python/pyspark/streaming/tests/test_listener.py    |    7 +-
 python/pyspark/streaming/util.py                   |   18 +-
 python/pyspark/taskcontext.py                      |   25 +-
 python/pyspark/testing/mllibutils.py               |    2 +-
 python/pyspark/testing/mlutils.py                  |   85 +-
 python/pyspark/testing/pandasutils.py              |   22 +-
 python/pyspark/testing/sqlutils.py                 |   16 +-
 python/pyspark/testing/streamingutils.py           |   18 +-
 python/pyspark/testing/utils.py                    |   22 +-
 python/pyspark/tests/test_appsubmit.py             |  168 +-
 python/pyspark/tests/test_broadcast.py             |    8 +-
 python/pyspark/tests/test_conf.py                  |    3 +-
 python/pyspark/tests/test_context.py               |   66 +-
 python/pyspark/tests/test_daemon.py                |    7 +-
 python/pyspark/tests/test_install_spark.py         |   59 +-
 python/pyspark/tests/test_join.py                  |    4 +-
 python/pyspark/tests/test_pin_thread.py            |   15 +-
 python/pyspark/tests/test_profiler.py              |   12 +-
 python/pyspark/tests/test_rdd.py                   |  172 ++-
 python/pyspark/tests/test_rddbarrier.py            |    4 +-
 python/pyspark/tests/test_readwrite.py             |  286 ++--
 python/pyspark/tests/test_serializers.py           |   86 +-
 python/pyspark/tests/test_shuffle.py               |   40 +-
 python/pyspark/tests/test_taskcontext.py           |   37 +-
 python/pyspark/tests/test_util.py                  |    6 +-
 python/pyspark/tests/test_worker.py                |   31 +-
 python/pyspark/traceback_utils.py                  |    5 +-
 python/pyspark/util.py                             |   56 +-
 python/pyspark/worker.py                           |  193 ++-
 224 files changed, 15237 insertions(+), 9349 deletions(-)

diff --git a/dev/lint-python b/dev/lint-python
index d77d645..30b1816 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -256,7 +256,7 @@ function black_test {
 
     echo "starting black test..."
     # Black is only applied for pandas API on Spark for now.
-    BLACK_REPORT=$( ($BLACK_BUILD python/pyspark/pandas --line-length 100 --check ) 2>&1)
+    BLACK_REPORT=$( ($BLACK_BUILD  --config dev/pyproject.toml --check python/pyspark) 2>&1)
     BLACK_STATUS=$?
 
     if [ "$BLACK_STATUS" -ne 0 ]; then
diff --git a/dev/pyproject.toml b/dev/pyproject.toml
index 286b728..b90257f 100644
--- a/dev/pyproject.toml
+++ b/dev/pyproject.toml
@@ -23,3 +23,9 @@ testpaths = [
   "pyspark/sql/tests/typing",
   "pyspark/ml/typing",
 ]
+
+[tool.black]
+line-length = 100
+target-version = ['py37']
+include = '\.pyi?$'
+extend-exclude = 'cloudpickle'
diff --git a/dev/reformat-python b/dev/reformat-python
index 1b5ee65..0b712d5 100755
--- a/dev/reformat-python
+++ b/dev/reformat-python
@@ -28,5 +28,4 @@ if [ $? -ne 0 ]; then
     exit 1
 fi
 
-# This script is only applied for pandas API on Spark for now.
-$BLACK_BUILD python/pyspark/pandas --line-length 100
+$BLACK_BUILD --config python/pyproject.toml python/pyspark
diff --git a/dev/tox.ini b/dev/tox.ini
index e1a4cf5..bd69a3f 100644
--- a/dev/tox.ini
+++ b/dev/tox.ini
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 [pycodestyle]
-ignore=E203,E226,E241,E305,E402,E722,E731,E741,W503,W504
+ignore=E203,E226,E241,E305,E402,E722,E731,E741,W503,W504,E501
 max-line-length=100
 exclude=*/target/*,python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,dev/ansible-for-test-node/*
 
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 9651b53..70392fb 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -69,13 +69,15 @@ def since(version):
     A decorator that annotates a function to append the version of Spark the function was added.
     """
     import re
-    indent_p = re.compile(r'\n( +)')
+
+    indent_p = re.compile(r"\n( +)")
 
     def deco(f):
         indents = indent_p.findall(f.__doc__)
-        indent = ' ' * (min(len(m) for m in indents) if indents else 0)
+        indent = " " * (min(len(m) for m in indents) if indents else 0)
         f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
         return f
+
     return deco
 
 
@@ -86,8 +88,9 @@ def copy_func(f, name=None, sinceversion=None, doc=None):
     """
     # See
     # http://stackoverflow.com/questions/6527633/how-can-i-make-a-deepcopy-of-a-function-in-python
-    fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__, f.__defaults__,
-                            f.__closure__)
+    fn = types.FunctionType(
+        f.__code__, f.__globals__, name or f.__name__, f.__defaults__, f.__closure__
+    )
     # in case f was given attrs (note this dict is a shallow copy):
     fn.__dict__.update(f.__dict__)
     if doc is not None:
@@ -106,14 +109,17 @@ def keyword_only(func):
     -----
     Should only be used to wrap a method where first arg is `self`
     """
+
     @wraps(func)
     def wrapper(self, *args, **kwargs):
         if len(args) > 0:
             raise TypeError("Method %s forces keyword arguments." % func.__name__)
         self._input_kwargs = kwargs
         return func(self, **kwargs)
+
     return wrapper
 
+
 # To avoid circular dependencies
 from pyspark.context import SparkContext
 
@@ -121,9 +127,26 @@ from pyspark.context import SparkContext
 from pyspark.sql import SQLContext, HiveContext, Row  # noqa: F401
 
 __all__ = [
-    "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
-    "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
-    "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext",
-    "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "InheritableThread",
-    "inheritable_thread_target", "__version__",
+    "SparkConf",
+    "SparkContext",
+    "SparkFiles",
+    "RDD",
+    "StorageLevel",
+    "Broadcast",
+    "Accumulator",
+    "AccumulatorParam",
+    "MarshalSerializer",
+    "PickleSerializer",
+    "StatusTracker",
+    "SparkJobInfo",
+    "SparkStageInfo",
+    "Profiler",
+    "BasicProfiler",
+    "TaskContext",
+    "RDDBarrier",
+    "BarrierTaskContext",
+    "BarrierTaskInfo",
+    "InheritableThread",
+    "inheritable_thread_target",
+    "__version__",
 ]
diff --git a/python/pyspark/_globals.py b/python/pyspark/_globals.py
index 8e6099d..a635972 100644
--- a/python/pyspark/_globals.py
+++ b/python/pyspark/_globals.py
@@ -32,13 +32,13 @@ See gh-7844 for a discussion of the reload problem that motivated this module.
 Note that this approach is taken after from NumPy.
 """
 
-__ALL__ = ['_NoValue']
+__ALL__ = ["_NoValue"]
 
 
 # Disallow reloading this module so as to preserve the identities of the
 # classes defined here.
-if '_is_loaded' in globals():
-    raise RuntimeError('Reloading pyspark._globals is not allowed')
+if "_is_loaded" in globals():
+    raise RuntimeError("Reloading pyspark._globals is not allowed")
 _is_loaded = True
 
 
@@ -51,6 +51,7 @@ class _NoValueType(object):
 
     This class was copied from NumPy.
     """
+
     __instance = None
 
     def __new__(cls):
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index e6106b3..c43ebe4 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -23,7 +23,7 @@ import threading
 from pyspark.serializers import read_int, PickleSerializer
 
 
-__all__ = ['Accumulator', 'AccumulatorParam']
+__all__ = ["Accumulator", "AccumulatorParam"]
 
 
 pickleSer = PickleSerializer()
@@ -35,6 +35,7 @@ _accumulatorRegistry = {}
 
 def _deserialize_accumulator(aid, zero_value, accum_param):
     from pyspark.accumulators import _accumulatorRegistry
+
     # If this certain accumulator was deserialized, don't overwrite it.
     if aid in _accumulatorRegistry:
         return _accumulatorRegistry[aid]
@@ -108,6 +109,7 @@ class Accumulator(object):
     def __init__(self, aid, value, accum_param):
         """Create a new Accumulator with a given initial value and AccumulatorParam object"""
         from pyspark.accumulators import _accumulatorRegistry
+
         self.aid = aid
         self.accum_param = accum_param
         self._value = value
@@ -225,6 +227,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
 
     def handle(self):
         from pyspark.accumulators import _accumulatorRegistry
+
         auth_token = self.server.auth_token
 
         def poll(func):
@@ -248,13 +251,14 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
             received_token = self.rfile.read(len(auth_token))
             if isinstance(received_token, bytes):
                 received_token = received_token.decode("utf-8")
-            if (received_token == auth_token):
+            if received_token == auth_token:
                 accum_updates()
                 # we've authenticated, we can break out of the first loop now
                 return True
             else:
                 raise ValueError(
-                    "The value of the provided token to the AccumulatorServer is not correct.")
+                    "The value of the provided token to the AccumulatorServer is not correct."
+                )
 
         # first we keep polling till we've received the authentication token
         poll(authenticate_and_accum_updates)
@@ -263,7 +267,6 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
 
 
 class AccumulatorServer(SocketServer.TCPServer):
-
     def __init__(self, server_address, RequestHandlerClass, auth_token):
         SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass)
         self.auth_token = auth_token
@@ -288,16 +291,17 @@ def _start_update_server(auth_token):
     thread.start()
     return server
 
+
 if __name__ == "__main__":
     import doctest
 
     from pyspark.context import SparkContext
+
     globs = globals().copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    globs['sc'] = SparkContext('local', 'test')
-    (failure_count, test_count) = doctest.testmod(
-        globs=globs, optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    globs["sc"] = SparkContext("local", "test")
+    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+    globs["sc"].stop()
     if failure_count:
         sys.exit(-1)
diff --git a/python/pyspark/accumulators.pyi b/python/pyspark/accumulators.pyi
index 13a1792cd..3159792 100644
--- a/python/pyspark/accumulators.pyi
+++ b/python/pyspark/accumulators.pyi
@@ -32,9 +32,7 @@ _accumulatorRegistry: Dict[int, Accumulator]
 class Accumulator(Generic[T]):
     aid: int
     accum_param: AccumulatorParam[T]
-    def __init__(
-        self, aid: int, value: T, accum_param: AccumulatorParam[T]
-    ) -> None: ...
+    def __init__(self, aid: int, value: T, accum_param: AccumulatorParam[T]) -> None: ...
     def __reduce__(
         self,
     ) -> Tuple[
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index aecd71f..995da33 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -27,7 +27,7 @@ from pyspark.serializers import ChunkedStream, pickle_protocol
 from pyspark.util import print_exec
 
 
-__all__ = ['Broadcast']
+__all__ = ["Broadcast"]
 
 
 # Holds broadcasted data received from Java, keyed by its id.
@@ -36,6 +36,7 @@ _broadcastRegistry = {}
 
 def _from_id(bid):
     from pyspark.broadcast import _broadcastRegistry
+
     if bid not in _broadcastRegistry:
         raise RuntimeError("Broadcast variable '%s' not loaded!" % bid)
     return _broadcastRegistry[bid]
@@ -61,8 +62,7 @@ class Broadcast(object):
     >>> large_broadcast = sc.broadcast(range(10000))
     """
 
-    def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
-                 sock_file=None):
+    def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_file=None):
         """
         Should not be called directly by users -- use :meth:`SparkContext.broadcast`
         instead.
@@ -99,7 +99,7 @@ class Broadcast(object):
             else:
                 # the jvm just dumps the pickled data in path -- we'll unpickle lazily when
                 # the value is requested
-                assert(path is not None)
+                assert path is not None
                 self._path = path
 
     def dump(self, value, f):
@@ -108,14 +108,13 @@ class Broadcast(object):
         except pickle.PickleError:
             raise
         except Exception as e:
-            msg = "Could not serialize broadcast: %s: %s" \
-                  % (e.__class__.__name__, str(e))
+            msg = "Could not serialize broadcast: %s: %s" % (e.__class__.__name__, str(e))
             print_exec(sys.stderr)
             raise pickle.PicklingError(msg)
         f.close()
 
     def load_from_path(self, path):
-        with open(path, 'rb', 1 << 20) as f:
+        with open(path, "rb", 1 << 20) as f:
             return self.load(f)
 
     def load(self, file):
@@ -128,8 +127,7 @@ class Broadcast(object):
 
     @property
     def value(self):
-        """ Return the broadcasted value
-        """
+        """Return the broadcasted value"""
         if not hasattr(self, "_value") and self._path is not None:
             # we only need to decrypt it here when encryption is enabled and
             # if its on the driver, since executor decryption is handled already
@@ -185,8 +183,7 @@ class Broadcast(object):
 
 
 class BroadcastPickleRegistry(threading.local):
-    """ Thread-local registry for broadcast variables that have been pickled
-    """
+    """Thread-local registry for broadcast variables that have been pickled"""
 
     def __init__(self):
         self.__dict__.setdefault("_registry", set())
@@ -204,6 +201,7 @@ class BroadcastPickleRegistry(threading.local):
 
 if __name__ == "__main__":
     import doctest
+
     (failure_count, test_count) = doctest.testmod()
     if failure_count:
         sys.exit(-1)
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index a6b2784..47ea8b6 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 
-__all__ = ['SparkConf']
+__all__ = ["SparkConf"]
 
 import sys
 from typing import Dict, List, Optional, Tuple, cast, overload
@@ -110,8 +110,12 @@ class SparkConf(object):
     _jconf: Optional[JavaObject]
     _conf: Optional[Dict[str, str]]
 
-    def __init__(self, loadDefaults: bool = True, _jvm: Optional[JVMView] = None,
-                 _jconf: Optional[JavaObject] = None):
+    def __init__(
+        self,
+        loadDefaults: bool = True,
+        _jvm: Optional[JVMView] = None,
+        _jconf: Optional[JavaObject] = None,
+    ):
         """
         Create a new Spark configuration.
         """
@@ -119,6 +123,7 @@ class SparkConf(object):
             self._jconf = _jconf
         else:
             from pyspark.context import SparkContext
+
             _jvm = _jvm or SparkContext._jvm  # type: ignore[attr-defined]
 
             if _jvm is not None:
@@ -169,8 +174,12 @@ class SparkConf(object):
     def setExecutorEnv(self, *, pairs: List[Tuple[str, str]]) -> "SparkConf":
         ...
 
-    def setExecutorEnv(self, key: Optional[str] = None, value: Optional[str] = None,
-                       pairs: Optional[List[Tuple[str, str]]] = None) -> "SparkConf":
+    def setExecutorEnv(
+        self,
+        key: Optional[str] = None,
+        value: Optional[str] = None,
+        pairs: Optional[List[Tuple[str, str]]] = None,
+    ) -> "SparkConf":
         """Set an environment variable to be passed to executors."""
         if (key is not None and pairs is not None) or (key is None and pairs is None):
             raise RuntimeError("Either pass one key-value pair or a list of pairs")
@@ -236,11 +245,12 @@ class SparkConf(object):
             return self._jconf.toDebugString()
         else:
             assert self._conf is not None
-            return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items())
+            return "\n".join("%s=%s" % (k, v) for k, v in self._conf.items())
 
 
 def _test() -> None:
     import doctest
+
     (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
     if failure_count:
         sys.exit(-1)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 9ed1a82..2c78994 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -34,8 +34,15 @@ from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
 from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway, local_connect_and_auth
-from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
-    PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
+from pyspark.serializers import (
+    PickleSerializer,
+    BatchedSerializer,
+    UTF8Deserializer,
+    PairDeserializer,
+    AutoBatchedSerializer,
+    NoOpSerializer,
+    ChunkedStream,
+)
 from pyspark.storagelevel import StorageLevel
 from pyspark.resource.information import ResourceInformation
 from pyspark.rdd import RDD, _load_from_socket
@@ -45,7 +52,7 @@ from pyspark.status import StatusTracker
 from pyspark.profiler import ProfilerCollector, BasicProfiler
 
 
-__all__ = ['SparkContext']
+__all__ = ["SparkContext"]
 
 
 # These are special default configs for PySpark, they will overwrite
@@ -125,13 +132,23 @@ class SparkContext(object):
     _lock = RLock()
     _python_includes = None  # zip and egg files that need to be added to PYTHONPATH
 
-    PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')
-
-    def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
-                 environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
-                 gateway=None, jsc=None, profiler_cls=BasicProfiler):
-        if (conf is None or
-                conf.get("spark.executor.allowSparkContext", "false").lower() != "true"):
+    PACKAGE_EXTENSIONS = (".zip", ".egg", ".jar")
+
+    def __init__(
+        self,
+        master=None,
+        appName=None,
+        sparkHome=None,
+        pyFiles=None,
+        environment=None,
+        batchSize=0,
+        serializer=PickleSerializer(),
+        conf=None,
+        gateway=None,
+        jsc=None,
+        profiler_cls=BasicProfiler,
+    ):
+        if conf is None or conf.get("spark.executor.allowSparkContext", "false").lower() != "true":
             # In order to prevent SparkContext from being created in executors.
             SparkContext._assert_on_driver()
 
@@ -139,19 +156,41 @@ class SparkContext(object):
         if gateway is not None and gateway.gateway_parameters.auth_token is None:
             raise ValueError(
                 "You are trying to pass an insecure Py4j gateway to Spark. This"
-                " is not allowed as it is a security risk.")
+                " is not allowed as it is a security risk."
+            )
 
         SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
         try:
-            self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
-                          conf, jsc, profiler_cls)
+            self._do_init(
+                master,
+                appName,
+                sparkHome,
+                pyFiles,
+                environment,
+                batchSize,
+                serializer,
+                conf,
+                jsc,
+                profiler_cls,
+            )
         except:
             # If an error occurs, clean up in order to allow future SparkContext creation:
             self.stop()
             raise
 
-    def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
-                 conf, jsc, profiler_cls):
+    def _do_init(
+        self,
+        master,
+        appName,
+        sparkHome,
+        pyFiles,
+        environment,
+        batchSize,
+        serializer,
+        conf,
+        jsc,
+        profiler_cls,
+    ):
         self.environment = environment or {}
         # java gateway must have been launched at this point.
         if conf is not None and conf._jconf is not None:
@@ -170,8 +209,7 @@ class SparkContext(object):
         if batchSize == 0:
             self.serializer = AutoBatchedSerializer(self._unbatched_serializer)
         else:
-            self.serializer = BatchedSerializer(self._unbatched_serializer,
-                                                batchSize)
+            self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
 
         # Set any parameters passed directly to us on the conf
         if master:
@@ -200,7 +238,7 @@ class SparkContext(object):
 
         for (k, v) in self._conf.getAll():
             if k.startswith("spark.executorEnv."):
-                varName = k[len("spark.executorEnv."):]
+                varName = k[len("spark.executorEnv.") :]
                 self.environment[varName] = v
 
         self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", "0")
@@ -222,21 +260,18 @@ class SparkContext(object):
         # data via a socket.
         # scala's mangled names w/ $ in them require special treatment.
         self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc)
-        os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = \
-            str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc))
-        os.environ["SPARK_BUFFER_SIZE"] = \
-            str(self._jvm.PythonUtils.getSparkBufferSize(self._jsc))
+        os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = str(
+            self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc)
+        )
+        os.environ["SPARK_BUFFER_SIZE"] = str(self._jvm.PythonUtils.getSparkBufferSize(self._jsc))
 
-        self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python3')
+        self.pythonExec = os.environ.get("PYSPARK_PYTHON", "python3")
         self.pythonVer = "%d.%d" % sys.version_info[:2]
 
         if sys.version_info[:2] < (3, 7):
             with warnings.catch_warnings():
                 warnings.simplefilter("once")
-                warnings.warn(
-                    "Python 3.6 support is deprecated in Spark 3.2.",
-                    FutureWarning
-                )
+                warnings.warn("Python 3.6 support is deprecated in Spark 3.2.", FutureWarning)
 
         # Broadcast's __reduce__ method stores Broadcast instances here.
         # This allows other code to determine which Broadcast instances have
@@ -250,7 +285,7 @@ class SparkContext(object):
 
         # Deploy any code dependencies specified in the constructor
         self._python_includes = list()
-        for path in (pyFiles or []):
+        for path in pyFiles or []:
             self.addPyFile(path)
 
         # Deploy code dependencies set by spark-submit; these will already have been added
@@ -272,13 +307,14 @@ class SparkContext(object):
                     warnings.warn(
                         "Failed to add file [%s] specified in 'spark.submit.pyFiles' to "
                         "Python path:\n  %s" % (path, "\n  ".join(sys.path)),
-                        RuntimeWarning)
+                        RuntimeWarning,
+                    )
 
         # Create a temporary directory inside spark.local.dir:
         local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
-        self._temp_dir = \
-            self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark") \
-                .getAbsolutePath()
+        self._temp_dir = self._jvm.org.apache.spark.util.Utils.createTempDir(
+            local_dir, "pyspark"
+        ).getAbsolutePath()
 
         # profiling stats collected for each PythonRDD
         if self._conf.get("spark.python.profile", "false") == "true":
@@ -340,8 +376,10 @@ class SparkContext(object):
                 SparkContext._jvm = SparkContext._gateway.jvm
 
             if instance:
-                if (SparkContext._active_spark_context and
-                        SparkContext._active_spark_context != instance):
+                if (
+                    SparkContext._active_spark_context
+                    and SparkContext._active_spark_context != instance
+                ):
                     currentMaster = SparkContext._active_spark_context.master
                     currentAppName = SparkContext._active_spark_context.appName
                     callsite = SparkContext._active_spark_context._callsite
@@ -351,8 +389,14 @@ class SparkContext(object):
                         "Cannot run multiple SparkContexts at once; "
                         "existing SparkContext(app=%s, master=%s)"
                         " created by %s at %s:%s "
-                        % (currentAppName, currentMaster,
-                            callsite.function, callsite.file, callsite.linenum))
+                        % (
+                            currentAppName,
+                            currentMaster,
+                            callsite.function,
+                            callsite.file,
+                            callsite.linenum,
+                        )
+                    )
                 else:
                     SparkContext._active_spark_context = instance
 
@@ -466,10 +510,10 @@ class SparkContext(object):
             except Py4JError:
                 # Case: SPARK-18523
                 warnings.warn(
-                    'Unable to cleanly shutdown Spark JVM process.'
-                    ' It is possible that the process has crashed,'
-                    ' been killed or may also be in a zombie state.',
-                    RuntimeWarning
+                    "Unable to cleanly shutdown Spark JVM process."
+                    " It is possible that the process has crashed,"
+                    " been killed or may also be in a zombie state.",
+                    RuntimeWarning,
                 )
             finally:
                 self._jsc = None
@@ -561,7 +605,7 @@ class SparkContext(object):
 
         # Make sure we distribute data evenly if it's smaller than self.batchSize
         if "__len__" not in dir(c):
-            c = list(c)    # Make it a list so we can compute its length
+            c = list(c)  # Make it a list so we can compute its length
         batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
         serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
 
@@ -652,8 +696,7 @@ class SparkContext(object):
         ['Hello world!']
         """
         minPartitions = minPartitions or min(self.defaultParallelism, 2)
-        return RDD(self._jsc.textFile(name, minPartitions), self,
-                   UTF8Deserializer(use_unicode))
+        return RDD(self._jsc.textFile(name, minPartitions), self, UTF8Deserializer(use_unicode))
 
     def wholeTextFiles(self, path, minPartitions=None, use_unicode=True):
         """
@@ -704,8 +747,11 @@ class SparkContext(object):
         [('.../1.txt', '1'), ('.../2.txt', '2')]
         """
         minPartitions = minPartitions or self.defaultMinPartitions
-        return RDD(self._jsc.wholeTextFiles(path, minPartitions), self,
-                   PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode)))
+        return RDD(
+            self._jsc.wholeTextFiles(path, minPartitions),
+            self,
+            PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode)),
+        )
 
     def binaryFiles(self, path, minPartitions=None):
         """
@@ -720,8 +766,11 @@ class SparkContext(object):
         Small files are preferred, large file is also allowable, but may cause bad performance.
         """
         minPartitions = minPartitions or self.defaultMinPartitions
-        return RDD(self._jsc.binaryFiles(path, minPartitions), self,
-                   PairDeserializer(UTF8Deserializer(), NoOpSerializer()))
+        return RDD(
+            self._jsc.binaryFiles(path, minPartitions),
+            self,
+            PairDeserializer(UTF8Deserializer(), NoOpSerializer()),
+        )
 
     def binaryRecords(self, path, recordLength):
         """
@@ -746,8 +795,16 @@ class SparkContext(object):
             jm[k] = v
         return jm
 
-    def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
-                     valueConverter=None, minSplits=None, batchSize=0):
+    def sequenceFile(
+        self,
+        path,
+        keyClass=None,
+        valueClass=None,
+        keyConverter=None,
+        valueConverter=None,
+        minSplits=None,
+        batchSize=0,
+    ):
         """
         Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS,
         a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -779,12 +836,29 @@ class SparkContext(object):
             Java object. (default 0, choose batchSize automatically)
         """
         minSplits = minSplits or min(self.defaultParallelism, 2)
-        jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass,
-                                                keyConverter, valueConverter, minSplits, batchSize)
+        jrdd = self._jvm.PythonRDD.sequenceFile(
+            self._jsc,
+            path,
+            keyClass,
+            valueClass,
+            keyConverter,
+            valueConverter,
+            minSplits,
+            batchSize,
+        )
         return RDD(jrdd, self)
 
-    def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
-                         valueConverter=None, conf=None, batchSize=0):
+    def newAPIHadoopFile(
+        self,
+        path,
+        inputFormatClass,
+        keyClass,
+        valueClass,
+        keyConverter=None,
+        valueConverter=None,
+        conf=None,
+        batchSize=0,
+    ):
         """
         Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS,
         a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -820,13 +894,29 @@ class SparkContext(object):
             Java object. (default 0, choose batchSize automatically)
         """
         jconf = self._dictToJavaMap(conf)
-        jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
-                                                    valueClass, keyConverter, valueConverter,
-                                                    jconf, batchSize)
+        jrdd = self._jvm.PythonRDD.newAPIHadoopFile(
+            self._jsc,
+            path,
+            inputFormatClass,
+            keyClass,
+            valueClass,
+            keyConverter,
+            valueConverter,
+            jconf,
+            batchSize,
+        )
         return RDD(jrdd, self)
 
-    def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
-                        valueConverter=None, conf=None, batchSize=0):
+    def newAPIHadoopRDD(
+        self,
+        inputFormatClass,
+        keyClass,
+        valueClass,
+        keyConverter=None,
+        valueConverter=None,
+        conf=None,
+        batchSize=0,
+    ):
         """
         Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
         Hadoop configuration, which is passed in as a Python dict.
@@ -856,13 +946,29 @@ class SparkContext(object):
             Java object. (default 0, choose batchSize automatically)
         """
         jconf = self._dictToJavaMap(conf)
-        jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass,
-                                                   valueClass, keyConverter, valueConverter,
-                                                   jconf, batchSize)
+        jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(
+            self._jsc,
+            inputFormatClass,
+            keyClass,
+            valueClass,
+            keyConverter,
+            valueConverter,
+            jconf,
+            batchSize,
+        )
         return RDD(jrdd, self)
 
-    def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
-                   valueConverter=None, conf=None, batchSize=0):
+    def hadoopFile(
+        self,
+        path,
+        inputFormatClass,
+        keyClass,
+        valueClass,
+        keyConverter=None,
+        valueConverter=None,
+        conf=None,
+        batchSize=0,
+    ):
         """
         Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS,
         a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -894,13 +1000,29 @@ class SparkContext(object):
             Java object. (default 0, choose batchSize automatically)
         """
         jconf = self._dictToJavaMap(conf)
-        jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass,
-                                              valueClass, keyConverter, valueConverter,
-                                              jconf, batchSize)
+        jrdd = self._jvm.PythonRDD.hadoopFile(
+            self._jsc,
+            path,
+            inputFormatClass,
+            keyClass,
+            valueClass,
+            keyConverter,
+            valueConverter,
+            jconf,
+            batchSize,
+        )
         return RDD(jrdd, self)
 
-    def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
-                  valueConverter=None, conf=None, batchSize=0):
+    def hadoopRDD(
+        self,
+        inputFormatClass,
+        keyClass,
+        valueClass,
+        keyConverter=None,
+        valueConverter=None,
+        conf=None,
+        batchSize=0,
+    ):
         """
         Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
         Hadoop configuration, which is passed in as a Python dict.
@@ -930,9 +1052,16 @@ class SparkContext(object):
             Java object. (default 0, choose batchSize automatically)
         """
         jconf = self._dictToJavaMap(conf)
-        jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass,
-                                             valueClass, keyConverter, valueConverter,
-                                             jconf, batchSize)
+        jrdd = self._jvm.PythonRDD.hadoopRDD(
+            self._jsc,
+            inputFormatClass,
+            keyClass,
+            valueClass,
+            keyConverter,
+            valueConverter,
+            jconf,
+            batchSize,
+        )
         return RDD(jrdd, self)
 
     def _checkpointFile(self, name, input_deserializer):
@@ -1087,11 +1216,13 @@ class SparkContext(object):
             raise TypeError("storageLevel must be of type pyspark.StorageLevel")
 
         newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
-        return newStorageLevel(storageLevel.useDisk,
-                               storageLevel.useMemory,
-                               storageLevel.useOffHeap,
-                               storageLevel.deserialized,
-                               storageLevel.replication)
+        return newStorageLevel(
+            storageLevel.useDisk,
+            storageLevel.useMemory,
+            storageLevel.useOffHeap,
+            storageLevel.deserialized,
+            storageLevel.replication,
+        )
 
     def setJobGroup(self, groupId, description, interruptOnCancel=False):
         """
@@ -1228,21 +1359,24 @@ class SparkContext(object):
         return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer))
 
     def show_profiles(self):
-        """ Print the profile stats to stdout """
+        """Print the profile stats to stdout"""
         if self.profiler_collector is not None:
             self.profiler_collector.show_profiles()
         else:
-            raise RuntimeError("'spark.python.profile' configuration must be set "
-                               "to 'true' to enable Python profile.")
+            raise RuntimeError(
+                "'spark.python.profile' configuration must be set "
+                "to 'true' to enable Python profile."
+            )
 
     def dump_profiles(self, path):
-        """ Dump the profile stats into directory `path`
-        """
+        """Dump the profile stats into directory `path`"""
         if self.profiler_collector is not None:
             self.profiler_collector.dump_profiles(path)
         else:
-            raise RuntimeError("'spark.python.profile' configuration must be set "
-                               "to 'true' to enable Python profile.")
+            raise RuntimeError(
+                "'spark.python.profile' configuration must be set "
+                "to 'true' to enable Python profile."
+            )
 
     def getConf(self):
         conf = SparkConf()
@@ -1275,12 +1409,13 @@ def _test():
     import atexit
     import doctest
     import tempfile
+
     globs = globals().copy()
-    globs['sc'] = SparkContext('local[4]', 'PythonTest')
-    globs['tempdir'] = tempfile.mkdtemp()
-    atexit.register(lambda: shutil.rmtree(globs['tempdir']))
+    globs["sc"] = SparkContext("local[4]", "PythonTest")
+    globs["tempdir"] = tempfile.mkdtemp()
+    atexit.register(lambda: shutil.rmtree(globs["tempdir"]))
     (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
-    globs['sc'].stop()
+    globs["sc"].stop()
     if failure_count:
         sys.exit(-1)
 
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 97b6b25..b8fd03d 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -88,13 +88,13 @@ def manager():
 
     # Create a listening socket on the AF_INET loopback interface
     listen_sock = socket.socket(AF_INET, SOCK_STREAM)
-    listen_sock.bind(('127.0.0.1', 0))
+    listen_sock.bind(("127.0.0.1", 0))
     listen_sock.listen(max(1024, SOMAXCONN))
     listen_host, listen_port = listen_sock.getsockname()
 
     # re-open stdin/stdout in 'wb' mode
-    stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4)
-    stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4)
+    stdin_bin = os.fdopen(sys.stdin.fileno(), "rb", 4)
+    stdout_bin = os.fdopen(sys.stdout.fileno(), "wb", 4)
     write_int(listen_port, stdout_bin)
     stdout_bin.flush()
 
@@ -106,6 +106,7 @@ def manager():
 
     def handle_sigterm(*args):
         shutdown(1)
+
     signal.signal(SIGTERM, handle_sigterm)  # Gracefully exit on SIGTERM
     signal.signal(SIGHUP, SIG_IGN)  # Don't die on SIGHUP
     signal.signal(SIGCHLD, SIG_IGN)
@@ -150,7 +151,7 @@ def manager():
                         time.sleep(1)
                         pid = os.fork()  # error here will shutdown daemon
                     else:
-                        outfile = sock.makefile(mode='wb')
+                        outfile = sock.makefile(mode="wb")
                         write_int(e.errno, outfile)  # Signal that the fork failed
                         outfile.flush()
                         outfile.close()
@@ -171,7 +172,7 @@ def manager():
                     # Therefore, here we redirects it to '/dev/null' by duplicating
                     # another file descriptor for '/dev/null' to the standard input (0).
                     # See SPARK-26175.
-                    devnull = open(os.devnull, 'r')
+                    devnull = open(os.devnull, "r")
                     os.dup2(devnull.fileno(), 0)
                     devnull.close()
 
@@ -207,5 +208,5 @@ def manager():
         shutdown(1)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     manager()
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
index a0b74ca..59ab0b6 100644
--- a/python/pyspark/files.py
+++ b/python/pyspark/files.py
@@ -18,7 +18,7 @@
 import os
 
 
-__all__ = ['SparkFiles']
+__all__ = ["SparkFiles"]
 
 from typing import cast, ClassVar, Optional, TYPE_CHECKING
 
@@ -61,6 +61,5 @@ class SparkFiles(object):
         else:
             # This will have to change if we support multiple SparkContexts:
             return cast(
-                "SparkContext",
-                cls._sc
+                "SparkContext", cls._sc
             )._jvm.org.apache.spark.SparkFiles.getRootDirectory()  # type: ignore[attr-defined]
diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py
index 62a36d4..09f4551 100755
--- a/python/pyspark/find_spark_home.py
+++ b/python/pyspark/find_spark_home.py
@@ -32,9 +32,10 @@ def _find_spark_home():
 
     def is_spark_home(path):
         """Takes a path and returns true if the provided path could be a reasonable SPARK_HOME"""
-        return (os.path.isfile(os.path.join(path, "bin/spark-submit")) and
-                (os.path.isdir(os.path.join(path, "jars")) or
-                 os.path.isdir(os.path.join(path, "assembly"))))
+        return os.path.isfile(os.path.join(path, "bin/spark-submit")) and (
+            os.path.isdir(os.path.join(path, "jars"))
+            or os.path.isdir(os.path.join(path, "assembly"))
+        )
 
     # Spark distribution can be downloaded when PYSPARK_HADOOP_VERSION environment variable is set.
     # We should look up this directory first, see also SPARK-32017.
@@ -43,11 +44,13 @@ def _find_spark_home():
         "../",  # When we're in spark/python.
         # Two case belows are valid when the current script is called as a library.
         os.path.join(os.path.dirname(os.path.realpath(__file__)), spark_dist_dir),
-        os.path.dirname(os.path.realpath(__file__))]
+        os.path.dirname(os.path.realpath(__file__)),
+    ]
 
     # Add the path of the PySpark module if it exists
     import_error_raised = False
     from importlib.util import find_spec
+
     try:
         module_home = os.path.dirname(find_spec("pyspark").origin)
         paths.append(os.path.join(module_home, spark_dist_dir))
@@ -78,7 +81,9 @@ def _find_spark_home():
                 "for example, 'python -m pip install pyspark [--user]'. Otherwise, you can also\n"
                 "explicitly set the Python executable, that has PySpark installed, to\n"
                 "PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON environment variables, for example,\n"
-                "'PYSPARK_PYTHON=python3 pyspark'.\n", file=sys.stderr)
+                "'PYSPARK_PYTHON=python3 pyspark'.\n",
+                file=sys.stderr,
+            )
         sys.exit(-1)
 
 
diff --git a/python/pyspark/install.py b/python/pyspark/install.py
index 7efee42..bfd7001 100644
--- a/python/pyspark/install.py
+++ b/python/pyspark/install.py
@@ -20,6 +20,7 @@ import tarfile
 import traceback
 import urllib.request
 from shutil import rmtree
+
 # NOTE that we shouldn't import pyspark here because this is used in
 # setup.py, and assume there's no PySpark imported.
 
@@ -27,8 +28,7 @@ DEFAULT_HADOOP = "hadoop3.2"
 DEFAULT_HIVE = "hive2.3"
 SUPPORTED_HADOOP_VERSIONS = ["hadoop2.7", "hadoop3.2", "without-hadoop"]
 SUPPORTED_HIVE_VERSIONS = ["hive2.3"]
-UNSUPPORTED_COMBINATIONS = [  # type: ignore
-]
+UNSUPPORTED_COMBINATIONS = []  # type: ignore
 
 
 def checked_package_name(spark_version, hadoop_version, hive_version):
@@ -60,8 +60,8 @@ def checked_versions(spark_version, hadoop_version, hive_version):
         spark_version = "spark-%s" % spark_version
     if not spark_version.startswith("spark-"):
         raise RuntimeError(
-            "Spark version should start with 'spark-' prefix; however, "
-            "got %s" % spark_version)
+            "Spark version should start with 'spark-' prefix; however, " "got %s" % spark_version
+        )
 
     if hadoop_version == "without":
         hadoop_version = "without-hadoop"
@@ -71,8 +71,8 @@ def checked_versions(spark_version, hadoop_version, hive_version):
     if hadoop_version not in SUPPORTED_HADOOP_VERSIONS:
         raise RuntimeError(
             "Spark distribution of %s is not supported. Hadoop version should be "
-            "one of [%s]" % (hadoop_version, ", ".join(
-                SUPPORTED_HADOOP_VERSIONS)))
+            "one of [%s]" % (hadoop_version, ", ".join(SUPPORTED_HADOOP_VERSIONS))
+        )
 
     if re.match("^[0-9]+\\.[0-9]+$", hive_version):
         hive_version = "hive%s" % hive_version
@@ -80,8 +80,8 @@ def checked_versions(spark_version, hadoop_version, hive_version):
     if hive_version not in SUPPORTED_HIVE_VERSIONS:
         raise RuntimeError(
             "Spark distribution of %s is not supported. Hive version should be "
-            "one of [%s]" % (hive_version, ", ".join(
-                SUPPORTED_HADOOP_VERSIONS)))
+            "one of [%s]" % (hive_version, ", ".join(SUPPORTED_HADOOP_VERSIONS))
+        )
 
     return spark_version, hadoop_version, hive_version
 
@@ -114,7 +114,8 @@ def install_spark(dest, spark_version, hadoop_version, hive_version):
 
     pretty_pkg_name = "%s for Hadoop %s" % (
         spark_version,
-        "Free build" if hadoop_version == "without" else hadoop_version)
+        "Free build" if hadoop_version == "without" else hadoop_version,
+    )
 
     for site in sites:
         os.makedirs(dest, exist_ok=True)
@@ -151,19 +152,22 @@ def get_preferred_mirrors():
     for _ in range(3):
         try:
             response = urllib.request.urlopen(
-                "https://www.apache.org/dyn/closer.lua?preferred=true")
-            mirror_urls.append(response.read().decode('utf-8'))
+                "https://www.apache.org/dyn/closer.lua?preferred=true"
+            )
+            mirror_urls.append(response.read().decode("utf-8"))
         except Exception:
             # If we can't get a mirror URL, skip it. No retry.
             pass
 
     default_sites = [
-        "https://archive.apache.org/dist", "https://dist.apache.org/repos/dist/release"]
+        "https://archive.apache.org/dist",
+        "https://dist.apache.org/repos/dist/release",
+    ]
     return list(set(mirror_urls)) + default_sites
 
 
 def download_to_file(response, path, chunk_size=1024 * 1024):
-    total_size = int(response.info().get('Content-Length').strip())
+    total_size = int(response.info().get("Content-Length").strip())
     bytes_so_far = 0
 
     with open(path, mode="wb") as dest:
@@ -173,7 +177,7 @@ def download_to_file(response, path, chunk_size=1024 * 1024):
             if not chunk:
                 break
             dest.write(chunk)
-            print("Downloaded %d of %d bytes (%0.2f%%)" % (
-                bytes_so_far,
-                total_size,
-                round(float(bytes_so_far) / total_size * 100, 2)))
+            print(
+                "Downloaded %d of %d bytes (%0.2f%%)"
+                % (bytes_so_far, total_size, round(float(bytes_so_far) / total_size * 100, 2))
+            )
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index bffdc0b..a41ccfa 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -64,13 +64,10 @@ def launch_gateway(conf=None, popen_kwargs=None):
         command = [os.path.join(SPARK_HOME, script)]
         if conf:
             for k, v in conf.getAll():
-                command += ['--conf', '%s=%s' % (k, v)]
+                command += ["--conf", "%s=%s" % (k, v)]
         submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
         if os.environ.get("SPARK_TESTING"):
-            submit_args = ' '.join([
-                "--conf spark.ui.enabled=false",
-                submit_args
-            ])
+            submit_args = " ".join(["--conf spark.ui.enabled=false", submit_args])
         command = command + shlex.split(submit_args)
 
         # Create a temporary directory where the gateway server should write the connection
@@ -87,14 +84,15 @@ def launch_gateway(conf=None, popen_kwargs=None):
             # Launch the Java gateway.
             popen_kwargs = {} if popen_kwargs is None else popen_kwargs
             # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
-            popen_kwargs['stdin'] = PIPE
+            popen_kwargs["stdin"] = PIPE
             # We always set the necessary environment variables.
-            popen_kwargs['env'] = env
+            popen_kwargs["env"] = env
             if not on_windows:
                 # Don't send ctrl-c / SIGINT to the Java gateway:
                 def preexec_func():
                     signal.signal(signal.SIGINT, signal.SIG_IGN)
-                popen_kwargs['preexec_fn'] = preexec_func
+
+                popen_kwargs["preexec_fn"] = preexec_func
                 proc = Popen(command, **popen_kwargs)
             else:
                 # preexec_fn not supported on Windows
@@ -127,24 +125,23 @@ def launch_gateway(conf=None, popen_kwargs=None):
             # child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
             def killChild():
                 Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
+
             atexit.register(killChild)
 
     # Connect to the gateway (or client server to pin the thread between JVM and Python)
     if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
         gateway = ClientServer(
             java_parameters=JavaParameters(
-                port=gateway_port,
-                auth_token=gateway_secret,
-                auto_convert=True),
-            python_parameters=PythonParameters(
-                port=0,
-                eager_load=False))
+                port=gateway_port, auth_token=gateway_secret, auto_convert=True
+            ),
+            python_parameters=PythonParameters(port=0, eager_load=False),
+        )
     else:
         gateway = JavaGateway(
             gateway_parameters=GatewayParameters(
-                port=gateway_port,
-                auth_token=gateway_secret,
-                auto_convert=True))
+                port=gateway_port, auth_token=gateway_secret, auto_convert=True
+            )
+        )
 
     # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr)
     gateway.proc = proc
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index c1f5362..040c946 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -50,6 +50,7 @@ def python_join(rdd, other, numPartitions):
             elif n == 2:
                 wbuf.append(v)
         return ((v, w) for v in vbuf for w in wbuf)
+
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -64,6 +65,7 @@ def python_right_outer_join(rdd, other, numPartitions):
         if not vbuf:
             vbuf.append(None)
         return ((v, w) for v in vbuf for w in wbuf)
+
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -78,6 +80,7 @@ def python_left_outer_join(rdd, other, numPartitions):
         if not wbuf:
             wbuf.append(None)
         return ((v, w) for v in vbuf for w in wbuf)
+
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -94,12 +97,14 @@ def python_full_outer_join(rdd, other, numPartitions):
         if not wbuf:
             wbuf.append(None)
         return ((v, w) for v in vbuf for w in wbuf)
+
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
 def python_cogroup(rdds, numPartitions):
     def make_mapper(i):
         return lambda v: (i, v)
+
     vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)]
     union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
     rdd_len = len(vrdds)
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
index 7d0e55a..8a235f7 100644
--- a/python/pyspark/ml/__init__.py
+++ b/python/pyspark/ml/__init__.py
@@ -19,15 +19,51 @@
 DataFrame-based machine learning APIs to let users quickly assemble and configure practical
 machine learning pipelines.
 """
-from pyspark.ml.base import Estimator, Model, Predictor, PredictionModel, \
-    Transformer, UnaryTransformer
+from pyspark.ml.base import (
+    Estimator,
+    Model,
+    Predictor,
+    PredictionModel,
+    Transformer,
+    UnaryTransformer,
+)
 from pyspark.ml.pipeline import Pipeline, PipelineModel
-from pyspark.ml import classification, clustering, evaluation, feature, fpm, \
-    image, recommendation, regression, stat, tuning, util, linalg, param
+from pyspark.ml import (
+    classification,
+    clustering,
+    evaluation,
+    feature,
+    fpm,
+    image,
+    recommendation,
+    regression,
+    stat,
+    tuning,
+    util,
+    linalg,
+    param,
+)
 
 __all__ = [
-    "Transformer", "UnaryTransformer", "Estimator", "Model",
-    "Predictor", "PredictionModel", "Pipeline", "PipelineModel",
-    "classification", "clustering", "evaluation", "feature", "fpm", "image",
-    "recommendation", "regression", "stat", "tuning", "util", "linalg", "param",
+    "Transformer",
+    "UnaryTransformer",
+    "Estimator",
+    "Model",
+    "Predictor",
+    "PredictionModel",
+    "Pipeline",
+    "PipelineModel",
+    "classification",
+    "clustering",
+    "evaluation",
+    "feature",
+    "fpm",
+    "image",
+    "recommendation",
+    "regression",
+    "stat",
+    "tuning",
+    "util",
+    "linalg",
+    "param",
 ]
diff --git a/python/pyspark/ml/_typing.pyi b/python/pyspark/ml/_typing.pyi
index d966a78..40531d1 100644
--- a/python/pyspark/ml/_typing.pyi
+++ b/python/pyspark/ml/_typing.pyi
@@ -32,9 +32,7 @@ P = TypeVar("P", bound=pyspark.ml.param.Params)
 M = TypeVar("M", bound=pyspark.ml.base.Transformer)
 JM = TypeVar("JM", bound=pyspark.ml.wrapper.JavaTransformer)
 
-BinaryClassificationEvaluatorMetricType = Union[
-    Literal["areaUnderROC"], Literal["areaUnderPR"]
-]
+BinaryClassificationEvaluatorMetricType = Union[Literal["areaUnderROC"], Literal["areaUnderPR"]]
 RegressionEvaluatorMetricType = Union[
     Literal["rmse"], Literal["mse"], Literal["r2"], Literal["mae"], Literal["var"]
 ]
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 31ce93d..970444d 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -22,8 +22,14 @@ import threading
 
 from pyspark import since
 from pyspark.ml.common import inherit_doc
-from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasLabelCol, HasFeaturesCol, \
-    HasPredictionCol, Params
+from pyspark.ml.param.shared import (
+    HasInputCol,
+    HasOutputCol,
+    HasLabelCol,
+    HasFeaturesCol,
+    HasPredictionCol,
+    Params,
+)
 from pyspark.sql.functions import udf
 from pyspark.sql.types import StructField, StructType
 
@@ -48,10 +54,9 @@ class _FitMultipleIterator(object):
     -----
     See :py:meth:`Estimator.fitMultiple` for more info.
     """
-    def __init__(self, fitSingleModel, numModels):
-        """
 
-        """
+    def __init__(self, fitSingleModel, numModels):
+        """ """
         self.fitSingleModel = fitSingleModel
         self.numModel = numModels
         self.counter = 0
@@ -80,6 +85,7 @@ class Estimator(Params, metaclass=ABCMeta):
 
     .. versionadded:: 1.3.0
     """
+
     pass
 
     @abstractmethod
@@ -160,8 +166,10 @@ class Estimator(Params, metaclass=ABCMeta):
             else:
                 return self._fit(dataset)
         else:
-            raise TypeError("Params must be either a param map or a list/tuple of param maps, "
-                            "but got %s." % type(params))
+            raise TypeError(
+                "Params must be either a param map or a list/tuple of param maps, "
+                "but got %s." % type(params)
+            )
 
 
 @inherit_doc
@@ -171,6 +179,7 @@ class Transformer(Params, metaclass=ABCMeta):
 
     .. versionadded:: 1.3.0
     """
+
     pass
 
     @abstractmethod
@@ -226,6 +235,7 @@ class Model(Transformer, metaclass=ABCMeta):
 
     .. versionadded:: 1.4.0
     """
+
     pass
 
 
@@ -279,16 +289,15 @@ class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
         if self.getOutputCol() in schema.names:
             raise ValueError("Output column %s already exists." % self.getOutputCol())
         outputFields = copy.copy(schema.fields)
-        outputFields.append(StructField(self.getOutputCol(),
-                                        self.outputDataType(),
-                                        nullable=False))
+        outputFields.append(StructField(self.getOutputCol(), self.outputDataType(), nullable=False))
         return StructType(outputFields)
 
     def _transform(self, dataset):
         self.transformSchema(dataset.schema)
         transformUDF = udf(self.createTransformFunc(), self.outputDataType())
-        transformedDataset = dataset.withColumn(self.getOutputCol(),
-                                                transformUDF(dataset[self.getInputCol()]))
+        transformedDataset = dataset.withColumn(
+            self.getOutputCol(), transformUDF(dataset[self.getInputCol()])
+        )
         return transformedDataset
 
 
@@ -299,6 +308,7 @@ class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
 
     .. versionadded:: 3.0.0
     """
+
     pass
 
 
diff --git a/python/pyspark/ml/base.pyi b/python/pyspark/ml/base.pyi
index 4f1c6f9..37ae6de 100644
--- a/python/pyspark/ml/base.pyi
+++ b/python/pyspark/ml/base.pyi
@@ -57,9 +57,7 @@ class _FitMultipleIterator:
     numModel: int
     counter: int = ...
     lock: _thread.LockType
-    def __init__(
-        self, fitSingleModel: Callable[[int], Transformer], numModels: int
-    ) -> None: ...
+    def __init__(self, fitSingleModel: Callable[[int], Transformer], numModels: int) -> None: ...
     def __iter__(self) -> _FitMultipleIterator: ...
     def __next__(self) -> Tuple[int, Transformer]: ...
     def next(self) -> Tuple[int, Transformer]: ...
@@ -76,9 +74,7 @@ class Estimator(Generic[M], Params, metaclass=abc.ABCMeta):
     ) -> Iterable[Tuple[int, M]]: ...
 
 class Transformer(Params, metaclass=abc.ABCMeta):
-    def transform(
-        self, dataset: DataFrame, params: Optional[ParamMap] = ...
-    ) -> DataFrame: ...
+    def transform(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> DataFrame: ...
 
 class Model(Transformer, metaclass=abc.ABCMeta): ...
 
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 79b57d7..e6ce3e0 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -25,20 +25,54 @@ from multiprocessing.pool import ThreadPool
 
 from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
 from pyspark.ml import Estimator, Predictor, PredictionModel, Model
-from pyspark.ml.param.shared import HasRawPredictionCol, HasProbabilityCol, HasThresholds, \
-    HasRegParam, HasMaxIter, HasFitIntercept, HasTol, HasStandardization, HasWeightCol, \
-    HasAggregationDepth, HasThreshold, HasBlockSize, HasMaxBlockSizeInMB, Param, Params, \
-    TypeConverters, HasElasticNetParam, HasSeed, HasStepSize, HasSolver, HasParallelism
-from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
-    _TreeEnsembleModel, _RandomForestParams, _GBTParams, \
-    _HasVarianceImpurity, _TreeClassifierParams
+from pyspark.ml.param.shared import (
+    HasRawPredictionCol,
+    HasProbabilityCol,
+    HasThresholds,
+    HasRegParam,
+    HasMaxIter,
+    HasFitIntercept,
+    HasTol,
+    HasStandardization,
+    HasWeightCol,
+    HasAggregationDepth,
+    HasThreshold,
+    HasBlockSize,
+    HasMaxBlockSizeInMB,
+    Param,
+    Params,
+    TypeConverters,
+    HasElasticNetParam,
+    HasSeed,
+    HasStepSize,
+    HasSolver,
+    HasParallelism,
+)
+from pyspark.ml.tree import (
+    _DecisionTreeModel,
+    _DecisionTreeParams,
+    _TreeEnsembleModel,
+    _RandomForestParams,
+    _GBTParams,
+    _HasVarianceImpurity,
+    _TreeClassifierParams,
+)
 from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
 from pyspark.ml.base import _PredictorParams
-from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, \
-    JavaMLReadable, JavaMLReader, JavaMLWritable, JavaMLWriter, \
-    MLReader, MLReadable, MLWriter, MLWritable, HasTrainingSummary
-from pyspark.ml.wrapper import JavaParams, \
-    JavaPredictor, JavaPredictionModel, JavaWrapper
+from pyspark.ml.util import (
+    DefaultParamsReader,
+    DefaultParamsWriter,
+    JavaMLReadable,
+    JavaMLReader,
+    JavaMLWritable,
+    JavaMLWriter,
+    MLReader,
+    MLReadable,
+    MLWriter,
+    MLWritable,
+    HasTrainingSummary,
+)
+from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
 from pyspark.ml.common import inherit_doc
 from pyspark.ml.linalg import Vectors, VectorUDT
 from pyspark.sql import DataFrame
@@ -46,24 +80,40 @@ from pyspark.sql.functions import udf, when
 from pyspark.sql.types import ArrayType, DoubleType
 from pyspark.storagelevel import StorageLevel
 
-__all__ = ['LinearSVC', 'LinearSVCModel',
-           'LinearSVCSummary', 'LinearSVCTrainingSummary',
-           'LogisticRegression', 'LogisticRegressionModel',
-           'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
-           'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
-           'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
-           'GBTClassifier', 'GBTClassificationModel',
-           'RandomForestClassifier', 'RandomForestClassificationModel',
-           'RandomForestClassificationSummary', 'RandomForestClassificationTrainingSummary',
-           'BinaryRandomForestClassificationSummary',
-           'BinaryRandomForestClassificationTrainingSummary',
-           'NaiveBayes', 'NaiveBayesModel',
-           'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel',
-           'MultilayerPerceptronClassificationSummary',
-           'MultilayerPerceptronClassificationTrainingSummary',
-           'OneVsRest', 'OneVsRestModel',
-           'FMClassifier', 'FMClassificationModel', 'FMClassificationSummary',
-           'FMClassificationTrainingSummary']
+__all__ = [
+    "LinearSVC",
+    "LinearSVCModel",
+    "LinearSVCSummary",
+    "LinearSVCTrainingSummary",
+    "LogisticRegression",
+    "LogisticRegressionModel",
+    "LogisticRegressionSummary",
+    "LogisticRegressionTrainingSummary",
+    "BinaryLogisticRegressionSummary",
+    "BinaryLogisticRegressionTrainingSummary",
+    "DecisionTreeClassifier",
+    "DecisionTreeClassificationModel",
+    "GBTClassifier",
+    "GBTClassificationModel",
+    "RandomForestClassifier",
+    "RandomForestClassificationModel",
+    "RandomForestClassificationSummary",
+    "RandomForestClassificationTrainingSummary",
+    "BinaryRandomForestClassificationSummary",
+    "BinaryRandomForestClassificationTrainingSummary",
+    "NaiveBayes",
+    "NaiveBayesModel",
+    "MultilayerPerceptronClassifier",
+    "MultilayerPerceptronClassificationModel",
+    "MultilayerPerceptronClassificationSummary",
+    "MultilayerPerceptronClassificationTrainingSummary",
+    "OneVsRest",
+    "OneVsRestModel",
+    "FMClassifier",
+    "FMClassificationModel",
+    "FMClassificationSummary",
+    "FMClassificationTrainingSummary",
+]
 
 
 class _ClassifierParams(HasRawPredictionCol, _PredictorParams):
@@ -72,6 +122,7 @@ class _ClassifierParams(HasRawPredictionCol, _PredictorParams):
 
     .. versionadded:: 3.0.0
     """
+
     pass
 
 
@@ -128,12 +179,12 @@ class _ProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _Classifi
 
     .. versionadded:: 3.0.0
     """
+
     pass
 
 
 @inherit_doc
-class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams,
-                              metaclass=ABCMeta):
+class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams, metaclass=ABCMeta):
     """
     Probabilistic Classifier for classification tasks.
     """
@@ -154,9 +205,9 @@ class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams,
 
 
 @inherit_doc
-class ProbabilisticClassificationModel(ClassificationModel,
-                                       _ProbabilisticClassifierParams,
-                                       metaclass=ABCMeta):
+class ProbabilisticClassificationModel(
+    ClassificationModel, _ProbabilisticClassifierParams, metaclass=ABCMeta
+):
     """
     Model produced by a ``ProbabilisticClassifier``.
     """
@@ -224,17 +275,18 @@ class _JavaClassificationModel(ClassificationModel, JavaPredictionModel):
 
 
 @inherit_doc
-class _JavaProbabilisticClassifier(ProbabilisticClassifier, _JavaClassifier,
-                                   metaclass=ABCMeta):
+class _JavaProbabilisticClassifier(ProbabilisticClassifier, _JavaClassifier, metaclass=ABCMeta):
     """
     Java Probabilistic Classifier for classification tasks.
     """
+
     pass
 
 
 @inherit_doc
-class _JavaProbabilisticClassificationModel(ProbabilisticClassificationModel,
-                                            _JavaClassificationModel):
+class _JavaProbabilisticClassificationModel(
+    ProbabilisticClassificationModel, _JavaClassificationModel
+):
     """
     Java Model produced by a ``ProbabilisticClassifier``.
     """
@@ -505,26 +557,45 @@ class _BinaryClassificationSummary(_ClassificationSummary):
         return self._call_java("recallByThreshold")
 
 
-class _LinearSVCParams(_ClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol,
-                       HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold,
-                       HasMaxBlockSizeInMB):
+class _LinearSVCParams(
+    _ClassifierParams,
+    HasRegParam,
+    HasMaxIter,
+    HasFitIntercept,
+    HasTol,
+    HasStandardization,
+    HasWeightCol,
+    HasAggregationDepth,
+    HasThreshold,
+    HasMaxBlockSizeInMB,
+):
     """
     Params for :py:class:`LinearSVC` and :py:class:`LinearSVCModel`.
 
     .. versionadded:: 3.0.0
     """
 
-    threshold = Param(Params._dummy(), "threshold",
-                      "The threshold in binary classification applied to the linear model"
-                      " prediction.  This threshold can be any real number, where Inf will make"
-                      " all predictions 0.0 and -Inf will make all predictions 1.0.",
-                      typeConverter=TypeConverters.toFloat)
+    threshold = Param(
+        Params._dummy(),
+        "threshold",
+        "The threshold in binary classification applied to the linear model"
+        " prediction.  This threshold can be any real number, where Inf will make"
+        " all predictions 0.0 and -Inf will make all predictions 1.0.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     def __init__(self, *args):
         super(_LinearSVCParams, self).__init__(*args)
-        self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True,
-                         standardization=True, threshold=0.0, aggregationDepth=2,
-                         maxBlockSizeInMB=0.0)
+        self._setDefault(
+            maxIter=100,
+            regParam=0.0,
+            tol=1e-6,
+            fitIntercept=True,
+            standardization=True,
+            threshold=0.0,
+            aggregationDepth=2,
+            maxBlockSizeInMB=0.0,
+        )
 
 
 @inherit_doc
@@ -605,10 +676,23 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
-                 fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
-                 aggregationDepth=2, maxBlockSizeInMB=0.0):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        maxIter=100,
+        regParam=0.0,
+        tol=1e-6,
+        rawPredictionCol="rawPrediction",
+        fitIntercept=True,
+        standardization=True,
+        threshold=0.0,
+        weightCol=None,
+        aggregationDepth=2,
+        maxBlockSizeInMB=0.0,
+    ):
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
@@ -617,16 +701,30 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl
         """
         super(LinearSVC, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.classification.LinearSVC", self.uid)
+            "org.apache.spark.ml.classification.LinearSVC", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.2.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
-                  fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
-                  aggregationDepth=2, maxBlockSizeInMB=0.0):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        maxIter=100,
+        regParam=0.0,
+        tol=1e-6,
+        rawPredictionCol="rawPrediction",
+        fitIntercept=True,
+        standardization=True,
+        threshold=0.0,
+        weightCol=None,
+        aggregationDepth=2,
+        maxBlockSizeInMB=0.0,
+    ):
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
@@ -704,8 +802,9 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl
         return self._set(maxBlockSizeInMB=value)
 
 
-class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable,
-                     HasTrainingSummary):
+class LinearSVCModel(
+    _JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary
+):
     """
     Model fitted by LinearSVC.
 
@@ -744,8 +843,9 @@ class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable,
         if self.hasSummary:
             return LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
         else:
-            raise RuntimeError("No training summary available for this %s" %
-                               self.__class__.__name__)
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
 
     def evaluate(self, dataset):
         """
@@ -770,6 +870,7 @@ class LinearSVCSummary(_BinaryClassificationSummary):
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
@@ -780,66 +881,95 @@ class LinearSVCTrainingSummary(LinearSVCSummary, _TrainingSummary):
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
-class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
-                                HasElasticNetParam, HasMaxIter, HasFitIntercept, HasTol,
-                                HasStandardization, HasWeightCol, HasAggregationDepth,
-                                HasThreshold, HasMaxBlockSizeInMB):
+class _LogisticRegressionParams(
+    _ProbabilisticClassifierParams,
+    HasRegParam,
+    HasElasticNetParam,
+    HasMaxIter,
+    HasFitIntercept,
+    HasTol,
+    HasStandardization,
+    HasWeightCol,
+    HasAggregationDepth,
+    HasThreshold,
+    HasMaxBlockSizeInMB,
+):
     """
     Params for :py:class:`LogisticRegression` and :py:class:`LogisticRegressionModel`.
 
     .. versionadded:: 3.0.0
     """
 
-    threshold = Param(Params._dummy(), "threshold",
-                      "Threshold in binary classification prediction, in range [0, 1]." +
-                      " If threshold and thresholds are both set, they must match." +
-                      "e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
-                      typeConverter=TypeConverters.toFloat)
-
-    family = Param(Params._dummy(), "family",
-                   "The name of family which is a description of the label distribution to " +
-                   "be used in the model. Supported options: auto, binomial, multinomial",
-                   typeConverter=TypeConverters.toString)
-
-    lowerBoundsOnCoefficients = Param(Params._dummy(), "lowerBoundsOnCoefficients",
-                                      "The lower bounds on coefficients if fitting under bound "
-                                      "constrained optimization. The bound matrix must be "
-                                      "compatible with the shape "
-                                      "(1, number of features) for binomial regression, or "
-                                      "(number of classes, number of features) "
-                                      "for multinomial regression.",
-                                      typeConverter=TypeConverters.toMatrix)
-
-    upperBoundsOnCoefficients = Param(Params._dummy(), "upperBoundsOnCoefficients",
-                                      "The upper bounds on coefficients if fitting under bound "
-                                      "constrained optimization. The bound matrix must be "
-                                      "compatible with the shape "
-                                      "(1, number of features) for binomial regression, or "
-                                      "(number of classes, number of features) "
-                                      "for multinomial regression.",
-                                      typeConverter=TypeConverters.toMatrix)
-
-    lowerBoundsOnIntercepts = Param(Params._dummy(), "lowerBoundsOnIntercepts",
-                                    "The lower bounds on intercepts if fitting under bound "
-                                    "constrained optimization. The bounds vector size must be"
-                                    "equal with 1 for binomial regression, or the number of"
-                                    "lasses for multinomial regression.",
-                                    typeConverter=TypeConverters.toVector)
-
-    upperBoundsOnIntercepts = Param(Params._dummy(), "upperBoundsOnIntercepts",
-                                    "The upper bounds on intercepts if fitting under bound "
-                                    "constrained optimization. The bound vector size must be "
-                                    "equal with 1 for binomial regression, or the number of "
-                                    "classes for multinomial regression.",
-                                    typeConverter=TypeConverters.toVector)
+    threshold = Param(
+        Params._dummy(),
+        "threshold",
+        "Threshold in binary classification prediction, in range [0, 1]."
+        + " If threshold and thresholds are both set, they must match."
+        + "e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
+        typeConverter=TypeConverters.toFloat,
+    )
+
+    family = Param(
+        Params._dummy(),
+        "family",
+        "The name of family which is a description of the label distribution to "
+        + "be used in the model. Supported options: auto, binomial, multinomial",
+        typeConverter=TypeConverters.toString,
+    )
+
+    lowerBoundsOnCoefficients = Param(
+        Params._dummy(),
+        "lowerBoundsOnCoefficients",
+        "The lower bounds on coefficients if fitting under bound "
+        "constrained optimization. The bound matrix must be "
+        "compatible with the shape "
+        "(1, number of features) for binomial regression, or "
+        "(number of classes, number of features) "
+        "for multinomial regression.",
+        typeConverter=TypeConverters.toMatrix,
+    )
+
+    upperBoundsOnCoefficients = Param(
+        Params._dummy(),
+        "upperBoundsOnCoefficients",
+        "The upper bounds on coefficients if fitting under bound "
+        "constrained optimization. The bound matrix must be "
+        "compatible with the shape "
+        "(1, number of features) for binomial regression, or "
+        "(number of classes, number of features) "
+        "for multinomial regression.",
+        typeConverter=TypeConverters.toMatrix,
+    )
+
+    lowerBoundsOnIntercepts = Param(
+        Params._dummy(),
+        "lowerBoundsOnIntercepts",
+        "The lower bounds on intercepts if fitting under bound "
+        "constrained optimization. The bounds vector size must be"
+        "equal with 1 for binomial regression, or the number of"
+        "lasses for multinomial regression.",
+        typeConverter=TypeConverters.toVector,
+    )
+
+    upperBoundsOnIntercepts = Param(
+        Params._dummy(),
+        "upperBoundsOnIntercepts",
+        "The upper bounds on intercepts if fitting under bound "
+        "constrained optimization. The bound vector size must be "
+        "equal with 1 for binomial regression, or the number of "
+        "classes for multinomial regression.",
+        typeConverter=TypeConverters.toVector,
+    )
 
     def __init__(self, *args):
         super(_LogisticRegressionParams, self).__init__(*args)
-        self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto",
-                         maxBlockSizeInMB=0.0)
+        self._setDefault(
+            maxIter=100, regParam=0.0, tol=1e-6, threshold=0.5, family="auto", maxBlockSizeInMB=0.0
+        )
 
     @since("1.4.0")
     def setThreshold(self, value):
@@ -865,10 +995,13 @@ class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
         if self.isSet(self.thresholds):
             ts = self.getOrDefault(self.thresholds)
             if len(ts) != 2:
-                raise ValueError("Logistic Regression getThreshold only applies to" +
-                                 " binary classification, but thresholds has length != 2." +
-                                 "  thresholds: " + ",".join(ts))
-            return 1.0/(1.0 + ts[0]/ts[1])
+                raise ValueError(
+                    "Logistic Regression getThreshold only applies to"
+                    + " binary classification, but thresholds has length != 2."
+                    + "  thresholds: "
+                    + ",".join(ts)
+                )
+            return 1.0 / (1.0 + ts[0] / ts[1])
         else:
             return self.getOrDefault(self.threshold)
 
@@ -893,7 +1026,7 @@ class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
         self._checkThresholdConsistency()
         if not self.isSet(self.thresholds) and self.isSet(self.threshold):
             t = self.getOrDefault(self.threshold)
-            return [1.0-t, t]
+            return [1.0 - t, t]
         else:
             return self.getOrDefault(self.thresholds)
 
@@ -901,14 +1034,18 @@ class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
         if self.isSet(self.threshold) and self.isSet(self.thresholds):
             ts = self.getOrDefault(self.thresholds)
             if len(ts) != 2:
-                raise ValueError("Logistic Regression getThreshold only applies to" +
-                                 " binary classification, but thresholds has length != 2." +
-                                 " thresholds: {0}".format(str(ts)))
-            t = 1.0/(1.0 + ts[0]/ts[1])
+                raise ValueError(
+                    "Logistic Regression getThreshold only applies to"
+                    + " binary classification, but thresholds has length != 2."
+                    + " thresholds: {0}".format(str(ts))
+                )
+            t = 1.0 / (1.0 + ts[0] / ts[1])
             t2 = self.getOrDefault(self.threshold)
-            if abs(t2 - t) >= 1E-5:
-                raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
-                                 " threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
+            if abs(t2 - t) >= 1e-5:
+                raise ValueError(
+                    "Logistic Regression getThreshold found inconsistent values for"
+                    + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)
+                )
 
     @since("2.1.0")
     def getFamily(self):
@@ -947,8 +1084,9 @@ class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
 
 
 @inherit_doc
-class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams, JavaMLWritable,
-                         JavaMLReadable):
+class LogisticRegression(
+    _JavaProbabilisticClassifier, _LogisticRegressionParams, JavaMLWritable, JavaMLReadable
+):
     """
     Logistic regression.
     This class supports multinomial logistic (softmax) and binomial logistic regression.
@@ -1043,14 +1181,31 @@ class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
-                 threshold=0.5, thresholds=None, probabilityCol="probability",
-                 rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
-                 aggregationDepth=2, family="auto",
-                 lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
-                 lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None,
-                 maxBlockSizeInMB=0.0):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        maxIter=100,
+        regParam=0.0,
+        elasticNetParam=0.0,
+        tol=1e-6,
+        fitIntercept=True,
+        threshold=0.5,
+        thresholds=None,
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        standardization=True,
+        weightCol=None,
+        aggregationDepth=2,
+        family="auto",
+        lowerBoundsOnCoefficients=None,
+        upperBoundsOnCoefficients=None,
+        lowerBoundsOnIntercepts=None,
+        upperBoundsOnIntercepts=None,
+        maxBlockSizeInMB=0.0,
+    ):
 
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
@@ -1065,21 +1220,39 @@ class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams
         """
         super(LogisticRegression, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.classification.LogisticRegression", self.uid)
+            "org.apache.spark.ml.classification.LogisticRegression", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
         self._checkThresholdConsistency()
 
     @keyword_only
     @since("1.3.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
-                  threshold=0.5, thresholds=None, probabilityCol="probability",
-                  rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
-                  aggregationDepth=2, family="auto",
-                  lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
-                  lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None,
-                  maxBlockSizeInMB=0.0):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        maxIter=100,
+        regParam=0.0,
+        elasticNetParam=0.0,
+        tol=1e-6,
+        fitIntercept=True,
+        threshold=0.5,
+        thresholds=None,
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        standardization=True,
+        weightCol=None,
+        aggregationDepth=2,
+        family="auto",
+        lowerBoundsOnCoefficients=None,
+        upperBoundsOnCoefficients=None,
+        lowerBoundsOnIntercepts=None,
+        upperBoundsOnIntercepts=None,
+        maxBlockSizeInMB=0.0,
+    ):
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
@@ -1191,8 +1364,13 @@ class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams
         return self._set(maxBlockSizeInMB=value)
 
 
-class LogisticRegressionModel(_JavaProbabilisticClassificationModel, _LogisticRegressionParams,
-                              JavaMLWritable, JavaMLReadable, HasTrainingSummary):
+class LogisticRegressionModel(
+    _JavaProbabilisticClassificationModel,
+    _LogisticRegressionParams,
+    JavaMLWritable,
+    JavaMLReadable,
+    HasTrainingSummary,
+):
     """
     Model fitted by LogisticRegression.
 
@@ -1242,14 +1420,17 @@ class LogisticRegressionModel(_JavaProbabilisticClassificationModel, _LogisticRe
         """
         if self.hasSummary:
             if self.numClasses <= 2:
-                return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
-                                                                     self).summary)
+                return BinaryLogisticRegressionTrainingSummary(
+                    super(LogisticRegressionModel, self).summary
+                )
             else:
-                return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
-                                                               self).summary)
+                return LogisticRegressionTrainingSummary(
+                    super(LogisticRegressionModel, self).summary
+                )
         else:
-            raise RuntimeError("No training summary available for this %s" %
-                               self.__class__.__name__)
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
 
     def evaluate(self, dataset):
         """
@@ -1304,28 +1485,31 @@ class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSumm
 
     .. versionadded:: 2.0.0
     """
+
     pass
 
 
 @inherit_doc
-class BinaryLogisticRegressionSummary(_BinaryClassificationSummary,
-                                      LogisticRegressionSummary):
+class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, LogisticRegressionSummary):
     """
     Binary Logistic regression results for a given model.
 
     .. versionadded:: 2.0.0
     """
+
     pass
 
 
 @inherit_doc
-class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
-                                              LogisticRegressionTrainingSummary):
+class BinaryLogisticRegressionTrainingSummary(
+    BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary
+):
     """
     Binary Logistic regression training results for a given model.
 
     .. versionadded:: 2.0.0
     """
+
     pass
 
 
@@ -1337,14 +1521,24 @@ class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
 
     def __init__(self, *args):
         super(_DecisionTreeClassifierParams, self).__init__(*args)
-        self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
-                         impurity="gini", leafCol="", minWeightFractionPerNode=0.0)
+        self._setDefault(
+            maxDepth=5,
+            maxBins=32,
+            minInstancesPerNode=1,
+            minInfoGain=0.0,
+            maxMemoryInMB=256,
+            cacheNodeIds=False,
+            checkpointInterval=10,
+            impurity="gini",
+            leafCol="",
+            minWeightFractionPerNode=0.0,
+        )
 
 
 @inherit_doc
-class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
-                             JavaMLWritable, JavaMLReadable):
+class DecisionTreeClassifier(
+    _JavaProbabilisticClassifier, _DecisionTreeClassifierParams, JavaMLWritable, JavaMLReadable
+):
     """
     `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
     learning algorithm for classification.
@@ -1426,11 +1620,27 @@ class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifi
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 probabilityCol="probability", rawPredictionCol="rawPrediction",
-                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
-                 seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        maxDepth=5,
+        maxBins=32,
+        minInstancesPerNode=1,
+        minInfoGain=0.0,
+        maxMemoryInMB=256,
+        cacheNodeIds=False,
+        checkpointInterval=10,
+        impurity="gini",
+        seed=None,
+        weightCol=None,
+        leafCol="",
+        minWeightFractionPerNode=0.0,
+    ):
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -1440,18 +1650,34 @@ class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifi
         """
         super(DecisionTreeClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
+            "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  probabilityCol="probability", rawPredictionCol="rawPrediction",
-                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
-                  impurity="gini", seed=None, weightCol=None, leafCol="",
-                  minWeightFractionPerNode=0.0):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        maxDepth=5,
+        maxBins=32,
+        minInstancesPerNode=1,
+        minInfoGain=0.0,
+        maxMemoryInMB=256,
+        cacheNodeIds=False,
+        checkpointInterval=10,
+        impurity="gini",
+        seed=None,
+        weightCol=None,
+        leafCol="",
+        minWeightFractionPerNode=0.0,
+    ):
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -1538,9 +1764,13 @@ class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifi
 
 
 @inherit_doc
-class DecisionTreeClassificationModel(_DecisionTreeModel, _JavaProbabilisticClassificationModel,
-                                      _DecisionTreeClassifierParams, JavaMLWritable,
-                                      JavaMLReadable):
+class DecisionTreeClassificationModel(
+    _DecisionTreeModel,
+    _JavaProbabilisticClassificationModel,
+    _DecisionTreeClassifierParams,
+    JavaMLWritable,
+    JavaMLReadable,
+):
     """
     Model fitted by DecisionTreeClassifier.
 
@@ -1580,16 +1810,28 @@ class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
 
     def __init__(self, *args):
         super(_RandomForestClassifierParams, self).__init__(*args)
-        self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
-                         impurity="gini", numTrees=20, featureSubsetStrategy="auto",
-                         subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0,
-                         bootstrap=True)
+        self._setDefault(
+            maxDepth=5,
+            maxBins=32,
+            minInstancesPerNode=1,
+            minInfoGain=0.0,
+            maxMemoryInMB=256,
+            cacheNodeIds=False,
+            checkpointInterval=10,
+            impurity="gini",
+            numTrees=20,
+            featureSubsetStrategy="auto",
+            subsamplingRate=1.0,
+            leafCol="",
+            minWeightFractionPerNode=0.0,
+            bootstrap=True,
+        )
 
 
 @inherit_doc
-class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifierParams,
-                             JavaMLWritable, JavaMLReadable):
+class RandomForestClassifier(
+    _JavaProbabilisticClassifier, _RandomForestClassifierParams, JavaMLWritable, JavaMLReadable
+):
     """
     `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
     learning algorithm for classification.
@@ -1665,12 +1907,31 @@ class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifi
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 probabilityCol="probability", rawPredictionCol="rawPrediction",
-                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
-                 numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
-                 leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        maxDepth=5,
+        maxBins=32,
+        minInstancesPerNode=1,
+        minInfoGain=0.0,
+        maxMemoryInMB=256,
+        cacheNodeIds=False,
+        checkpointInterval=10,
+        impurity="gini",
+        numTrees=20,
+        featureSubsetStrategy="auto",
+        seed=None,
+        subsamplingRate=1.0,
+        leafCol="",
+        minWeightFractionPerNode=0.0,
+        weightCol=None,
+        bootstrap=True,
+    ):
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -1681,18 +1942,38 @@ class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifi
         """
         super(RandomForestClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
+            "org.apache.spark.ml.classification.RandomForestClassifier", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  probabilityCol="probability", rawPredictionCol="rawPrediction",
-                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
-                  impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
-                  leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        maxDepth=5,
+        maxBins=32,
+        minInstancesPerNode=1,
+        minInfoGain=0.0,
+        maxMemoryInMB=256,
+        cacheNodeIds=False,
+        checkpointInterval=10,
+        seed=None,
+        impurity="gini",
+        numTrees=20,
+        featureSubsetStrategy="auto",
+        subsamplingRate=1.0,
+        leafCol="",
+        minWeightFractionPerNode=0.0,
+        weightCol=None,
+        bootstrap=True,
+    ):
         """
         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -1806,9 +2087,14 @@ class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifi
         return self._set(minWeightFractionPerNode=value)
 
 
-class RandomForestClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClassificationModel,
-                                      _RandomForestClassifierParams, JavaMLWritable,
-                                      JavaMLReadable, HasTrainingSummary):
+class RandomForestClassificationModel(
+    _TreeEnsembleModel,
+    _JavaProbabilisticClassificationModel,
+    _RandomForestClassifierParams,
+    JavaMLWritable,
+    JavaMLReadable,
+    HasTrainingSummary,
+):
     """
     Model fitted by RandomForestClassifier.
 
@@ -1849,13 +2135,16 @@ class RandomForestClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClas
         if self.hasSummary:
             if self.numClasses <= 2:
                 return BinaryRandomForestClassificationTrainingSummary(
-                    super(RandomForestClassificationModel, self).summary)
+                    super(RandomForestClassificationModel, self).summary
+                )
             else:
                 return RandomForestClassificationTrainingSummary(
-                    super(RandomForestClassificationModel, self).summary)
+                    super(RandomForestClassificationModel, self).summary
+                )
         else:
-            raise RuntimeError("No training summary available for this %s" %
-                               self.__class__.__name__)
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
 
     def evaluate(self, dataset):
         """
@@ -1883,17 +2172,20 @@ class RandomForestClassificationSummary(_ClassificationSummary):
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
 @inherit_doc
-class RandomForestClassificationTrainingSummary(RandomForestClassificationSummary,
-                                                _TrainingSummary):
+class RandomForestClassificationTrainingSummary(
+    RandomForestClassificationSummary, _TrainingSummary
+):
     """
     Abstraction for RandomForestClassificationTraining Training results.
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
@@ -1904,17 +2196,20 @@ class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary):
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
 @inherit_doc
-class BinaryRandomForestClassificationTrainingSummary(BinaryRandomForestClassificationSummary,
-                                                      RandomForestClassificationTrainingSummary):
+class BinaryRandomForestClassificationTrainingSummary(
+    BinaryRandomForestClassificationSummary, RandomForestClassificationTrainingSummary
+):
     """
     BinaryRandomForestClassification training results for a given model.
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
@@ -1927,18 +2222,35 @@ class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
 
     supportedLossTypes = ["logistic"]
 
-    lossType = Param(Params._dummy(), "lossType",
-                     "Loss function which GBT tries to minimize (case-insensitive). " +
-                     "Supported options: " + ", ".join(supportedLossTypes),
-                     typeConverter=TypeConverters.toString)
+    lossType = Param(
+        Params._dummy(),
+        "lossType",
+        "Loss function which GBT tries to minimize (case-insensitive). "
+        + "Supported options: "
+        + ", ".join(supportedLossTypes),
+        typeConverter=TypeConverters.toString,
+    )
 
     def __init__(self, *args):
         super(_GBTClassifierParams, self).__init__(*args)
-        self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                         maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
-                         lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
-                         impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
-                         leafCol="", minWeightFractionPerNode=0.0)
+        self._setDefault(
+            maxDepth=5,
+            maxBins=32,
+            minInstancesPerNode=1,
+            minInfoGain=0.0,
+            maxMemoryInMB=256,
+            cacheNodeIds=False,
+            checkpointInterval=10,
+            lossType="logistic",
+            maxIter=20,
+            stepSize=0.1,
+            subsamplingRate=1.0,
+            impurity="variance",
+            featureSubsetStrategy="all",
+            validationTol=0.01,
+            leafCol="",
+            minWeightFractionPerNode=0.0,
+        )
 
     @since("1.4.0")
     def getLossType(self):
@@ -1949,8 +2261,9 @@ class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
 
 
 @inherit_doc
-class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
-                    JavaMLWritable, JavaMLReadable):
+class GBTClassifier(
+    _JavaProbabilisticClassifier, _GBTClassifierParams, JavaMLWritable, JavaMLReadable
+):
     """
     `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
     learning algorithm for classification.
@@ -2056,12 +2369,32 @@ class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
-                 maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
-                 featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
-                 leafCol="", minWeightFractionPerNode=0.0, weightCol=None):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        maxDepth=5,
+        maxBins=32,
+        minInstancesPerNode=1,
+        minInfoGain=0.0,
+        maxMemoryInMB=256,
+        cacheNodeIds=False,
+        checkpointInterval=10,
+        lossType="logistic",
+        maxIter=20,
+        stepSize=0.1,
+        seed=None,
+        subsamplingRate=1.0,
+        impurity="variance",
+        featureSubsetStrategy="all",
+        validationTol=0.01,
+        validationIndicatorCol=None,
+        leafCol="",
+        minWeightFractionPerNode=0.0,
+        weightCol=None,
+    ):
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
@@ -2073,19 +2406,39 @@ class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
         """
         super(GBTClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.classification.GBTClassifier", self.uid)
+            "org.apache.spark.ml.classification.GBTClassifier", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
-                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
-                  lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
-                  impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
-                  validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
-                  weightCol=None):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        maxDepth=5,
+        maxBins=32,
+        minInstancesPerNode=1,
+        minInfoGain=0.0,
+        maxMemoryInMB=256,
+        cacheNodeIds=False,
+        checkpointInterval=10,
+        lossType="logistic",
+        maxIter=20,
+        stepSize=0.1,
+        seed=None,
+        subsamplingRate=1.0,
+        impurity="variance",
+        featureSubsetStrategy="all",
+        validationTol=0.01,
+        validationIndicatorCol=None,
+        leafCol="",
+        minWeightFractionPerNode=0.0,
+        weightCol=None,
+    ):
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
@@ -2216,8 +2569,13 @@ class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
         return self._set(minWeightFractionPerNode=value)
 
 
-class GBTClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClassificationModel,
-                             _GBTClassifierParams, JavaMLWritable, JavaMLReadable):
+class GBTClassificationModel(
+    _TreeEnsembleModel,
+    _JavaProbabilisticClassificationModel,
+    _GBTClassifierParams,
+    JavaMLWritable,
+    JavaMLReadable,
+):
     """
     Model fitted by GBTClassifier.
 
@@ -2269,12 +2627,20 @@ class _NaiveBayesParams(_PredictorParams, HasWeightCol):
     .. versionadded:: 3.0.0
     """
 
-    smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
-                      "default is 1.0", typeConverter=TypeConverters.toFloat)
-    modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
-                      "(case-sensitive). Supported options: multinomial (default), bernoulli " +
-                      "and gaussian.",
-                      typeConverter=TypeConverters.toString)
+    smoothing = Param(
+        Params._dummy(),
+        "smoothing",
+        "The smoothing parameter, should be >= 0, " + "default is 1.0",
+        typeConverter=TypeConverters.toFloat,
+    )
+    modelType = Param(
+        Params._dummy(),
+        "modelType",
+        "The model type which is a string "
+        + "(case-sensitive). Supported options: multinomial (default), bernoulli "
+        + "and gaussian.",
+        typeConverter=TypeConverters.toString,
+    )
 
     def __init__(self, *args):
         super(_NaiveBayesParams, self).__init__(*args)
@@ -2296,8 +2662,14 @@ class _NaiveBayesParams(_PredictorParams, HasWeightCol):
 
 
 @inherit_doc
-class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds, HasWeightCol,
-                 JavaMLWritable, JavaMLReadable):
+class NaiveBayes(
+    _JavaProbabilisticClassifier,
+    _NaiveBayesParams,
+    HasThresholds,
+    HasWeightCol,
+    JavaMLWritable,
+    JavaMLReadable,
+):
     """
     Naive Bayes Classifiers.
     It supports both Multinomial and Bernoulli NB. `Multinomial NB \
@@ -2392,9 +2764,19 @@ class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
-                 modelType="multinomial", thresholds=None, weightCol=None):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        smoothing=1.0,
+        modelType="multinomial",
+        thresholds=None,
+        weightCol=None,
+    ):
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
@@ -2402,15 +2784,26 @@ class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
         """
         super(NaiveBayes, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.classification.NaiveBayes", self.uid)
+            "org.apache.spark.ml.classification.NaiveBayes", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.5.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
-                  modelType="multinomial", thresholds=None, weightCol=None):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        smoothing=1.0,
+        modelType="multinomial",
+        thresholds=None,
+        weightCol=None,
+    ):
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
@@ -2444,8 +2837,9 @@ class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
         return self._set(weightCol=value)
 
 
-class NaiveBayesModel(_JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable,
-                      JavaMLReadable):
+class NaiveBayesModel(
+    _JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable, JavaMLReadable
+):
     """
     Model fitted by NaiveBayes.
 
@@ -2477,26 +2871,45 @@ class NaiveBayesModel(_JavaProbabilisticClassificationModel, _NaiveBayesParams,
         return self._call_java("sigma")
 
 
-class _MultilayerPerceptronParams(_ProbabilisticClassifierParams, HasSeed, HasMaxIter,
-                                  HasTol, HasStepSize, HasSolver, HasBlockSize):
+class _MultilayerPerceptronParams(
+    _ProbabilisticClassifierParams,
+    HasSeed,
+    HasMaxIter,
+    HasTol,
+    HasStepSize,
+    HasSolver,
+    HasBlockSize,
+):
     """
     Params for :py:class:`MultilayerPerceptronClassifier`.
 
     .. versionadded:: 3.0.0
     """
 
-    layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
-                   "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
-                   "neurons and output layer of 10 neurons.",
-                   typeConverter=TypeConverters.toListInt)
-    solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
-                   "options: l-bfgs, gd.", typeConverter=TypeConverters.toString)
-    initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.",
-                           typeConverter=TypeConverters.toVector)
+    layers = Param(
+        Params._dummy(),
+        "layers",
+        "Sizes of layers from input layer to output layer "
+        + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 "
+        + "neurons and output layer of 10 neurons.",
+        typeConverter=TypeConverters.toListInt,
+    )
+    solver = Param(
+        Params._dummy(),
+        "solver",
+        "The solver algorithm for optimization. Supported " + "options: l-bfgs, gd.",
+        typeConverter=TypeConverters.toString,
+    )
+    initialWeights = Param(
+        Params._dummy(),
+        "initialWeights",
+        "The initial weights of the model.",
+        typeConverter=TypeConverters.toVector,
+    )
 
     def __init__(self, *args):
         super(_MultilayerPerceptronParams, self).__init__(*args)
-        self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
+        self._setDefault(maxIter=100, tol=1e-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
 
     @since("1.6.0")
     def getLayers(self):
@@ -2514,8 +2927,9 @@ class _MultilayerPerceptronParams(_ProbabilisticClassifierParams, HasSeed, HasMa
 
 
 @inherit_doc
-class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPerceptronParams,
-                                     JavaMLWritable, JavaMLReadable):
+class MultilayerPerceptronClassifier(
+    _JavaProbabilisticClassifier, _MultilayerPerceptronParams, JavaMLWritable, JavaMLReadable
+):
     """
     Classifier trainer based on the Multilayer Perceptron.
     Each layer has sigmoid activation function, output layer has softmax.
@@ -2592,10 +3006,23 @@ class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPe
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
-                 solver="l-bfgs", initialWeights=None, probabilityCol="probability",
-                 rawPredictionCol="rawPrediction"):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        maxIter=100,
+        tol=1e-6,
+        seed=None,
+        layers=None,
+        blockSize=128,
+        stepSize=0.03,
+        solver="l-bfgs",
+        initialWeights=None,
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+    ):
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
@@ -2604,16 +3031,30 @@ class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPe
         """
         super(MultilayerPerceptronClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
+            "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.6.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
-                  solver="l-bfgs", initialWeights=None, probabilityCol="probability",
-                  rawPredictionCol="rawPrediction"):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        maxIter=100,
+        tol=1e-6,
+        seed=None,
+        layers=None,
+        blockSize=128,
+        stepSize=0.03,
+        solver="l-bfgs",
+        initialWeights=None,
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+    ):
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
@@ -2680,9 +3121,13 @@ class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPe
         return self._set(solver=value)
 
 
-class MultilayerPerceptronClassificationModel(_JavaProbabilisticClassificationModel,
-                                              _MultilayerPerceptronParams, JavaMLWritable,
-                                              JavaMLReadable, HasTrainingSummary):
+class MultilayerPerceptronClassificationModel(
+    _JavaProbabilisticClassificationModel,
+    _MultilayerPerceptronParams,
+    JavaMLWritable,
+    JavaMLReadable,
+    HasTrainingSummary,
+):
     """
     Model fitted by MultilayerPerceptronClassifier.
 
@@ -2705,10 +3150,12 @@ class MultilayerPerceptronClassificationModel(_JavaProbabilisticClassificationMo
         """
         if self.hasSummary:
             return MultilayerPerceptronClassificationTrainingSummary(
-                super(MultilayerPerceptronClassificationModel, self).summary)
+                super(MultilayerPerceptronClassificationModel, self).summary
+            )
         else:
-            raise RuntimeError("No training summary available for this %s" %
-                               self.__class__.__name__)
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
 
     def evaluate(self, dataset):
         """
@@ -2733,17 +3180,20 @@ class MultilayerPerceptronClassificationSummary(_ClassificationSummary):
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
 @inherit_doc
-class MultilayerPerceptronClassificationTrainingSummary(MultilayerPerceptronClassificationSummary,
-                                                        _TrainingSummary):
+class MultilayerPerceptronClassificationTrainingSummary(
+    MultilayerPerceptronClassificationSummary, _TrainingSummary
+):
     """
     Abstraction for MultilayerPerceptronClassifier Training results.
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
@@ -2815,8 +3265,17 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        rawPredictionCol="rawPrediction",
+        classifier=None,
+        weightCol=None,
+        parallelism=1,
+    ):
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
@@ -2828,8 +3287,17 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
 
     @keyword_only
     @since("2.0.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        rawPredictionCol="rawPrediction",
+        classifier=None,
+        weightCol=None,
+        parallelism=1,
+    ):
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
@@ -2887,15 +3355,16 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
         predictionCol = self.getPredictionCol()
         classifier = self.getClassifier()
 
-        numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
+        numClasses = int(dataset.agg({labelCol: "max"}).head()["max(" + labelCol + ")"]) + 1
 
         weightCol = None
-        if (self.isDefined(self.weightCol) and self.getWeightCol()):
+        if self.isDefined(self.weightCol) and self.getWeightCol():
             if isinstance(classifier, HasWeightCol):
                 weightCol = self.getWeightCol()
             else:
-                warnings.warn("weightCol is ignored, "
-                              "as it is not supported by {} now.".format(classifier))
+                warnings.warn(
+                    "weightCol is ignored, " "as it is not supported by {} now.".format(classifier)
+                )
 
         if weightCol:
             multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
@@ -2911,10 +3380,15 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
             binaryLabelCol = "mc2b$" + str(index)
             trainingDataset = multiclassLabeled.withColumn(
                 binaryLabelCol,
-                when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0))
-            paramMap = dict([(classifier.labelCol, binaryLabelCol),
-                            (classifier.featuresCol, featuresCol),
-                            (classifier.predictionCol, predictionCol)])
+                when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
+            )
+            paramMap = dict(
+                [
+                    (classifier.labelCol, binaryLabelCol),
+                    (classifier.featuresCol, featuresCol),
+                    (classifier.predictionCol, predictionCol),
+                ]
+            )
             if weightCol:
                 paramMap[classifier.weightCol] = weightCol
             return classifier.fit(trainingDataset, paramMap)
@@ -2965,9 +3439,14 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
         rawPredictionCol = java_stage.getRawPredictionCol()
         classifier = JavaParams._from_java(java_stage.getClassifier())
         parallelism = java_stage.getParallelism()
-        py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
-                       rawPredictionCol=rawPredictionCol, classifier=classifier,
-                       parallelism=parallelism)
+        py_stage = cls(
+            featuresCol=featuresCol,
+            labelCol=labelCol,
+            predictionCol=predictionCol,
+            rawPredictionCol=rawPredictionCol,
+            classifier=classifier,
+            parallelism=parallelism,
+        )
         if java_stage.isDefined(java_stage.getParam("weightCol")):
             py_stage.setWeightCol(java_stage.getWeightCol())
         py_stage._resetUid(java_stage.uid())
@@ -2982,14 +3461,15 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
         py4j.java_gateway.JavaObject
             Java object equivalent to this instance.
         """
-        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
-                                             self.uid)
+        _java_obj = JavaParams._new_java_obj(
+            "org.apache.spark.ml.classification.OneVsRest", self.uid
+        )
         _java_obj.setClassifier(self.getClassifier()._to_java())
         _java_obj.setParallelism(self.getParallelism())
         _java_obj.setFeaturesCol(self.getFeaturesCol())
         _java_obj.setLabelCol(self.getLabelCol())
         _java_obj.setPredictionCol(self.getPredictionCol())
-        if (self.isDefined(self.weightCol) and self.getWeightCol()):
+        if self.isDefined(self.weightCol) and self.getWeightCol():
             _java_obj.setWeightCol(self.getWeightCol())
         _java_obj.setRawPredictionCol(self.getRawPredictionCol())
         return _java_obj
@@ -3008,16 +3488,17 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
 class _OneVsRestSharedReadWrite:
     @staticmethod
     def saveImpl(instance, sc, path, extraMetadata=None):
-        skipParams = ['classifier']
+        skipParams = ["classifier"]
         jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams)
-        DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams,
-                                         extraMetadata=extraMetadata)
-        classifierPath = os.path.join(path, 'classifier')
+        DefaultParamsWriter.saveMetadata(
+            instance, path, sc, paramMap=jsonParams, extraMetadata=extraMetadata
+        )
+        classifierPath = os.path.join(path, "classifier")
         instance.getClassifier().save(classifierPath)
 
     @staticmethod
     def loadClassifier(path, sc):
-        classifierPath = os.path.join(path, 'classifier')
+        classifierPath = os.path.join(path, "classifier")
         return DefaultParamsReader.loadParamsInstance(classifierPath, sc)
 
     @staticmethod
@@ -3028,8 +3509,10 @@ class _OneVsRestSharedReadWrite:
 
         for elem in elems_to_check:
             if not isinstance(elem, MLWritable):
-                raise ValueError(f'OneVsRest write will fail because it contains {elem.uid} '
-                                 f'which is not writable.')
+                raise ValueError(
+                    f"OneVsRest write will fail because it contains {elem.uid} "
+                    f"which is not writable."
+                )
 
 
 @inherit_doc
@@ -3044,8 +3527,8 @@ class OneVsRestReader(MLReader):
             return JavaMLReader(self.cls).load(path)
         else:
             classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
-            ova = OneVsRest(classifier=classifier)._resetUid(metadata['uid'])
-            DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=['classifier'])
+            ova = OneVsRest(classifier=classifier)._resetUid(metadata["uid"])
+            DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=["classifier"])
             return ova
 
 
@@ -3096,14 +3579,17 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
         # set java instance
         java_models = [model._to_java() for model in self.models]
         sc = SparkContext._active_spark_context
-        java_models_array = JavaWrapper._new_java_array(java_models,
-                                                        sc._gateway.jvm.org.apache.spark.ml
-                                                        .classification.ClassificationModel)
+        java_models_array = JavaWrapper._new_java_array(
+            java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel
+        )
         # TODO: need to set metadata
         metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
-        self._java_obj = \
-            JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
-                                     self.uid, metadata.empty(), java_models_array)
+        self._java_obj = JavaParams._new_java_obj(
+            "org.apache.spark.ml.classification.OneVsRestModel",
+            self.uid,
+            metadata.empty(),
+            java_models_array,
+        )
 
     def _transform(self, dataset):
         # determine the input columns: these need to be passed through
@@ -3130,21 +3616,25 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
             tmpColName = "mbc$tmp" + str(uuid.uuid4())
             updateUDF = udf(
                 lambda predictions, prediction: predictions + [prediction.tolist()[1]],
-                ArrayType(DoubleType()))
+                ArrayType(DoubleType()),
+            )
             transformedDataset = model.transform(aggregatedDataset).select(*columns)
             updatedDataset = transformedDataset.withColumn(
                 tmpColName,
-                updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]))
+                updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]),
+            )
             newColumns = origCols + [tmpColName]
 
             # switch out the intermediate column with the accumulator column
-            aggregatedDataset = updatedDataset\
-                .select(*newColumns).withColumnRenamed(tmpColName, accColName)
+            aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed(
+                tmpColName, accColName
+            )
 
         if handlePersistence:
             newDataset.unpersist()
 
         if self.getRawPredictionCol():
+
             def func(predictions):
                 predArray = []
                 for x in predictions:
@@ -3153,14 +3643,20 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
 
             rawPredictionUDF = udf(func, VectorUDT())
             aggregatedDataset = aggregatedDataset.withColumn(
-                self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))
+                self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName])
+            )
 
         if self.getPredictionCol():
             # output the index of the classifier with highest confidence as prediction
-            labelUDF = udf(lambda predictions: float(max(enumerate(predictions),
-                           key=operator.itemgetter(1))[0]), DoubleType())
+            labelUDF = udf(
+                lambda predictions: float(
+                    max(enumerate(predictions), key=operator.itemgetter(1))[0]
+                ),
+                DoubleType(),
+            )
             aggregatedDataset = aggregatedDataset.withColumn(
-                self.getPredictionCol(), labelUDF(aggregatedDataset[accColName]))
+                self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])
+            )
         return aggregatedDataset.drop(accColName)
 
     def copy(self, extra=None):
@@ -3198,8 +3694,7 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
         predictionCol = java_stage.getPredictionCol()
         classifier = JavaParams._from_java(java_stage.getClassifier())
         models = [JavaParams._from_java(model) for model in java_stage.models()]
-        py_stage = cls(models=models).setPredictionCol(predictionCol)\
-            .setFeaturesCol(featuresCol)
+        py_stage = cls(models=models).setPredictionCol(predictionCol).setFeaturesCol(featuresCol)
         py_stage._set(labelCol=labelCol)
         if java_stage.isDefined(java_stage.getParam("weightCol")):
             py_stage._set(weightCol=java_stage.getWeightCol())
@@ -3219,15 +3714,20 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
         sc = SparkContext._active_spark_context
         java_models = [model._to_java() for model in self.models]
         java_models_array = JavaWrapper._new_java_array(
-            java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel)
+            java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel
+        )
         metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
-        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
-                                             self.uid, metadata.empty(), java_models_array)
+        _java_obj = JavaParams._new_java_obj(
+            "org.apache.spark.ml.classification.OneVsRestModel",
+            self.uid,
+            metadata.empty(),
+            java_models_array,
+        )
         _java_obj.set("classifier", self.getClassifier()._to_java())
         _java_obj.set("featuresCol", self.getFeaturesCol())
         _java_obj.set("labelCol", self.getLabelCol())
         _java_obj.set("predictionCol", self.getPredictionCol())
-        if (self.isDefined(self.weightCol) and self.getWeightCol()):
+        if self.isDefined(self.weightCol) and self.getWeightCol():
             _java_obj.set("weightCol", self.getWeightCol())
         return _java_obj
 
@@ -3236,8 +3736,9 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
         return OneVsRestModelReader(cls)
 
     def write(self):
-        if all(map(lambda elem: isinstance(elem, JavaMLWritable),
-                   [self.getClassifier()] + self.models)):
+        if all(
+            map(lambda elem: isinstance(elem, JavaMLWritable), [self.getClassifier()] + self.models)
+        ):
             return JavaMLWriter(self)
         else:
             return OneVsRestModelWriter(self)
@@ -3255,14 +3756,14 @@ class OneVsRestModelReader(MLReader):
             return JavaMLReader(self.cls).load(path)
         else:
             classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
-            numClasses = metadata['numClasses']
+            numClasses = metadata["numClasses"]
             subModels = [None] * numClasses
             for idx in range(numClasses):
-                subModelPath = os.path.join(path, f'model_{idx}')
+                subModelPath = os.path.join(path, f"model_{idx}")
                 subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc)
-            ovaModel = OneVsRestModel(subModels)._resetUid(metadata['uid'])
+            ovaModel = OneVsRestModel(subModels)._resetUid(metadata["uid"])
             ovaModel.set(ovaModel.classifier, classifier)
-            DefaultParamsReader.getAndSetParams(ovaModel, metadata, skipParams=['classifier'])
+            DefaultParamsReader.getAndSetParams(ovaModel, metadata, skipParams=["classifier"])
             return ovaModel
 
 
@@ -3276,16 +3777,17 @@ class OneVsRestModelWriter(MLWriter):
         _OneVsRestSharedReadWrite.validateParams(self.instance)
         instance = self.instance
         numClasses = len(instance.models)
-        extraMetadata = {'numClasses': numClasses}
+        extraMetadata = {"numClasses": numClasses}
         _OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata)
         for idx in range(numClasses):
-            subModelPath = os.path.join(path, f'model_{idx}')
+            subModelPath = os.path.join(path, f"model_{idx}")
             instance.models[idx].save(subModelPath)
 
 
 @inherit_doc
-class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,
-                   JavaMLReadable):
+class FMClassifier(
+    _JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable
+):
     """
     Factorization Machines learning algorithm for classification.
 
@@ -3348,11 +3850,27 @@ class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, J
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                 probabilityCol="probability", rawPredictionCol="rawPrediction",
-                 factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
-                 miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
-                 tol=1e-6, solver="adamW", thresholds=None, seed=None):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        factorSize=8,
+        fitIntercept=True,
+        fitLinear=True,
+        regParam=0.0,
+        miniBatchFraction=1.0,
+        initStd=0.01,
+        maxIter=100,
+        stepSize=1.0,
+        tol=1e-6,
+        solver="adamW",
+        thresholds=None,
+        seed=None,
+    ):
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -3362,17 +3880,34 @@ class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, J
         """
         super(FMClassifier, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.classification.FMClassifier", self.uid)
+            "org.apache.spark.ml.classification.FMClassifier", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("3.0.0")
-    def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
-                  probabilityCol="probability", rawPredictionCol="rawPrediction",
-                  factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
-                  miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
-                  tol=1e-6, solver="adamW", thresholds=None, seed=None):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        labelCol="label",
+        predictionCol="prediction",
+        probabilityCol="probability",
+        rawPredictionCol="rawPrediction",
+        factorSize=8,
+        fitIntercept=True,
+        fitLinear=True,
+        regParam=0.0,
+        miniBatchFraction=1.0,
+        initStd=0.01,
+        maxIter=100,
+        stepSize=1.0,
+        tol=1e-6,
+        solver="adamW",
+        thresholds=None,
+        seed=None,
+    ):
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -3465,8 +4000,13 @@ class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, J
         return self._set(regParam=value)
 
 
-class FMClassificationModel(_JavaProbabilisticClassificationModel, _FactorizationMachinesParams,
-                            JavaMLWritable, JavaMLReadable, HasTrainingSummary):
+class FMClassificationModel(
+    _JavaProbabilisticClassificationModel,
+    _FactorizationMachinesParams,
+    JavaMLWritable,
+    JavaMLReadable,
+    HasTrainingSummary,
+):
     """
     Model fitted by :class:`FMClassifier`.
 
@@ -3506,8 +4046,9 @@ class FMClassificationModel(_JavaProbabilisticClassificationModel, _Factorizatio
         if self.hasSummary:
             return FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
         else:
-            raise RuntimeError("No training summary available for this %s" %
-                               self.__class__.__name__)
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
 
     def evaluate(self, dataset):
         """
@@ -3532,6 +4073,7 @@ class FMClassificationSummary(_BinaryClassificationSummary):
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
@@ -3542,6 +4084,7 @@ class FMClassificationTrainingSummary(FMClassificationSummary, _TrainingSummary)
 
     .. versionadded:: 3.1.0
     """
+
     pass
 
 
@@ -3549,24 +4092,24 @@ if __name__ == "__main__":
     import doctest
     import pyspark.ml.classification
     from pyspark.sql import SparkSession
+
     globs = pyspark.ml.classification.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    spark = SparkSession.builder\
-        .master("local[2]")\
-        .appName("ml.classification tests")\
-        .getOrCreate()
+    spark = SparkSession.builder.master("local[2]").appName("ml.classification tests").getOrCreate()
     sc = spark.sparkContext
-    globs['sc'] = sc
-    globs['spark'] = spark
+    globs["sc"] = sc
+    globs["spark"] = spark
     import tempfile
+
     temp_path = tempfile.mkdtemp()
-    globs['temp_path'] = temp_path
+    globs["temp_path"] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
         spark.stop()
     finally:
         from shutil import rmtree
+
         try:
             rmtree(temp_path)
         except OSError:
diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi
index a4a3d21..bb4fb05 100644
--- a/python/pyspark/ml/classification.pyi
+++ b/python/pyspark/ml/classification.pyi
@@ -53,8 +53,15 @@ from pyspark.ml.tree import (
     _TreeClassifierParams,
     _TreeEnsembleModel,
 )
-from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable, \
-    MLReader, MLReadable, MLWriter, MLWritable
+from pyspark.ml.util import (
+    HasTrainingSummary,
+    JavaMLReadable,
+    JavaMLWritable,
+    MLReader,
+    MLReadable,
+    MLWriter,
+    MLWritable,
+)
 from pyspark.ml.wrapper import JavaPredictionModel, JavaPredictor, JavaWrapper
 
 from pyspark.ml.linalg import Matrix, Vector
@@ -75,13 +82,9 @@ class ClassificationModel(PredictionModel, _ClassifierParams, metaclass=abc.ABCM
     @abstractmethod
     def predictRaw(self, value: Vector) -> Vector: ...
 
-class _ProbabilisticClassifierParams(
-    HasProbabilityCol, HasThresholds, _ClassifierParams
-): ...
+class _ProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _ClassifierParams): ...
 
-class ProbabilisticClassifier(
-    Classifier, _ProbabilisticClassifierParams, metaclass=abc.ABCMeta
-):
+class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams, metaclass=abc.ABCMeta):
     def setProbabilityCol(self: P, value: str) -> P: ...
     def setThresholds(self: P, value: List[float]) -> P: ...
 
@@ -200,7 +203,7 @@ class LinearSVC(
         threshold: float = ...,
         weightCol: Optional[str] = ...,
         aggregationDepth: int = ...,
-        maxBlockSizeInMB: float = ...
+        maxBlockSizeInMB: float = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -217,7 +220,7 @@ class LinearSVC(
         threshold: float = ...,
         weightCol: Optional[str] = ...,
         aggregationDepth: int = ...,
-        maxBlockSizeInMB: float = ...
+        maxBlockSizeInMB: float = ...,
     ) -> LinearSVC: ...
     def setMaxIter(self, value: int) -> LinearSVC: ...
     def setRegParam(self, value: float) -> LinearSVC: ...
@@ -306,7 +309,7 @@ class LogisticRegression(
         upperBoundsOnCoefficients: Optional[Matrix] = ...,
         lowerBoundsOnIntercepts: Optional[Vector] = ...,
         upperBoundsOnIntercepts: Optional[Vector] = ...,
-        maxBlockSizeInMB: float = ...
+        maxBlockSizeInMB: float = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -331,7 +334,7 @@ class LogisticRegression(
         upperBoundsOnCoefficients: Optional[Matrix] = ...,
         lowerBoundsOnIntercepts: Optional[Vector] = ...,
         upperBoundsOnIntercepts: Optional[Vector] = ...,
-        maxBlockSizeInMB: float = ...
+        maxBlockSizeInMB: float = ...,
     ) -> LogisticRegression: ...
     def setFamily(self, value: str) -> LogisticRegression: ...
     def setLowerBoundsOnCoefficients(self, value: Matrix) -> LogisticRegression: ...
@@ -373,12 +376,8 @@ class LogisticRegressionSummary(_ClassificationSummary):
     @property
     def featuresCol(self) -> str: ...
 
-class LogisticRegressionTrainingSummary(
-    LogisticRegressionSummary, _TrainingSummary
-): ...
-class BinaryLogisticRegressionSummary(
-    _BinaryClassificationSummary, LogisticRegressionSummary
-): ...
+class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSummary): ...
+class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, LogisticRegressionSummary): ...
 class BinaryLogisticRegressionTrainingSummary(
     BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary
 ): ...
@@ -411,7 +410,7 @@ class DecisionTreeClassifier(
         seed: Optional[int] = ...,
         weightCol: Optional[str] = ...,
         leafCol: str = ...,
-        minWeightFractionPerNode: float = ...
+        minWeightFractionPerNode: float = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -432,7 +431,7 @@ class DecisionTreeClassifier(
         seed: Optional[int] = ...,
         weightCol: Optional[str] = ...,
         leafCol: str = ...,
-        minWeightFractionPerNode: float = ...
+        minWeightFractionPerNode: float = ...,
     ) -> DecisionTreeClassifier: ...
     def setMaxDepth(self, value: int) -> DecisionTreeClassifier: ...
     def setMaxBins(self, value: int) -> DecisionTreeClassifier: ...
@@ -488,7 +487,7 @@ class RandomForestClassifier(
         leafCol: str = ...,
         minWeightFractionPerNode: float = ...,
         weightCol: Optional[str] = ...,
-        bootstrap: Optional[bool] = ...
+        bootstrap: Optional[bool] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -513,7 +512,7 @@ class RandomForestClassifier(
         leafCol: str = ...,
         minWeightFractionPerNode: float = ...,
         weightCol: Optional[str] = ...,
-        bootstrap: Optional[bool] = ...
+        bootstrap: Optional[bool] = ...,
     ) -> RandomForestClassifier: ...
     def setMaxDepth(self, value: int) -> RandomForestClassifier: ...
     def setMaxBins(self, value: int) -> RandomForestClassifier: ...
@@ -590,7 +589,7 @@ class GBTClassifier(
         validationIndicatorCol: Optional[str] = ...,
         leafCol: str = ...,
         minWeightFractionPerNode: float = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -615,7 +614,7 @@ class GBTClassifier(
         validationIndicatorCol: Optional[str] = ...,
         leafCol: str = ...,
         minWeightFractionPerNode: float = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> GBTClassifier: ...
     def setMaxDepth(self, value: int) -> GBTClassifier: ...
     def setMaxBins(self, value: int) -> GBTClassifier: ...
@@ -674,7 +673,7 @@ class NaiveBayes(
         smoothing: float = ...,
         modelType: str = ...,
         thresholds: Optional[List[float]] = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -687,7 +686,7 @@ class NaiveBayes(
         smoothing: float = ...,
         modelType: str = ...,
         thresholds: Optional[List[float]] = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> NaiveBayes: ...
     def setSmoothing(self, value: float) -> NaiveBayes: ...
     def setModelType(self, value: str) -> NaiveBayes: ...
@@ -743,7 +742,7 @@ class MultilayerPerceptronClassifier(
         solver: str = ...,
         initialWeights: Optional[Vector] = ...,
         probabilityCol: str = ...,
-        rawPredictionCol: str = ...
+        rawPredictionCol: str = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -760,7 +759,7 @@ class MultilayerPerceptronClassifier(
         solver: str = ...,
         initialWeights: Optional[Vector] = ...,
         probabilityCol: str = ...,
-        rawPredictionCol: str = ...
+        rawPredictionCol: str = ...,
     ) -> MultilayerPerceptronClassifier: ...
     def setLayers(self, value: List[int]) -> MultilayerPerceptronClassifier: ...
     def setBlockSize(self, value: int) -> MultilayerPerceptronClassifier: ...
@@ -781,9 +780,7 @@ class MultilayerPerceptronClassificationModel(
     @property
     def weights(self) -> Vector: ...
     def summary(self) -> MultilayerPerceptronClassificationTrainingSummary: ...
-    def evaluate(
-        self, dataset: DataFrame
-    ) -> MultilayerPerceptronClassificationSummary: ...
+    def evaluate(self, dataset: DataFrame) -> MultilayerPerceptronClassificationSummary: ...
 
 class MultilayerPerceptronClassificationSummary(_ClassificationSummary): ...
 class MultilayerPerceptronClassificationTrainingSummary(
@@ -810,7 +807,7 @@ class OneVsRest(
         rawPredictionCol: str = ...,
         classifier: Optional[Estimator[M]] = ...,
         weightCol: Optional[str] = ...,
-        parallelism: int = ...
+        parallelism: int = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -821,7 +818,7 @@ class OneVsRest(
         rawPredictionCol: str = ...,
         classifier: Optional[Estimator[M]] = ...,
         weightCol: Optional[str] = ...,
-        parallelism: int = ...
+        parallelism: int = ...,
     ) -> OneVsRest: ...
     def setClassifier(self, value: Estimator[M]) -> OneVsRest: ...
     def setLabelCol(self, value: str) -> OneVsRest: ...
@@ -832,9 +829,7 @@ class OneVsRest(
     def setParallelism(self, value: int) -> OneVsRest: ...
     def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRest: ...
 
-class OneVsRestModel(
-    Model, _OneVsRestParams, MLReadable[OneVsRestModel], MLWritable
-):
+class OneVsRestModel(Model, _OneVsRestParams, MLReadable[OneVsRestModel], MLWritable):
     models: List[Transformer]
     def __init__(self, models: List[Transformer]) -> None: ...
     def setFeaturesCol(self, value: str) -> OneVsRestModel: ...
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index f2a248c..11fbdf5 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -19,20 +19,49 @@ import sys
 import warnings
 
 from pyspark import since, keyword_only
-from pyspark.ml.param.shared import HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, \
-    HasAggregationDepth, HasWeightCol, HasTol, HasProbabilityCol, HasDistanceMeasure, \
-    HasCheckpointInterval, Param, Params, TypeConverters
-from pyspark.ml.util import JavaMLWritable, JavaMLReadable, GeneralJavaMLWritable, \
-    HasTrainingSummary, SparkContext
+from pyspark.ml.param.shared import (
+    HasMaxIter,
+    HasFeaturesCol,
+    HasSeed,
+    HasPredictionCol,
+    HasAggregationDepth,
+    HasWeightCol,
+    HasTol,
+    HasProbabilityCol,
+    HasDistanceMeasure,
+    HasCheckpointInterval,
+    Param,
+    Params,
+    TypeConverters,
+)
+from pyspark.ml.util import (
+    JavaMLWritable,
+    JavaMLReadable,
+    GeneralJavaMLWritable,
+    HasTrainingSummary,
+    SparkContext,
+)
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper
 from pyspark.ml.common import inherit_doc, _java2py
 from pyspark.ml.stat import MultivariateGaussian
 from pyspark.sql import DataFrame
 
-__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary',
-           'KMeans', 'KMeansModel', 'KMeansSummary',
-           'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary',
-           'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel', 'PowerIterationClustering']
+__all__ = [
+    "BisectingKMeans",
+    "BisectingKMeansModel",
+    "BisectingKMeansSummary",
+    "KMeans",
+    "KMeansModel",
+    "KMeansSummary",
+    "GaussianMixture",
+    "GaussianMixtureModel",
+    "GaussianMixtureSummary",
+    "LDA",
+    "LDAModel",
+    "LocalLDAModel",
+    "DistributedLDAModel",
+    "PowerIterationClustering",
+]
 
 
 class ClusteringSummary(JavaWrapper):
@@ -100,16 +129,28 @@ class ClusteringSummary(JavaWrapper):
 
 
 @inherit_doc
-class _GaussianMixtureParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol,
-                             HasProbabilityCol, HasTol, HasAggregationDepth, HasWeightCol):
+class _GaussianMixtureParams(
+    HasMaxIter,
+    HasFeaturesCol,
+    HasSeed,
+    HasPredictionCol,
+    HasProbabilityCol,
+    HasTol,
+    HasAggregationDepth,
+    HasWeightCol,
+):
     """
     Params for :py:class:`GaussianMixture` and :py:class:`GaussianMixtureModel`.
 
     .. versionadded:: 3.0.0
     """
 
-    k = Param(Params._dummy(), "k", "Number of independent Gaussians in the mixture model. " +
-              "Must be > 1.", typeConverter=TypeConverters.toInt)
+    k = Param(
+        Params._dummy(),
+        "k",
+        "Number of independent Gaussians in the mixture model. " + "Must be > 1.",
+        typeConverter=TypeConverters.toInt,
+    )
 
     def __init__(self, *args):
         super(_GaussianMixtureParams, self).__init__(*args)
@@ -123,8 +164,9 @@ class _GaussianMixtureParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionC
         return self.getOrDefault(self.k)
 
 
-class GaussianMixtureModel(JavaModel, _GaussianMixtureParams, JavaMLWritable, JavaMLReadable,
-                           HasTrainingSummary):
+class GaussianMixtureModel(
+    JavaModel, _GaussianMixtureParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary
+):
     """
     Model fitted by GaussianMixture.
 
@@ -173,7 +215,8 @@ class GaussianMixtureModel(JavaModel, _GaussianMixtureParams, JavaMLWritable, Ja
         jgaussians = self._java_obj.gaussians()
         return [
             MultivariateGaussian(_java2py(sc, jgaussian.mean()), _java2py(sc, jgaussian.cov()))
-            for jgaussian in jgaussians]
+            for jgaussian in jgaussians
+        ]
 
     @property
     @since("2.0.0")
@@ -195,8 +238,9 @@ class GaussianMixtureModel(JavaModel, _GaussianMixtureParams, JavaMLWritable, Ja
         if self.hasSummary:
             return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
         else:
-            raise RuntimeError("No training summary available for this %s" %
-                               self.__class__.__name__)
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
 
     @since("3.0.0")
     def predict(self, value):
@@ -336,17 +380,28 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", predictionCol="prediction", k=2,
-                 probabilityCol="probability", tol=0.01, maxIter=100, seed=None,
-                 aggregationDepth=2, weightCol=None):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        predictionCol="prediction",
+        k=2,
+        probabilityCol="probability",
+        tol=0.01,
+        maxIter=100,
+        seed=None,
+        aggregationDepth=2,
+        weightCol=None,
+    ):
         """
         __init__(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
                  probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \
                  aggregationDepth=2, weightCol=None)
         """
         super(GaussianMixture, self).__init__()
-        self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture",
-                                            self.uid)
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.clustering.GaussianMixture", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
@@ -355,9 +410,19 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav
 
     @keyword_only
     @since("2.0.0")
-    def setParams(self, *, featuresCol="features", predictionCol="prediction", k=2,
-                  probabilityCol="probability", tol=0.01, maxIter=100, seed=None,
-                  aggregationDepth=2, weightCol=None):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        predictionCol="prediction",
+        k=2,
+        probabilityCol="probability",
+        tol=0.01,
+        maxIter=100,
+        seed=None,
+        aggregationDepth=2,
+        weightCol=None,
+    ):
         """
         setParams(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
                   probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \
@@ -482,28 +547,46 @@ class KMeansSummary(ClusteringSummary):
 
 
 @inherit_doc
-class _KMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasTol,
-                    HasDistanceMeasure, HasWeightCol):
+class _KMeansParams(
+    HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasTol, HasDistanceMeasure, HasWeightCol
+):
     """
     Params for :py:class:`KMeans` and :py:class:`KMeansModel`.
 
     .. versionadded:: 3.0.0
     """
 
-    k = Param(Params._dummy(), "k", "The number of clusters to create. Must be > 1.",
-              typeConverter=TypeConverters.toInt)
-    initMode = Param(Params._dummy(), "initMode",
-                     "The initialization algorithm. This can be either \"random\" to " +
-                     "choose random points as initial cluster centers, or \"k-means||\" " +
-                     "to use a parallel variant of k-means++",
-                     typeConverter=TypeConverters.toString)
-    initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " +
-                      "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt)
+    k = Param(
+        Params._dummy(),
+        "k",
+        "The number of clusters to create. Must be > 1.",
+        typeConverter=TypeConverters.toInt,
+    )
+    initMode = Param(
+        Params._dummy(),
+        "initMode",
+        'The initialization algorithm. This can be either "random" to '
+        + 'choose random points as initial cluster centers, or "k-means||" '
+        + "to use a parallel variant of k-means++",
+        typeConverter=TypeConverters.toString,
+    )
+    initSteps = Param(
+        Params._dummy(),
+        "initSteps",
+        "The number of steps for k-means|| " + "initialization mode. Must be > 0.",
+        typeConverter=TypeConverters.toInt,
+    )
 
     def __init__(self, *args):
         super(_KMeansParams, self).__init__(*args)
-        self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20,
-                         distanceMeasure="euclidean")
+        self._setDefault(
+            k=2,
+            initMode="k-means||",
+            initSteps=2,
+            tol=1e-4,
+            maxIter=20,
+            distanceMeasure="euclidean",
+        )
 
     @since("1.5.0")
     def getK(self):
@@ -527,8 +610,9 @@ class _KMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasTo
         return self.getOrDefault(self.initSteps)
 
 
-class KMeansModel(JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadable,
-                  HasTrainingSummary):
+class KMeansModel(
+    JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary
+):
     """
     Model fitted by KMeans.
 
@@ -564,8 +648,9 @@ class KMeansModel(JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadabl
         if self.hasSummary:
             return KMeansSummary(super(KMeansModel, self).summary)
         else:
-            raise RuntimeError("No training summary available for this %s" %
-                               self.__class__.__name__)
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
 
     @since("3.0.0")
     def predict(self, value):
@@ -643,9 +728,20 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", predictionCol="prediction", k=2,
-                 initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None,
-                 distanceMeasure="euclidean", weightCol=None):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        predictionCol="prediction",
+        k=2,
+        initMode="k-means||",
+        initSteps=2,
+        tol=1e-4,
+        maxIter=20,
+        seed=None,
+        distanceMeasure="euclidean",
+        weightCol=None,
+    ):
         """
         __init__(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
                  initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
@@ -661,9 +757,20 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
 
     @keyword_only
     @since("1.5.0")
-    def setParams(self, *, featuresCol="features", predictionCol="prediction", k=2,
-                  initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None,
-                  distanceMeasure="euclidean", weightCol=None):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        predictionCol="prediction",
+        k=2,
+        initMode="k-means||",
+        initSteps=2,
+        tol=1e-4,
+        maxIter=20,
+        seed=None,
+        distanceMeasure="euclidean",
+        weightCol=None,
+    ):
         """
         setParams(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
                   initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
@@ -746,20 +853,28 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
 
 
 @inherit_doc
-class _BisectingKMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol,
-                             HasDistanceMeasure, HasWeightCol):
+class _BisectingKMeansParams(
+    HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasDistanceMeasure, HasWeightCol
+):
     """
     Params for :py:class:`BisectingKMeans` and :py:class:`BisectingKMeansModel`.
 
     .. versionadded:: 3.0.0
     """
 
-    k = Param(Params._dummy(), "k", "The desired number of leaf clusters. Must be > 1.",
-              typeConverter=TypeConverters.toInt)
-    minDivisibleClusterSize = Param(Params._dummy(), "minDivisibleClusterSize",
-                                    "The minimum number of points (if >= 1.0) or the minimum " +
-                                    "proportion of points (if < 1.0) of a divisible cluster.",
-                                    typeConverter=TypeConverters.toFloat)
+    k = Param(
+        Params._dummy(),
+        "k",
+        "The desired number of leaf clusters. Must be > 1.",
+        typeConverter=TypeConverters.toInt,
+    )
+    minDivisibleClusterSize = Param(
+        Params._dummy(),
+        "minDivisibleClusterSize",
+        "The minimum number of points (if >= 1.0) or the minimum "
+        + "proportion of points (if < 1.0) of a divisible cluster.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     def __init__(self, *args):
         super(_BisectingKMeansParams, self).__init__(*args)
@@ -780,8 +895,9 @@ class _BisectingKMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionC
         return self.getOrDefault(self.minDivisibleClusterSize)
 
 
-class BisectingKMeansModel(JavaModel, _BisectingKMeansParams, JavaMLWritable, JavaMLReadable,
-                           HasTrainingSummary):
+class BisectingKMeansModel(
+    JavaModel, _BisectingKMeansParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary
+):
     """
     Model fitted by BisectingKMeans.
 
@@ -817,9 +933,12 @@ class BisectingKMeansModel(JavaModel, _BisectingKMeansParams, JavaMLWritable, Ja
             It will be removed in future versions. Use :py:class:`ClusteringEvaluator` instead.
             You can also get the cost on the training dataset in the summary.
         """
-        warnings.warn("Deprecated in 3.0.0. It will be removed in future versions. Use "
-                      "ClusteringEvaluator instead. You can also get the cost on the training "
-                      "dataset in the summary.", FutureWarning)
+        warnings.warn(
+            "Deprecated in 3.0.0. It will be removed in future versions. Use "
+            "ClusteringEvaluator instead. You can also get the cost on the training "
+            "dataset in the summary.",
+            FutureWarning,
+        )
         return self._call_java("computeCost", dataset)
 
     @property
@@ -832,8 +951,9 @@ class BisectingKMeansModel(JavaModel, _BisectingKMeansParams, JavaMLWritable, Ja
         if self.hasSummary:
             return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
         else:
-            raise RuntimeError("No training summary available for this %s" %
-                               self.__class__.__name__)
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
 
     @since("3.0.0")
     def predict(self, value):
@@ -924,25 +1044,44 @@ class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, Jav
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", predictionCol="prediction", maxIter=20,
-                 seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean",
-                 weightCol=None):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        predictionCol="prediction",
+        maxIter=20,
+        seed=None,
+        k=4,
+        minDivisibleClusterSize=1.0,
+        distanceMeasure="euclidean",
+        weightCol=None,
+    ):
         """
         __init__(self, \\*, featuresCol="features", predictionCol="prediction", maxIter=20, \
                  seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \
                  weightCol=None)
         """
         super(BisectingKMeans, self).__init__()
-        self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans",
-                                            self.uid)
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.clustering.BisectingKMeans", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.0.0")
-    def setParams(self, *, featuresCol="features", predictionCol="prediction", maxIter=20,
-                  seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean",
-                  weightCol=None):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        predictionCol="prediction",
+        maxIter=20,
+        seed=None,
+        k=4,
+        minDivisibleClusterSize=1.0,
+        distanceMeasure="euclidean",
+        weightCol=None,
+    ):
         """
         setParams(self, \\*, featuresCol="features", predictionCol="prediction", maxIter=20, \
                   seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \
@@ -1037,52 +1176,95 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval):
     .. versionadded:: 3.0.0
     """
 
-    k = Param(Params._dummy(), "k", "The number of topics (clusters) to infer. Must be > 1.",
-              typeConverter=TypeConverters.toInt)
-    optimizer = Param(Params._dummy(), "optimizer",
-                      "Optimizer or inference algorithm used to estimate the LDA model.  "
-                      "Supported: online, em", typeConverter=TypeConverters.toString)
-    learningOffset = Param(Params._dummy(), "learningOffset",
-                           "A (positive) learning parameter that downweights early iterations."
-                           " Larger values make early iterations count less",
-                           typeConverter=TypeConverters.toFloat)
-    learningDecay = Param(Params._dummy(), "learningDecay", "Learning rate, set as an"
-                          "exponential decay rate. This should be between (0.5, 1.0] to "
-                          "guarantee asymptotic convergence.", typeConverter=TypeConverters.toFloat)
-    subsamplingRate = Param(Params._dummy(), "subsamplingRate",
-                            "Fraction of the corpus to be sampled and used in each iteration "
-                            "of mini-batch gradient descent, in range (0, 1].",
-                            typeConverter=TypeConverters.toFloat)
-    optimizeDocConcentration = Param(Params._dummy(), "optimizeDocConcentration",
-                                     "Indicates whether the docConcentration (Dirichlet parameter "
-                                     "for document-topic distribution) will be optimized during "
-                                     "training.", typeConverter=TypeConverters.toBoolean)
-    docConcentration = Param(Params._dummy(), "docConcentration",
-                             "Concentration parameter (commonly named \"alpha\") for the "
-                             "prior placed on documents' distributions over topics (\"theta\").",
-                             typeConverter=TypeConverters.toListFloat)
-    topicConcentration = Param(Params._dummy(), "topicConcentration",
-                               "Concentration parameter (commonly named \"beta\" or \"eta\") for "
-                               "the prior placed on topic' distributions over terms.",
-                               typeConverter=TypeConverters.toFloat)
-    topicDistributionCol = Param(Params._dummy(), "topicDistributionCol",
-                                 "Output column with estimates of the topic mixture distribution "
-                                 "for each document (often called \"theta\" in the literature). "
-                                 "Returns a vector of zeros for an empty document.",
-                                 typeConverter=TypeConverters.toString)
-    keepLastCheckpoint = Param(Params._dummy(), "keepLastCheckpoint",
-                               "(For EM optimizer) If using checkpointing, this indicates whether"
-                               " to keep the last checkpoint. If false, then the checkpoint will be"
-                               " deleted. Deleting the checkpoint can cause failures if a data"
-                               " partition is lost, so set this bit with care.",
-                               TypeConverters.toBoolean)
+    k = Param(
+        Params._dummy(),
+        "k",
+        "The number of topics (clusters) to infer. Must be > 1.",
+        typeConverter=TypeConverters.toInt,
+    )
+    optimizer = Param(
+        Params._dummy(),
+        "optimizer",
+        "Optimizer or inference algorithm used to estimate the LDA model.  "
+        "Supported: online, em",
+        typeConverter=TypeConverters.toString,
+    )
+    learningOffset = Param(
+        Params._dummy(),
+        "learningOffset",
+        "A (positive) learning parameter that downweights early iterations."
+        " Larger values make early iterations count less",
+        typeConverter=TypeConverters.toFloat,
+    )
+    learningDecay = Param(
+        Params._dummy(),
+        "learningDecay",
+        "Learning rate, set as an"
+        "exponential decay rate. This should be between (0.5, 1.0] to "
+        "guarantee asymptotic convergence.",
+        typeConverter=TypeConverters.toFloat,
+    )
+    subsamplingRate = Param(
+        Params._dummy(),
+        "subsamplingRate",
+        "Fraction of the corpus to be sampled and used in each iteration "
+        "of mini-batch gradient descent, in range (0, 1].",
+        typeConverter=TypeConverters.toFloat,
+    )
+    optimizeDocConcentration = Param(
+        Params._dummy(),
+        "optimizeDocConcentration",
+        "Indicates whether the docConcentration (Dirichlet parameter "
+        "for document-topic distribution) will be optimized during "
+        "training.",
+        typeConverter=TypeConverters.toBoolean,
+    )
+    docConcentration = Param(
+        Params._dummy(),
+        "docConcentration",
+        'Concentration parameter (commonly named "alpha") for the '
+        'prior placed on documents\' distributions over topics ("theta").',
+        typeConverter=TypeConverters.toListFloat,
+    )
+    topicConcentration = Param(
+        Params._dummy(),
+        "topicConcentration",
+        'Concentration parameter (commonly named "beta" or "eta") for '
+        "the prior placed on topic' distributions over terms.",
+        typeConverter=TypeConverters.toFloat,
+    )
+    topicDistributionCol = Param(
+        Params._dummy(),
+        "topicDistributionCol",
+        "Output column with estimates of the topic mixture distribution "
+        'for each document (often called "theta" in the literature). '
+        "Returns a vector of zeros for an empty document.",
+        typeConverter=TypeConverters.toString,
+    )
+    keepLastCheckpoint = Param(
+        Params._dummy(),
+        "keepLastCheckpoint",
+        "(For EM optimizer) If using checkpointing, this indicates whether"
+        " to keep the last checkpoint. If false, then the checkpoint will be"
+        " deleted. Deleting the checkpoint can cause failures if a data"
+        " partition is lost, so set this bit with care.",
+        TypeConverters.toBoolean,
+    )
 
     def __init__(self, *args):
         super(_LDAParams, self).__init__(*args)
-        self._setDefault(maxIter=20, checkpointInterval=10,
-                         k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
-                         subsamplingRate=0.05, optimizeDocConcentration=True,
-                         topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
+        self._setDefault(
+            maxIter=20,
+            checkpointInterval=10,
+            k=10,
+            optimizer="online",
+            learningOffset=1024.0,
+            learningDecay=0.51,
+            subsamplingRate=0.05,
+            optimizeDocConcentration=True,
+            topicDistributionCol="topicDistribution",
+            keepLastCheckpoint=True,
+        )
 
     @since("2.0.0")
     def getK(self):
@@ -1336,6 +1518,7 @@ class LocalLDAModel(LDAModel, JavaMLReadable, JavaMLWritable):
 
     .. versionadded:: 2.0.0
     """
+
     pass
 
 
@@ -1411,11 +1594,24 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable):
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,
-                 k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
-                 subsamplingRate=0.05, optimizeDocConcentration=True,
-                 docConcentration=None, topicConcentration=None,
-                 topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        maxIter=20,
+        seed=None,
+        checkpointInterval=10,
+        k=10,
+        optimizer="online",
+        learningOffset=1024.0,
+        learningDecay=0.51,
+        subsamplingRate=0.05,
+        optimizeDocConcentration=True,
+        docConcentration=None,
+        topicConcentration=None,
+        topicDistributionCol="topicDistribution",
+        keepLastCheckpoint=True,
+    ):
         """
         __init__(self, \\*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\
                   k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
@@ -1436,11 +1632,24 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable):
 
     @keyword_only
     @since("2.0.0")
-    def setParams(self, *, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,
-                  k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
-                  subsamplingRate=0.05, optimizeDocConcentration=True,
-                  docConcentration=None, topicConcentration=None,
-                  topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        maxIter=20,
+        seed=None,
+        checkpointInterval=10,
+        k=10,
+        optimizer="online",
+        learningOffset=1024.0,
+        learningDecay=0.51,
+        subsamplingRate=0.05,
+        optimizeDocConcentration=True,
+        docConcentration=None,
+        topicConcentration=None,
+        topicDistributionCol="topicDistribution",
+        keepLastCheckpoint=True,
+    ):
         """
         setParams(self, \\*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\
                   k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
@@ -1619,21 +1828,33 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
     .. versionadded:: 3.0.0
     """
 
-    k = Param(Params._dummy(), "k",
-              "The number of clusters to create. Must be > 1.",
-              typeConverter=TypeConverters.toInt)
-    initMode = Param(Params._dummy(), "initMode",
-                     "The initialization algorithm. This can be either " +
-                     "'random' to use a random vector as vertex properties, or 'degree' to use " +
-                     "a normalized sum of similarities with other vertices.  Supported options: " +
-                     "'random' and 'degree'.",
-                     typeConverter=TypeConverters.toString)
-    srcCol = Param(Params._dummy(), "srcCol",
-                   "Name of the input column for source vertex IDs.",
-                   typeConverter=TypeConverters.toString)
-    dstCol = Param(Params._dummy(), "dstCol",
-                   "Name of the input column for destination vertex IDs.",
-                   typeConverter=TypeConverters.toString)
+    k = Param(
+        Params._dummy(),
+        "k",
+        "The number of clusters to create. Must be > 1.",
+        typeConverter=TypeConverters.toInt,
+    )
+    initMode = Param(
+        Params._dummy(),
+        "initMode",
+        "The initialization algorithm. This can be either "
+        + "'random' to use a random vector as vertex properties, or 'degree' to use "
+        + "a normalized sum of similarities with other vertices.  Supported options: "
+        + "'random' and 'degree'.",
+        typeConverter=TypeConverters.toString,
+    )
+    srcCol = Param(
+        Params._dummy(),
+        "srcCol",
+        "Name of the input column for source vertex IDs.",
+        typeConverter=TypeConverters.toString,
+    )
+    dstCol = Param(
+        Params._dummy(),
+        "dstCol",
+        "Name of the input column for destination vertex IDs.",
+        typeConverter=TypeConverters.toString,
+    )
 
     def __init__(self, *args):
         super(_PowerIterationClusteringParams, self).__init__(*args)
@@ -1669,8 +1890,9 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
 
 
 @inherit_doc
-class PowerIterationClustering(_PowerIterationClusteringParams, JavaParams, JavaMLReadable,
-                               JavaMLWritable):
+class PowerIterationClustering(
+    _PowerIterationClusteringParams, JavaParams, JavaMLReadable, JavaMLWritable
+):
     """
     Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
     `Lin and Cohen <http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf>`_. From the
@@ -1722,22 +1944,25 @@ class PowerIterationClustering(_PowerIterationClusteringParams, JavaParams, Java
     """
 
     @keyword_only
-    def __init__(self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",
-                 weightCol=None):
+    def __init__(
+        self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", weightCol=None
+    ):
         """
         __init__(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
                  weightCol=None)
         """
         super(PowerIterationClustering, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid)
+            "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.4.0")
-    def setParams(self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",
-                  weightCol=None):
+    def setParams(
+        self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", weightCol=None
+    ):
         """
         setParams(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
                   weightCol=None)
@@ -1822,29 +2047,29 @@ if __name__ == "__main__":
     import numpy
     import pyspark.ml.clustering
     from pyspark.sql import SparkSession
+
     try:
         # Numpy 1.14+ changed it's string format.
-        numpy.set_printoptions(legacy='1.13')
+        numpy.set_printoptions(legacy="1.13")
     except TypeError:
         pass
     globs = pyspark.ml.clustering.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    spark = SparkSession.builder\
-        .master("local[2]")\
-        .appName("ml.clustering tests")\
-        .getOrCreate()
+    spark = SparkSession.builder.master("local[2]").appName("ml.clustering tests").getOrCreate()
     sc = spark.sparkContext
-    globs['sc'] = sc
-    globs['spark'] = spark
+    globs["sc"] = sc
+    globs["spark"] = spark
     import tempfile
+
     temp_path = tempfile.mkdtemp()
-    globs['temp_path'] = temp_path
+    globs["temp_path"] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
         spark.stop()
     finally:
         from shutil import rmtree
+
         try:
             rmtree(temp_path)
         except OSError:
diff --git a/python/pyspark/ml/clustering.pyi b/python/pyspark/ml/clustering.pyi
index e899b60..81074fc 100644
--- a/python/pyspark/ml/clustering.pyi
+++ b/python/pyspark/ml/clustering.pyi
@@ -113,7 +113,7 @@ class GaussianMixture(
         maxIter: int = ...,
         seed: Optional[int] = ...,
         aggregationDepth: int = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -126,7 +126,7 @@ class GaussianMixture(
         maxIter: int = ...,
         seed: Optional[int] = ...,
         aggregationDepth: int = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> GaussianMixture: ...
     def setK(self, value: int) -> GaussianMixture: ...
     def setMaxIter(self, value: int) -> GaussianMixture: ...
@@ -180,9 +180,7 @@ class KMeansModel(
     def summary(self) -> KMeansSummary: ...
     def predict(self, value: Vector) -> int: ...
 
-class KMeans(
-    JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable[KMeans]
-):
+class KMeans(JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable[KMeans]):
     def __init__(
         self,
         *,
@@ -195,7 +193,7 @@ class KMeans(
         maxIter: int = ...,
         seed: Optional[int] = ...,
         distanceMeasure: str = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -209,7 +207,7 @@ class KMeans(
         maxIter: int = ...,
         seed: Optional[int] = ...,
         distanceMeasure: str = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> KMeans: ...
     def setK(self, value: int) -> KMeans: ...
     def setInitMode(self, value: str) -> KMeans: ...
@@ -267,7 +265,7 @@ class BisectingKMeans(
         k: int = ...,
         minDivisibleClusterSize: float = ...,
         distanceMeasure: str = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -279,7 +277,7 @@ class BisectingKMeans(
         k: int = ...,
         minDivisibleClusterSize: float = ...,
         distanceMeasure: str = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> BisectingKMeans: ...
     def setK(self, value: int) -> BisectingKMeans: ...
     def setMinDivisibleClusterSize(self, value: float) -> BisectingKMeans: ...
@@ -329,9 +327,7 @@ class LDAModel(JavaModel, _LDAParams):
     def describeTopics(self, maxTermsPerTopic: int = ...) -> DataFrame: ...
     def estimatedDocConcentration(self) -> Vector: ...
 
-class DistributedLDAModel(
-    LDAModel, JavaMLReadable[DistributedLDAModel], JavaMLWritable
-):
+class DistributedLDAModel(LDAModel, JavaMLReadable[DistributedLDAModel], JavaMLWritable):
     def toLocal(self) -> LDAModel: ...
     def trainingLogLikelihood(self) -> float: ...
     def logPrior(self) -> float: ...
@@ -356,7 +352,7 @@ class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable[LDA], JavaMLWritab
         docConcentration: Optional[List[float]] = ...,
         topicConcentration: Optional[float] = ...,
         topicDistributionCol: str = ...,
-        keepLastCheckpoint: bool = ...
+        keepLastCheckpoint: bool = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -374,7 +370,7 @@ class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable[LDA], JavaMLWritab
         docConcentration: Optional[List[float]] = ...,
         topicConcentration: Optional[float] = ...,
         topicDistributionCol: str = ...,
-        keepLastCheckpoint: bool = ...
+        keepLastCheckpoint: bool = ...,
     ) -> LDA: ...
     def setCheckpointInterval(self, value: int) -> LDA: ...
     def setSeed(self, value: int) -> LDA: ...
@@ -416,7 +412,7 @@ class PowerIterationClustering(
         initMode: str = ...,
         srcCol: str = ...,
         dstCol: str = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -426,7 +422,7 @@ class PowerIterationClustering(
         initMode: str = ...,
         srcCol: str = ...,
         dstCol: str = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> PowerIterationClustering: ...
     def setK(self, value: int) -> PowerIterationClustering: ...
     def setInitMode(self, value: str) -> PowerIterationClustering: ...
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
index 165dc29..8107f91 100644
--- a/python/pyspark/ml/common.py
+++ b/python/pyspark/ml/common.py
@@ -28,9 +28,9 @@ from pyspark.sql import DataFrame, SQLContext
 _old_smart_decode = py4j.protocol.smart_decode
 
 _float_str_mapping = {
-    'nan': 'NaN',
-    'inf': 'Infinity',
-    '-inf': '-Infinity',
+    "nan": "NaN",
+    "inf": "Infinity",
+    "-inf": "-Infinity",
 }
 
 
@@ -40,20 +40,21 @@ def _new_smart_decode(obj):
         return _float_str_mapping.get(s, s)
     return _old_smart_decode(obj)
 
+
 py4j.protocol.smart_decode = _new_smart_decode
 
 
 _picklable_classes = [
-    'SparseVector',
-    'DenseVector',
-    'SparseMatrix',
-    'DenseMatrix',
+    "SparseVector",
+    "DenseVector",
+    "SparseMatrix",
+    "DenseMatrix",
 ]
 
 
 # this will call the ML version of pythonToJava()
 def _to_java_object_rdd(rdd):
-    """ Return an JavaRDD of Object by unpickling
+    """Return an JavaRDD of Object by unpickling
 
     It will convert each Python object into Java object by Pickle, whenever the
     RDD is serialized in batch or not.
@@ -63,7 +64,7 @@ def _to_java_object_rdd(rdd):
 
 
 def _py2java(sc, obj):
-    """ Convert Python object into Java """
+    """Convert Python object into Java"""
     if isinstance(obj, RDD):
         obj = _to_java_object_rdd(obj)
     elif isinstance(obj, DataFrame):
@@ -86,15 +87,15 @@ def _java2py(sc, r, encoding="bytes"):
     if isinstance(r, JavaObject):
         clsName = r.getClass().getSimpleName()
         # convert RDD into JavaRDD
-        if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+        if clsName != "JavaRDD" and clsName.endswith("RDD"):
             r = r.toJavaRDD()
-            clsName = 'JavaRDD'
+            clsName = "JavaRDD"
 
-        if clsName == 'JavaRDD':
+        if clsName == "JavaRDD":
             jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
             return RDD(jrdd, sc)
 
-        if clsName == 'Dataset':
+        if clsName == "Dataset":
             return DataFrame(r, SQLContext.getOrCreate(sc))
 
         if clsName in _picklable_classes:
@@ -111,7 +112,7 @@ def _java2py(sc, r, encoding="bytes"):
 
 
 def callJavaFunc(sc, func, *args):
-    """ Call Java Function """
+    """Call Java Function"""
     args = [_py2java(sc, a) for a in args]
     return _java2py(sc, func(*args))
 
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index e8cada9..be63a8f 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -21,14 +21,26 @@ from abc import abstractmethod, ABCMeta
 from pyspark import since, keyword_only
 from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.param import Param, Params, TypeConverters
-from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasProbabilityCol, \
-    HasRawPredictionCol, HasFeaturesCol, HasWeightCol
+from pyspark.ml.param.shared import (
+    HasLabelCol,
+    HasPredictionCol,
+    HasProbabilityCol,
+    HasRawPredictionCol,
+    HasFeaturesCol,
+    HasWeightCol,
+)
 from pyspark.ml.common import inherit_doc
 from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 
-__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
-           'MulticlassClassificationEvaluator', 'MultilabelClassificationEvaluator',
-           'ClusteringEvaluator', 'RankingEvaluator']
+__all__ = [
+    "Evaluator",
+    "BinaryClassificationEvaluator",
+    "RegressionEvaluator",
+    "MulticlassClassificationEvaluator",
+    "MultilabelClassificationEvaluator",
+    "ClusteringEvaluator",
+    "RankingEvaluator",
+]
 
 
 @inherit_doc
@@ -38,6 +50,7 @@ class Evaluator(Params, metaclass=ABCMeta):
 
     .. versionadded:: 1.4.0
     """
+
     pass
 
     @abstractmethod
@@ -125,8 +138,9 @@ class JavaEvaluator(JavaParams, Evaluator, metaclass=ABCMeta):
 
 
 @inherit_doc
-class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol,
-                                    JavaMLReadable, JavaMLWritable):
+class BinaryClassificationEvaluator(
+    JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol, JavaMLReadable, JavaMLWritable
+):
     """
     Evaluator for binary classification, which expects input columns rawPrediction, label
     and an optional weight column.
@@ -168,25 +182,40 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
     1000
     """
 
-    metricName = Param(Params._dummy(), "metricName",
-                       "metric name in evaluation (areaUnderROC|areaUnderPR)",
-                       typeConverter=TypeConverters.toString)
-
-    numBins = Param(Params._dummy(), "numBins", "Number of bins to down-sample the curves "
-                    "(ROC curve, PR curve) in area computation. If 0, no down-sampling will "
-                    "occur. Must be >= 0.",
-                    typeConverter=TypeConverters.toInt)
+    metricName = Param(
+        Params._dummy(),
+        "metricName",
+        "metric name in evaluation (areaUnderROC|areaUnderPR)",
+        typeConverter=TypeConverters.toString,
+    )
+
+    numBins = Param(
+        Params._dummy(),
+        "numBins",
+        "Number of bins to down-sample the curves "
+        "(ROC curve, PR curve) in area computation. If 0, no down-sampling will "
+        "occur. Must be >= 0.",
+        typeConverter=TypeConverters.toInt,
+    )
 
     @keyword_only
-    def __init__(self, *, rawPredictionCol="rawPrediction", labelCol="label",
-                 metricName="areaUnderROC", weightCol=None, numBins=1000):
+    def __init__(
+        self,
+        *,
+        rawPredictionCol="rawPrediction",
+        labelCol="label",
+        metricName="areaUnderROC",
+        weightCol=None,
+        numBins=1000,
+    ):
         """
         __init__(self, \\*, rawPredictionCol="rawPrediction", labelCol="label", \
                  metricName="areaUnderROC", weightCol=None, numBins=1000)
         """
         super(BinaryClassificationEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
+            "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid
+        )
         self._setDefault(metricName="areaUnderROC", numBins=1000)
         kwargs = self._input_kwargs
         self._set(**kwargs)
@@ -240,8 +269,15 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, rawPredictionCol="rawPrediction", labelCol="label",
-                  metricName="areaUnderROC", weightCol=None, numBins=1000):
+    def setParams(
+        self,
+        *,
+        rawPredictionCol="rawPrediction",
+        labelCol="label",
+        metricName="areaUnderROC",
+        weightCol=None,
+        numBins=1000,
+    ):
         """
         setParams(self, \\*, rawPredictionCol="rawPrediction", labelCol="label", \
                   metricName="areaUnderROC", weightCol=None, numBins=1000)
@@ -252,8 +288,9 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
 
 
 @inherit_doc
-class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
-                          JavaMLReadable, JavaMLWritable):
+class RegressionEvaluator(
+    JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol, JavaMLReadable, JavaMLWritable
+):
     """
     Evaluator for Regression, which expects input columns prediction, label
     and an optional weight column.
@@ -290,29 +327,44 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
     >>> evaluator.getThroughOrigin()
     False
     """
-    metricName = Param(Params._dummy(), "metricName",
-                       """metric name in evaluation - one of:
+
+    metricName = Param(
+        Params._dummy(),
+        "metricName",
+        """metric name in evaluation - one of:
                        rmse - root mean squared error (default)
                        mse - mean squared error
                        r2 - r^2 metric
                        mae - mean absolute error
                        var - explained variance.""",
-                       typeConverter=TypeConverters.toString)
+        typeConverter=TypeConverters.toString,
+    )
 
-    throughOrigin = Param(Params._dummy(), "throughOrigin",
-                          "whether the regression is through the origin.",
-                          typeConverter=TypeConverters.toBoolean)
+    throughOrigin = Param(
+        Params._dummy(),
+        "throughOrigin",
+        "whether the regression is through the origin.",
+        typeConverter=TypeConverters.toBoolean,
+    )
 
     @keyword_only
-    def __init__(self, *, predictionCol="prediction", labelCol="label",
-                 metricName="rmse", weightCol=None, throughOrigin=False):
+    def __init__(
+        self,
+        *,
+        predictionCol="prediction",
+        labelCol="label",
+        metricName="rmse",
+        weightCol=None,
+        throughOrigin=False,
+    ):
         """
         __init__(self, \\*, predictionCol="prediction", labelCol="label", \
                  metricName="rmse", weightCol=None, throughOrigin=False)
         """
         super(RegressionEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
+            "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid
+        )
         self._setDefault(metricName="rmse", throughOrigin=False)
         kwargs = self._input_kwargs
         self._set(**kwargs)
@@ -366,8 +418,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, predictionCol="prediction", labelCol="label",
-                  metricName="rmse", weightCol=None, throughOrigin=False):
+    def setParams(
+        self,
+        *,
+        predictionCol="prediction",
+        labelCol="label",
+        metricName="rmse",
+        weightCol=None,
+        throughOrigin=False,
+    ):
         """
         setParams(self, \\*, predictionCol="prediction", labelCol="label", \
                   metricName="rmse", weightCol=None, throughOrigin=False)
@@ -378,8 +437,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
 
 
 @inherit_doc
-class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
-                                        HasProbabilityCol, JavaMLReadable, JavaMLWritable):
+class MulticlassClassificationEvaluator(
+    JavaEvaluator,
+    HasLabelCol,
+    HasPredictionCol,
+    HasWeightCol,
+    HasProbabilityCol,
+    JavaMLReadable,
+    JavaMLWritable,
+):
     """
     Evaluator for Multiclass Classification, which expects input
     columns: prediction, label, weight (optional) and probabilityCol (only for logLoss).
@@ -432,32 +498,54 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
     >>> evaluator.evaluate(dataset)
     0.9682...
     """
-    metricName = Param(Params._dummy(), "metricName",
-                       "metric name in evaluation "
-                       "(f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate| "
-                       "weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel| "
-                       "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel| "
-                       "logLoss|hammingLoss)",
-                       typeConverter=TypeConverters.toString)
-    metricLabel = Param(Params._dummy(), "metricLabel",
-                        "The class whose metric will be computed in truePositiveRateByLabel|"
-                        "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel."
-                        " Must be >= 0. The default value is 0.",
-                        typeConverter=TypeConverters.toFloat)
-    beta = Param(Params._dummy(), "beta",
-                 "The beta value used in weightedFMeasure|fMeasureByLabel."
-                 " Must be > 0. The default value is 1.",
-                 typeConverter=TypeConverters.toFloat)
-    eps = Param(Params._dummy(), "eps",
-                "log-loss is undefined for p=0 or p=1, so probabilities are clipped to "
-                "max(eps, min(1 - eps, p)). "
-                "Must be in range (0, 0.5). The default value is 1e-15.",
-                typeConverter=TypeConverters.toFloat)
+
+    metricName = Param(
+        Params._dummy(),
+        "metricName",
+        "metric name in evaluation "
+        "(f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate| "
+        "weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel| "
+        "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel| "
+        "logLoss|hammingLoss)",
+        typeConverter=TypeConverters.toString,
+    )
+    metricLabel = Param(
+        Params._dummy(),
+        "metricLabel",
+        "The class whose metric will be computed in truePositiveRateByLabel|"
+        "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel."
+        " Must be >= 0. The default value is 0.",
+        typeConverter=TypeConverters.toFloat,
+    )
+    beta = Param(
+        Params._dummy(),
+        "beta",
+        "The beta value used in weightedFMeasure|fMeasureByLabel."
+        " Must be > 0. The default value is 1.",
+        typeConverter=TypeConverters.toFloat,
+    )
+    eps = Param(
+        Params._dummy(),
+        "eps",
+        "log-loss is undefined for p=0 or p=1, so probabilities are clipped to "
+        "max(eps, min(1 - eps, p)). "
+        "Must be in range (0, 0.5). The default value is 1e-15.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     @keyword_only
-    def __init__(self, *, predictionCol="prediction", labelCol="label",
-                 metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0,
-                 probabilityCol="probability", eps=1e-15):
+    def __init__(
+        self,
+        *,
+        predictionCol="prediction",
+        labelCol="label",
+        metricName="f1",
+        weightCol=None,
+        metricLabel=0.0,
+        beta=1.0,
+        probabilityCol="probability",
+        eps=1e-15,
+    ):
         """
         __init__(self, \\*, predictionCol="prediction", labelCol="label", \
                  metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0, \
@@ -465,7 +553,8 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
         """
         super(MulticlassClassificationEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
+            "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid
+        )
         self._setDefault(metricName="f1", metricLabel=0.0, beta=1.0, eps=1e-15)
         kwargs = self._input_kwargs
         self._set(**kwargs)
@@ -554,9 +643,18 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
 
     @keyword_only
     @since("1.5.0")
-    def setParams(self, *, predictionCol="prediction", labelCol="label",
-                  metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0,
-                  probabilityCol="probability", eps=1e-15):
+    def setParams(
+        self,
+        *,
+        predictionCol="prediction",
+        labelCol="label",
+        metricName="f1",
+        weightCol=None,
+        metricLabel=0.0,
+        beta=1.0,
+        probabilityCol="probability",
+        eps=1e-15,
+    ):
         """
         setParams(self, \\*, predictionCol="prediction", labelCol="label", \
                   metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0, \
@@ -568,8 +666,9 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
 
 
 @inherit_doc
-class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
-                                        JavaMLReadable, JavaMLWritable):
+class MultilabelClassificationEvaluator(
+    JavaEvaluator, HasLabelCol, HasPredictionCol, JavaMLReadable, JavaMLWritable
+):
     """
     Evaluator for Multilabel Classification, which expects two input
     columns: prediction and label.
@@ -600,28 +699,42 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
     >>> str(evaluator2.getPredictionCol())
     'prediction'
     """
-    metricName = Param(Params._dummy(), "metricName",
-                       "metric name in evaluation "
-                       "(subsetAccuracy|accuracy|hammingLoss|precision|recall|f1Measure|"
-                       "precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|"
-                       "microRecall|microF1Measure)",
-                       typeConverter=TypeConverters.toString)
-    metricLabel = Param(Params._dummy(), "metricLabel",
-                        "The class whose metric will be computed in precisionByLabel|"
-                        "recallByLabel|f1MeasureByLabel. "
-                        "Must be >= 0. The default value is 0.",
-                        typeConverter=TypeConverters.toFloat)
+
+    metricName = Param(
+        Params._dummy(),
+        "metricName",
+        "metric name in evaluation "
+        "(subsetAccuracy|accuracy|hammingLoss|precision|recall|f1Measure|"
+        "precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|"
+        "microRecall|microF1Measure)",
+        typeConverter=TypeConverters.toString,
+    )
+    metricLabel = Param(
+        Params._dummy(),
+        "metricLabel",
+        "The class whose metric will be computed in precisionByLabel|"
+        "recallByLabel|f1MeasureByLabel. "
+        "Must be >= 0. The default value is 0.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     @keyword_only
-    def __init__(self, *, predictionCol="prediction", labelCol="label",
-                 metricName="f1Measure", metricLabel=0.0):
+    def __init__(
+        self,
+        *,
+        predictionCol="prediction",
+        labelCol="label",
+        metricName="f1Measure",
+        metricLabel=0.0,
+    ):
         """
         __init__(self, \\*, predictionCol="prediction", labelCol="label", \
                  metricName="f1Measure", metricLabel=0.0)
         """
         super(MultilabelClassificationEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator", self.uid)
+            "org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator", self.uid
+        )
         self._setDefault(metricName="f1Measure", metricLabel=0.0)
         kwargs = self._input_kwargs
         self._set(**kwargs)
@@ -670,8 +783,14 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
 
     @keyword_only
     @since("3.0.0")
-    def setParams(self, *, predictionCol="prediction", labelCol="label",
-                  metricName="f1Measure", metricLabel=0.0):
+    def setParams(
+        self,
+        *,
+        predictionCol="prediction",
+        labelCol="label",
+        metricName="f1Measure",
+        metricLabel=0.0,
+    ):
         """
         setParams(self, \\*, predictionCol="prediction", labelCol="label", \
                   metricName="f1Measure", metricLabel=0.0)
@@ -682,8 +801,9 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
 
 
 @inherit_doc
-class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWeightCol,
-                          JavaMLReadable, JavaMLWritable):
+class ClusteringEvaluator(
+    JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWeightCol, JavaMLReadable, JavaMLWritable
+):
     """
     Evaluator for Clustering results, which expects two input
     columns: prediction and features. The metric computes the Silhouette
@@ -727,31 +847,53 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWe
     >>> str(evaluator2.getPredictionCol())
     'prediction'
     """
-    metricName = Param(Params._dummy(), "metricName",
-                       "metric name in evaluation (silhouette)",
-                       typeConverter=TypeConverters.toString)
-    distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " +
-                            "Supported options: 'squaredEuclidean' and 'cosine'.",
-                            typeConverter=TypeConverters.toString)
+
+    metricName = Param(
+        Params._dummy(),
+        "metricName",
+        "metric name in evaluation (silhouette)",
+        typeConverter=TypeConverters.toString,
+    )
+    distanceMeasure = Param(
+        Params._dummy(),
+        "distanceMeasure",
+        "The distance measure. " + "Supported options: 'squaredEuclidean' and 'cosine'.",
+        typeConverter=TypeConverters.toString,
+    )
 
     @keyword_only
-    def __init__(self, *, predictionCol="prediction", featuresCol="features",
-                 metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None):
+    def __init__(
+        self,
+        *,
+        predictionCol="prediction",
+        featuresCol="features",
+        metricName="silhouette",
+        distanceMeasure="squaredEuclidean",
+        weightCol=None,
+    ):
         """
         __init__(self, \\*, predictionCol="prediction", featuresCol="features", \
                  metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
         """
         super(ClusteringEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid)
+            "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid
+        )
         self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean")
         kwargs = self._input_kwargs
         self._set(**kwargs)
 
     @keyword_only
     @since("2.3.0")
-    def setParams(self, *, predictionCol="prediction", featuresCol="features",
-                  metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None):
+    def setParams(
+        self,
+        *,
+        predictionCol="prediction",
+        featuresCol="features",
+        metricName="silhouette",
+        distanceMeasure="squaredEuclidean",
+        weightCol=None,
+    ):
         """
         setParams(self, \\*, predictionCol="prediction", featuresCol="features", \
                   metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
@@ -809,8 +951,9 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWe
 
 
 @inherit_doc
-class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
-                       JavaMLReadable, JavaMLWritable):
+class RankingEvaluator(
+    JavaEvaluator, HasLabelCol, HasPredictionCol, JavaMLReadable, JavaMLWritable
+):
     """
     Evaluator for Ranking, which expects two input
     columns: prediction and label.
@@ -842,26 +985,40 @@ class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
     >>> str(evaluator2.getPredictionCol())
     'prediction'
     """
-    metricName = Param(Params._dummy(), "metricName",
-                       "metric name in evaluation "
-                       "(meanAveragePrecision|meanAveragePrecisionAtK|"
-                       "precisionAtK|ndcgAtK|recallAtK)",
-                       typeConverter=TypeConverters.toString)
-    k = Param(Params._dummy(), "k",
-              "The ranking position value used in meanAveragePrecisionAtK|precisionAtK|"
-              "ndcgAtK|recallAtK. Must be > 0. The default value is 10.",
-              typeConverter=TypeConverters.toInt)
+
+    metricName = Param(
+        Params._dummy(),
+        "metricName",
+        "metric name in evaluation "
+        "(meanAveragePrecision|meanAveragePrecisionAtK|"
+        "precisionAtK|ndcgAtK|recallAtK)",
+        typeConverter=TypeConverters.toString,
+    )
+    k = Param(
+        Params._dummy(),
+        "k",
+        "The ranking position value used in meanAveragePrecisionAtK|precisionAtK|"
+        "ndcgAtK|recallAtK. Must be > 0. The default value is 10.",
+        typeConverter=TypeConverters.toInt,
+    )
 
     @keyword_only
-    def __init__(self, *, predictionCol="prediction", labelCol="label",
-                 metricName="meanAveragePrecision", k=10):
+    def __init__(
+        self,
+        *,
+        predictionCol="prediction",
+        labelCol="label",
+        metricName="meanAveragePrecision",
+        k=10,
+    ):
         """
         __init__(self, \\*, predictionCol="prediction", labelCol="label", \
                  metricName="meanAveragePrecision", k=10)
         """
         super(RankingEvaluator, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.evaluation.RankingEvaluator", self.uid)
+            "org.apache.spark.ml.evaluation.RankingEvaluator", self.uid
+        )
         self._setDefault(metricName="meanAveragePrecision", k=10)
         kwargs = self._input_kwargs
         self._set(**kwargs)
@@ -910,8 +1067,14 @@ class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
 
     @keyword_only
     @since("3.0.0")
-    def setParams(self, *, predictionCol="prediction", labelCol="label",
-                  metricName="meanAveragePrecision", k=10):
+    def setParams(
+        self,
+        *,
+        predictionCol="prediction",
+        labelCol="label",
+        metricName="meanAveragePrecision",
+        k=10,
+    ):
         """
         setParams(self, \\*, predictionCol="prediction", labelCol="label", \
                   metricName="meanAveragePrecision", k=10)
@@ -926,21 +1089,20 @@ if __name__ == "__main__":
     import tempfile
     import pyspark.ml.evaluation
     from pyspark.sql import SparkSession
+
     globs = pyspark.ml.evaluation.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    spark = SparkSession.builder\
-        .master("local[2]")\
-        .appName("ml.evaluation tests")\
-        .getOrCreate()
-    globs['spark'] = spark
+    spark = SparkSession.builder.master("local[2]").appName("ml.evaluation tests").getOrCreate()
+    globs["spark"] = spark
     temp_path = tempfile.mkdtemp()
-    globs['temp_path'] = temp_path
+    globs["temp_path"] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
         spark.stop()
     finally:
         from shutil import rmtree
+
         try:
             rmtree(temp_path)
         except OSError:
diff --git a/python/pyspark/ml/evaluation.pyi b/python/pyspark/ml/evaluation.pyi
index 55a3ae2..d7883f4 100644
--- a/python/pyspark/ml/evaluation.pyi
+++ b/python/pyspark/ml/evaluation.pyi
@@ -42,9 +42,7 @@ from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 from pyspark.sql.dataframe import DataFrame
 
 class Evaluator(Params, metaclass=abc.ABCMeta):
-    def evaluate(
-        self, dataset: DataFrame, params: Optional[ParamMap] = ...
-    ) -> float: ...
+    def evaluate(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> float: ...
     def isLargerBetter(self) -> bool: ...
 
 class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta):
@@ -67,7 +65,7 @@ class BinaryClassificationEvaluator(
         labelCol: str = ...,
         metricName: BinaryClassificationEvaluatorMetricType = ...,
         weightCol: Optional[str] = ...,
-        numBins: int = ...
+        numBins: int = ...,
     ) -> None: ...
     def setMetricName(
         self, value: BinaryClassificationEvaluatorMetricType
@@ -85,7 +83,7 @@ class BinaryClassificationEvaluator(
         labelCol: str = ...,
         metricName: BinaryClassificationEvaluatorMetricType = ...,
         weightCol: Optional[str] = ...,
-        numBins: int = ...
+        numBins: int = ...,
     ) -> BinaryClassificationEvaluator: ...
 
 class RegressionEvaluator(
@@ -105,11 +103,9 @@ class RegressionEvaluator(
         labelCol: str = ...,
         metricName: RegressionEvaluatorMetricType = ...,
         weightCol: Optional[str] = ...,
-        throughOrigin: bool = ...
+        throughOrigin: bool = ...,
     ) -> None: ...
-    def setMetricName(
-        self, value: RegressionEvaluatorMetricType
-    ) -> RegressionEvaluator: ...
+    def setMetricName(self, value: RegressionEvaluatorMetricType) -> RegressionEvaluator: ...
     def getMetricName(self) -> RegressionEvaluatorMetricType: ...
     def setThroughOrigin(self, value: bool) -> RegressionEvaluator: ...
     def getThroughOrigin(self) -> bool: ...
@@ -123,7 +119,7 @@ class RegressionEvaluator(
         labelCol: str = ...,
         metricName: RegressionEvaluatorMetricType = ...,
         weightCol: Optional[str] = ...,
-        throughOrigin: bool = ...
+        throughOrigin: bool = ...,
     ) -> RegressionEvaluator: ...
 
 class MulticlassClassificationEvaluator(
@@ -149,7 +145,7 @@ class MulticlassClassificationEvaluator(
         metricLabel: float = ...,
         beta: float = ...,
         probabilityCol: str = ...,
-        eps: float = ...
+        eps: float = ...,
     ) -> None: ...
     def setMetricName(
         self, value: MulticlassClassificationEvaluatorMetricType
@@ -175,7 +171,7 @@ class MulticlassClassificationEvaluator(
         metricLabel: float = ...,
         beta: float = ...,
         probabilityCol: str = ...,
-        eps: float = ...
+        eps: float = ...,
     ) -> MulticlassClassificationEvaluator: ...
 
 class MultilabelClassificationEvaluator(
@@ -193,7 +189,7 @@ class MultilabelClassificationEvaluator(
         predictionCol: str = ...,
         labelCol: str = ...,
         metricName: MultilabelClassificationEvaluatorMetricType = ...,
-        metricLabel: float = ...
+        metricLabel: float = ...,
     ) -> None: ...
     def setMetricName(
         self, value: MultilabelClassificationEvaluatorMetricType
@@ -209,7 +205,7 @@ class MultilabelClassificationEvaluator(
         predictionCol: str = ...,
         labelCol: str = ...,
         metricName: MultilabelClassificationEvaluatorMetricType = ...,
-        metricLabel: float = ...
+        metricLabel: float = ...,
     ) -> MultilabelClassificationEvaluator: ...
 
 class ClusteringEvaluator(
@@ -229,7 +225,7 @@ class ClusteringEvaluator(
         featuresCol: str = ...,
         metricName: ClusteringEvaluatorMetricType = ...,
         distanceMeasure: str = ...,
-        weightCol: Optional[str] = ...
+        weightCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -238,11 +234,9 @@ class ClusteringEvaluator(
         featuresCol: str = ...,
         metricName: ClusteringEvaluatorMetricType = ...,
         distanceMeasure: str = ...,
-        weightCol: Optional[str] = ...
-    ) -> ClusteringEvaluator: ...
-    def setMetricName(
-        self, value: ClusteringEvaluatorMetricType
+        weightCol: Optional[str] = ...,
     ) -> ClusteringEvaluator: ...
+    def setMetricName(self, value: ClusteringEvaluatorMetricType) -> ClusteringEvaluator: ...
     def getMetricName(self) -> ClusteringEvaluatorMetricType: ...
     def setDistanceMeasure(self, value: str) -> ClusteringEvaluator: ...
     def getDistanceMeasure(self) -> str: ...
@@ -265,7 +259,7 @@ class RankingEvaluator(
         predictionCol: str = ...,
         labelCol: str = ...,
         metricName: RankingEvaluatorMetricType = ...,
-        k: int = ...
+        k: int = ...,
     ) -> None: ...
     def setMetricName(self, value: RankingEvaluatorMetricType) -> RankingEvaluator: ...
     def getMetricName(self) -> RankingEvaluatorMetricType: ...
@@ -279,5 +273,5 @@ class RankingEvaluator(
         predictionCol: str = ...,
         labelCol: str = ...,
         metricName: RankingEvaluatorMetricType = ...,
-        k: int = ...
+        k: int = ...,
     ) -> RankingEvaluator: ...
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index cf6b91c..18731ae 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -17,55 +17,100 @@
 
 from pyspark import since, keyword_only, SparkContext
 from pyspark.ml.linalg import _convert_to_vector
-from pyspark.ml.param.shared import HasThreshold, HasThresholds, HasInputCol, HasOutputCol, \
-    HasInputCols, HasOutputCols, HasHandleInvalid, HasRelativeError, HasFeaturesCol, HasLabelCol, \
-    HasSeed, HasNumFeatures, HasStepSize, HasMaxIter, TypeConverters, Param, Params
+from pyspark.ml.param.shared import (
+    HasThreshold,
+    HasThresholds,
+    HasInputCol,
+    HasOutputCol,
+    HasInputCols,
+    HasOutputCols,
+    HasHandleInvalid,
+    HasRelativeError,
+    HasFeaturesCol,
+    HasLabelCol,
+    HasSeed,
+    HasNumFeatures,
+    HasStepSize,
+    HasMaxIter,
+    TypeConverters,
+    Param,
+    Params,
+)
 from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
 from pyspark.ml.common import inherit_doc
 
-__all__ = ['Binarizer',
-           'BucketedRandomProjectionLSH', 'BucketedRandomProjectionLSHModel',
-           'Bucketizer',
-           'ChiSqSelector', 'ChiSqSelectorModel',
-           'CountVectorizer', 'CountVectorizerModel',
-           'DCT',
-           'ElementwiseProduct',
-           'FeatureHasher',
-           'HashingTF',
-           'IDF', 'IDFModel',
-           'Imputer', 'ImputerModel',
-           'IndexToString',
-           'Interaction',
-           'MaxAbsScaler', 'MaxAbsScalerModel',
-           'MinHashLSH', 'MinHashLSHModel',
-           'MinMaxScaler', 'MinMaxScalerModel',
-           'NGram',
-           'Normalizer',
-           'OneHotEncoder', 'OneHotEncoderModel',
-           'PCA', 'PCAModel',
-           'PolynomialExpansion',
-           'QuantileDiscretizer',
-           'RobustScaler', 'RobustScalerModel',
-           'RegexTokenizer',
-           'RFormula', 'RFormulaModel',
-           'SQLTransformer',
-           'StandardScaler', 'StandardScalerModel',
-           'StopWordsRemover',
-           'StringIndexer', 'StringIndexerModel',
-           'Tokenizer',
-           'UnivariateFeatureSelector', 'UnivariateFeatureSelectorModel',
-           'VarianceThresholdSelector', 'VarianceThresholdSelectorModel',
-           'VectorAssembler',
-           'VectorIndexer', 'VectorIndexerModel',
-           'VectorSizeHint',
-           'VectorSlicer',
-           'Word2Vec', 'Word2VecModel']
+__all__ = [
+    "Binarizer",
+    "BucketedRandomProjectionLSH",
+    "BucketedRandomProjectionLSHModel",
+    "Bucketizer",
+    "ChiSqSelector",
+    "ChiSqSelectorModel",
+    "CountVectorizer",
+    "CountVectorizerModel",
+    "DCT",
+    "ElementwiseProduct",
+    "FeatureHasher",
+    "HashingTF",
+    "IDF",
+    "IDFModel",
+    "Imputer",
+    "ImputerModel",
+    "IndexToString",
+    "Interaction",
+    "MaxAbsScaler",
+    "MaxAbsScalerModel",
+    "MinHashLSH",
+    "MinHashLSHModel",
+    "MinMaxScaler",
+    "MinMaxScalerModel",
+    "NGram",
+    "Normalizer",
+    "OneHotEncoder",
+    "OneHotEncoderModel",
+    "PCA",
+    "PCAModel",
+    "PolynomialExpansion",
+    "QuantileDiscretizer",
+    "RobustScaler",
+    "RobustScalerModel",
+    "RegexTokenizer",
+    "RFormula",
+    "RFormulaModel",
+    "SQLTransformer",
+    "StandardScaler",
+    "StandardScalerModel",
+    "StopWordsRemover",
+    "StringIndexer",
+    "StringIndexerModel",
+    "Tokenizer",
+    "UnivariateFeatureSelector",
+    "UnivariateFeatureSelectorModel",
+    "VarianceThresholdSelector",
+    "VarianceThresholdSelectorModel",
+    "VectorAssembler",
+    "VectorIndexer",
+    "VectorIndexerModel",
+    "VectorSizeHint",
+    "VectorSlicer",
+    "Word2Vec",
+    "Word2VecModel",
+]
 
 
 @inherit_doc
-class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOutputCol,
-                HasInputCols, HasOutputCols, JavaMLReadable, JavaMLWritable):
+class Binarizer(
+    JavaTransformer,
+    HasThreshold,
+    HasThresholds,
+    HasInputCol,
+    HasOutputCol,
+    HasInputCols,
+    HasOutputCols,
+    JavaMLReadable,
+    JavaMLWritable,
+):
     """
     Binarize a column of continuous features given a threshold. Since 3.0.0,
     :py:class:`Binarize` can map multiple columns at once by setting the :py:attr:`inputCols`
@@ -112,21 +157,35 @@ class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOu
     ...
     """
 
-    threshold = Param(Params._dummy(), "threshold",
-                      "Param for threshold used to binarize continuous features. " +
-                      "The features greater than the threshold will be binarized to 1.0. " +
-                      "The features equal to or less than the threshold will be binarized to 0.0",
-                      typeConverter=TypeConverters.toFloat)
-    thresholds = Param(Params._dummy(), "thresholds",
-                       "Param for array of threshold used to binarize continuous features. " +
-                       "This is for multiple columns input. If transforming multiple columns " +
-                       "and thresholds is not set, but threshold is set, then threshold will " +
-                       "be applied across all columns.",
-                       typeConverter=TypeConverters.toListFloat)
+    threshold = Param(
+        Params._dummy(),
+        "threshold",
+        "Param for threshold used to binarize continuous features. "
+        + "The features greater than the threshold will be binarized to 1.0. "
+        + "The features equal to or less than the threshold will be binarized to 0.0",
+        typeConverter=TypeConverters.toFloat,
+    )
+    thresholds = Param(
+        Params._dummy(),
+        "thresholds",
+        "Param for array of threshold used to binarize continuous features. "
+        + "This is for multiple columns input. If transforming multiple columns "
+        + "and thresholds is not set, but threshold is set, then threshold will "
+        + "be applied across all columns.",
+        typeConverter=TypeConverters.toListFloat,
+    )
 
     @keyword_only
-    def __init__(self, *, threshold=0.0, inputCol=None, outputCol=None, thresholds=None,
-                 inputCols=None, outputCols=None):
+    def __init__(
+        self,
+        *,
+        threshold=0.0,
+        inputCol=None,
+        outputCol=None,
+        thresholds=None,
+        inputCols=None,
+        outputCols=None,
+    ):
         """
         __init__(self, \\*, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, \
                  inputCols=None, outputCols=None)
@@ -139,8 +198,16 @@ class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOu
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, threshold=0.0, inputCol=None, outputCol=None, thresholds=None,
-                  inputCols=None, outputCols=None):
+    def setParams(
+        self,
+        *,
+        threshold=0.0,
+        inputCol=None,
+        outputCol=None,
+        thresholds=None,
+        inputCols=None,
+        outputCols=None,
+    ):
         """
         setParams(self, \\*, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, \
                   inputCols=None, outputCols=None)
@@ -195,10 +262,14 @@ class _LSHParams(HasInputCol, HasOutputCol):
     Mixin for Locality Sensitive Hashing (LSH) algorithm parameters.
     """
 
-    numHashTables = Param(Params._dummy(), "numHashTables", "number of hash tables, where " +
-                          "increasing number of hash tables lowers the false negative rate, " +
-                          "and decreasing it improves the running performance.",
-                          typeConverter=TypeConverters.toInt)
+    numHashTables = Param(
+        Params._dummy(),
+        "numHashTables",
+        "number of hash tables, where "
+        + "increasing number of hash tables lowers the false negative rate, "
+        + "and decreasing it improves the running performance.",
+        typeConverter=TypeConverters.toInt,
+    )
 
     def __init__(self, *args):
         super(_LSHParams, self).__init__(*args)
@@ -281,8 +352,7 @@ class _LSHModel(JavaModel, _LSHParams):
             A dataset containing at most k items closest to the key. A column "distCol" is
             added to show the distance between each row and the key.
         """
-        return self._call_java("approxNearestNeighbors", dataset, key, numNearestNeighbors,
-                               distCol)
+        return self._call_java("approxNearestNeighbors", dataset, key, numNearestNeighbors, distCol)
 
     def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol"):
         """
@@ -314,7 +384,7 @@ class _LSHModel(JavaModel, _LSHParams):
         return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol)
 
 
-class _BucketedRandomProjectionLSHParams():
+class _BucketedRandomProjectionLSHParams:
     """
     Params for :py:class:`BucketedRandomProjectionLSH` and
     :py:class:`BucketedRandomProjectionLSHModel`.
@@ -322,9 +392,12 @@ class _BucketedRandomProjectionLSHParams():
     .. versionadded:: 3.0.0
     """
 
-    bucketLength = Param(Params._dummy(), "bucketLength", "the length of each hash bucket, " +
-                         "a larger bucket lowers the false negative rate.",
-                         typeConverter=TypeConverters.toFloat)
+    bucketLength = Param(
+        Params._dummy(),
+        "bucketLength",
+        "the length of each hash bucket, " + "a larger bucket lowers the false negative rate.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     @since("2.2.0")
     def getBucketLength(self):
@@ -335,8 +408,9 @@ class _BucketedRandomProjectionLSHParams():
 
 
 @inherit_doc
-class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
-                                  HasSeed, JavaMLReadable, JavaMLWritable):
+class BucketedRandomProjectionLSH(
+    _LSH, _BucketedRandomProjectionLSHParams, HasSeed, JavaMLReadable, JavaMLWritable
+):
     """
     LSH class for Euclidean distance metrics.
     The input is dense or sparse vectors, each of which represents a point in the Euclidean
@@ -417,22 +491,25 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
     """
 
     @keyword_only
-    def __init__(self, *, inputCol=None, outputCol=None, seed=None, numHashTables=1,
-                 bucketLength=None):
+    def __init__(
+        self, *, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None
+    ):
         """
         __init__(self, \\*, inputCol=None, outputCol=None, seed=None, numHashTables=1, \
                  bucketLength=None)
         """
         super(BucketedRandomProjectionLSH, self).__init__()
-        self._java_obj = \
-            self._new_java_obj("org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid)
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.2.0")
-    def setParams(self, *, inputCol=None, outputCol=None, seed=None, numHashTables=1,
-                  bucketLength=None):
+    def setParams(
+        self, *, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None
+    ):
         """
         setParams(self, \\*, inputCol=None, outputCol=None, seed=None, numHashTables=1, \
                   bucketLength=None)
@@ -458,8 +535,9 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
         return BucketedRandomProjectionLSHModel(java_model)
 
 
-class BucketedRandomProjectionLSHModel(_LSHModel, _BucketedRandomProjectionLSHParams,
-                                       JavaMLReadable, JavaMLWritable):
+class BucketedRandomProjectionLSHModel(
+    _LSHModel, _BucketedRandomProjectionLSHParams, JavaMLReadable, JavaMLWritable
+):
     r"""
     Model fitted by :py:class:`BucketedRandomProjectionLSH`, where multiple random vectors are
     stored. The vectors are normalized to be unit vectors and each vector is used in a hash
@@ -472,8 +550,16 @@ class BucketedRandomProjectionLSHModel(_LSHModel, _BucketedRandomProjectionLSHPa
 
 
 @inherit_doc
-class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
-                 HasHandleInvalid, JavaMLReadable, JavaMLWritable):
+class Bucketizer(
+    JavaTransformer,
+    HasInputCol,
+    HasOutputCol,
+    HasInputCols,
+    HasOutputCols,
+    HasHandleInvalid,
+    JavaMLReadable,
+    JavaMLWritable,
+):
     """
     Maps a column of continuous features to a column of feature buckets. Since 3.0.0,
     :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols`
@@ -539,40 +625,59 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu
     ...
     """
 
-    splits = \
-        Param(Params._dummy(), "splits",
-              "Split points for mapping continuous features into buckets. With n+1 splits, " +
-              "there are n buckets. A bucket defined by splits x,y holds values in the " +
-              "range [x,y) except the last bucket, which also includes y. The splits " +
-              "should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
-              "explicitly provided to cover all Double values; otherwise, values outside the " +
-              "splits specified will be treated as errors.",
-              typeConverter=TypeConverters.toListFloat)
-
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries "
-                          "containing NaN values. Values outside the splits will always be treated "
-                          "as errors. Options are 'skip' (filter out rows with invalid values), " +
-                          "'error' (throw an error), or 'keep' (keep invalid values in a " +
-                          "special additional bucket). Note that in the multiple column " +
-                          "case, the invalid handling is applied to all columns. That said " +
-                          "for 'error' it will throw an error if any invalids are found in " +
-                          "any column, for 'skip' it will skip rows with any invalids in " +
-                          "any columns, etc.",
-                          typeConverter=TypeConverters.toString)
-
-    splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " +
-                        "continuous features into buckets for multiple columns. For each input " +
-                        "column, with n+1 splits, there are n buckets. A bucket defined by " +
-                        "splits x,y holds values in the range [x,y) except the last bucket, " +
-                        "which also includes y. The splits should be of length >= 3 and " +
-                        "strictly increasing. Values at -inf, inf must be explicitly provided " +
-                        "to cover all Double values; otherwise, values outside the splits " +
-                        "specified will be treated as errors.",
-                        typeConverter=TypeConverters.toListListFloat)
+    splits = Param(
+        Params._dummy(),
+        "splits",
+        "Split points for mapping continuous features into buckets. With n+1 splits, "
+        + "there are n buckets. A bucket defined by splits x,y holds values in the "
+        + "range [x,y) except the last bucket, which also includes y. The splits "
+        + "should be of length >= 3 and strictly increasing. Values at -inf, inf must be "
+        + "explicitly provided to cover all Double values; otherwise, values outside the "
+        + "splits specified will be treated as errors.",
+        typeConverter=TypeConverters.toListFloat,
+    )
+
+    handleInvalid = Param(
+        Params._dummy(),
+        "handleInvalid",
+        "how to handle invalid entries "
+        "containing NaN values. Values outside the splits will always be treated "
+        "as errors. Options are 'skip' (filter out rows with invalid values), "
+        + "'error' (throw an error), or 'keep' (keep invalid values in a "
+        + "special additional bucket). Note that in the multiple column "
+        + "case, the invalid handling is applied to all columns. That said "
+        + "for 'error' it will throw an error if any invalids are found in "
+        + "any column, for 'skip' it will skip rows with any invalids in "
+        + "any columns, etc.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    splitsArray = Param(
+        Params._dummy(),
+        "splitsArray",
+        "The array of split points for mapping "
+        + "continuous features into buckets for multiple columns. For each input "
+        + "column, with n+1 splits, there are n buckets. A bucket defined by "
+        + "splits x,y holds values in the range [x,y) except the last bucket, "
+        + "which also includes y. The splits should be of length >= 3 and "
+        + "strictly increasing. Values at -inf, inf must be explicitly provided "
+        + "to cover all Double values; otherwise, values outside the splits "
+        + "specified will be treated as errors.",
+        typeConverter=TypeConverters.toListListFloat,
+    )
 
     @keyword_only
-    def __init__(self, *, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
-                 splitsArray=None, inputCols=None, outputCols=None):
+    def __init__(
+        self,
+        *,
+        splits=None,
+        inputCol=None,
+        outputCol=None,
+        handleInvalid="error",
+        splitsArray=None,
+        inputCols=None,
+        outputCols=None,
+    ):
         """
         __init__(self, \\*, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
                  splitsArray=None, inputCols=None, outputCols=None)
@@ -585,8 +690,17 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
-                  splitsArray=None, inputCols=None, outputCols=None):
+    def setParams(
+        self,
+        *,
+        splits=None,
+        inputCol=None,
+        outputCol=None,
+        handleInvalid="error",
+        splitsArray=None,
+        inputCols=None,
+        outputCols=None,
+    ):
         """
         setParams(self, \\*, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
                   splitsArray=None, inputCols=None, outputCols=None)
@@ -662,35 +776,53 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
     """
 
     minTF = Param(
-        Params._dummy(), "minTF", "Filter to ignore rare words in" +
-        " a document. For each document, terms with frequency/count less than the given" +
-        " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
-        " times the term must appear in the document); if this is a double in [0,1), then this " +
-        "specifies a fraction (out of the document's token count). Note that the parameter is " +
-        "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
-        typeConverter=TypeConverters.toFloat)
+        Params._dummy(),
+        "minTF",
+        "Filter to ignore rare words in"
+        + " a document. For each document, terms with frequency/count less than the given"
+        + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of"
+        + " times the term must appear in the document); if this is a double in [0,1), then this "
+        + "specifies a fraction (out of the document's token count). Note that the parameter is "
+        + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
+        typeConverter=TypeConverters.toFloat,
+    )
     minDF = Param(
-        Params._dummy(), "minDF", "Specifies the minimum number of" +
-        " different documents a term must appear in to be included in the vocabulary." +
-        " If this is an integer >= 1, this specifies the number of documents the term must" +
-        " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
-        " Default 1.0", typeConverter=TypeConverters.toFloat)
+        Params._dummy(),
+        "minDF",
+        "Specifies the minimum number of"
+        + " different documents a term must appear in to be included in the vocabulary."
+        + " If this is an integer >= 1, this specifies the number of documents the term must"
+        + " appear in; if this is a double in [0,1), then this specifies the fraction of documents."
+        + " Default 1.0",
+        typeConverter=TypeConverters.toFloat,
+    )
     maxDF = Param(
-        Params._dummy(), "maxDF", "Specifies the maximum number of" +
-        " different documents a term could appear in to be included in the vocabulary." +
-        " A term that appears more than the threshold will be ignored. If this is an" +
-        " integer >= 1, this specifies the maximum number of documents the term could appear in;" +
-        " if this is a double in [0,1), then this specifies the maximum" +
-        " fraction of documents the term could appear in." +
-        " Default (2^63) - 1", typeConverter=TypeConverters.toFloat)
+        Params._dummy(),
+        "maxDF",
+        "Specifies the maximum number of"
+        + " different documents a term could appear in to be included in the vocabulary."
+        + " A term that appears more than the threshold will be ignored. If this is an"
+        + " integer >= 1, this specifies the maximum number of documents the term could appear in;"
+        + " if this is a double in [0,1), then this specifies the maximum"
+        + " fraction of documents the term could appear in."
+        + " Default (2^63) - 1",
+        typeConverter=TypeConverters.toFloat,
+    )
     vocabSize = Param(
-        Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
-        typeConverter=TypeConverters.toInt)
+        Params._dummy(),
+        "vocabSize",
+        "max size of the vocabulary. Default 1 << 18.",
+        typeConverter=TypeConverters.toInt,
+    )
     binary = Param(
-        Params._dummy(), "binary", "Binary toggle to control the output vector values." +
-        " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
-        " for discrete probabilistic models that model binary events rather than integer counts." +
-        " Default False", typeConverter=TypeConverters.toBoolean)
+        Params._dummy(),
+        "binary",
+        "Binary toggle to control the output vector values."
+        + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful"
+        + " for discrete probabilistic models that model binary events rather than integer counts."
+        + " Default False",
+        typeConverter=TypeConverters.toBoolean,
+    )
 
     def __init__(self, *args):
         super(_CountVectorizerParams, self).__init__(*args)
@@ -791,22 +923,39 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav
     """
 
     @keyword_only
-    def __init__(self, *, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,
-                 binary=False, inputCol=None, outputCol=None):
+    def __init__(
+        self,
+        *,
+        minTF=1.0,
+        minDF=1.0,
+        maxDF=2 ** 63 - 1,
+        vocabSize=1 << 18,
+        binary=False,
+        inputCol=None,
+        outputCol=None,
+    ):
         """
         __init__(self, \\*, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,\
                  binary=False, inputCol=None,outputCol=None)
         """
         super(CountVectorizer, self).__init__()
-        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
-                                            self.uid)
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid)
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.6.0")
-    def setParams(self, *, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,
-                  binary=False, inputCol=None, outputCol=None):
+    def setParams(
+        self,
+        *,
+        minTF=1.0,
+        minDF=1.0,
+        maxDF=2 ** 63 - 1,
+        vocabSize=1 << 18,
+        binary=False,
+        inputCol=None,
+        outputCol=None,
+    ):
         """
         setParams(self, \\*, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,\
                   binary=False, inputCol=None, outputCol=None)
@@ -899,7 +1048,8 @@ class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, Ja
         java_class = sc._gateway.jvm.java.lang.String
         jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
         model = CountVectorizerModel._create_from_java_class(
-            "org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
+            "org.apache.spark.ml.feature.CountVectorizerModel", jvocab
+        )
         model.setInputCol(inputCol)
         if outputCol is not None:
             model.setOutputCol(outputCol)
@@ -975,8 +1125,12 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit
     False
     """
 
-    inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " +
-                    "default False.", typeConverter=TypeConverters.toBoolean)
+    inverse = Param(
+        Params._dummy(),
+        "inverse",
+        "Set transformer to perform inverse DCT, " + "default False.",
+        typeConverter=TypeConverters.toBoolean,
+    )
 
     @keyword_only
     def __init__(self, *, inverse=False, inputCol=None, outputCol=None):
@@ -1027,8 +1181,9 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit
 
 
 @inherit_doc
-class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
-                         JavaMLWritable):
+class ElementwiseProduct(
+    JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable
+):
     """
     Outputs the Hadamard product (i.e., the element-wise product) of each input vector
     with a provided "weight" vector. In other words, it scales each column of the dataset
@@ -1060,8 +1215,12 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
     True
     """
 
-    scalingVec = Param(Params._dummy(), "scalingVec", "Vector for hadamard product.",
-                       typeConverter=TypeConverters.toVector)
+    scalingVec = Param(
+        Params._dummy(),
+        "scalingVec",
+        "Vector for hadamard product.",
+        typeConverter=TypeConverters.toVector,
+    )
 
     @keyword_only
     def __init__(self, *, scalingVec=None, inputCol=None, outputCol=None):
@@ -1069,8 +1228,9 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
         __init__(self, \\*, scalingVec=None, inputCol=None, outputCol=None)
         """
         super(ElementwiseProduct, self).__init__()
-        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct",
-                                            self.uid)
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.feature.ElementwiseProduct", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
@@ -1112,8 +1272,9 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
 
 
 @inherit_doc
-class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, JavaMLReadable,
-                    JavaMLWritable):
+class FeatureHasher(
+    JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, JavaMLReadable, JavaMLWritable
+):
     """
     Feature hashing projects a set of categorical or numerical features into a feature vector of
     specified dimension (typically substantially smaller than that of the original feature
@@ -1171,13 +1332,17 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
     True
     """
 
-    categoricalCols = Param(Params._dummy(), "categoricalCols",
-                            "numeric columns to treat as categorical",
-                            typeConverter=TypeConverters.toListString)
+    categoricalCols = Param(
+        Params._dummy(),
+        "categoricalCols",
+        "numeric columns to treat as categorical",
+        typeConverter=TypeConverters.toListString,
+    )
 
     @keyword_only
-    def __init__(self, *, numFeatures=1 << 18, inputCols=None, outputCol=None,
-                 categoricalCols=None):
+    def __init__(
+        self, *, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None
+    ):
         """
         __init__(self, \\*, numFeatures=1 << 18, inputCols=None, outputCol=None, \
                  categoricalCols=None)
@@ -1190,8 +1355,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
 
     @keyword_only
     @since("2.3.0")
-    def setParams(self, *, numFeatures=1 << 18, inputCols=None, outputCol=None,
-                  categoricalCols=None):
+    def setParams(
+        self, *, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None
+    ):
         """
         setParams(self, \\*, numFeatures=1 << 18, inputCols=None, outputCol=None, \
                   categoricalCols=None)
@@ -1234,8 +1400,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
 
 
 @inherit_doc
-class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable,
-                JavaMLWritable):
+class HashingTF(
+    JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, JavaMLWritable
+):
     """
     Maps a sequence of terms to their term frequencies using the hashing trick.
     Currently we use Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32)
@@ -1270,10 +1437,14 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
     5
     """
 
-    binary = Param(Params._dummy(), "binary", "If True, all non zero counts are set to 1. " +
-                   "This is useful for discrete probabilistic models that model binary events " +
-                   "rather than integer counts. Default False.",
-                   typeConverter=TypeConverters.toBoolean)
+    binary = Param(
+        Params._dummy(),
+        "binary",
+        "If True, all non zero counts are set to 1. "
+        + "This is useful for discrete probabilistic models that model binary events "
+        + "rather than integer counts. Default False.",
+        typeConverter=TypeConverters.toBoolean,
+    )
 
     @keyword_only
     def __init__(self, *, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None):
@@ -1344,9 +1515,12 @@ class _IDFParams(HasInputCol, HasOutputCol):
     .. versionadded:: 3.0.0
     """
 
-    minDocFreq = Param(Params._dummy(), "minDocFreq",
-                       "minimum number of documents in which a term should appear for filtering",
-                       typeConverter=TypeConverters.toInt)
+    minDocFreq = Param(
+        Params._dummy(),
+        "minDocFreq",
+        "minimum number of documents in which a term should appear for filtering",
+        typeConverter=TypeConverters.toInt,
+    )
 
     @since("1.4.0")
     def getMinDocFreq(self):
@@ -1503,16 +1677,23 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has
     .. versionadded:: 3.0.0
     """
 
-    strategy = Param(Params._dummy(), "strategy",
-                     "strategy for imputation. If mean, then replace missing values using the mean "
-                     "value of the feature. If median, then replace missing values using the "
-                     "median value of the feature. If mode, then replace missing using the most "
-                     "frequent value of the feature.",
-                     typeConverter=TypeConverters.toString)
-
-    missingValue = Param(Params._dummy(), "missingValue",
-                         "The placeholder for the missing values. All occurrences of missingValue "
-                         "will be imputed.", typeConverter=TypeConverters.toFloat)
+    strategy = Param(
+        Params._dummy(),
+        "strategy",
+        "strategy for imputation. If mean, then replace missing values using the mean "
+        "value of the feature. If median, then replace missing values using the "
+        "median value of the feature. If mode, then replace missing using the most "
+        "frequent value of the feature.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    missingValue = Param(
+        Params._dummy(),
+        "missingValue",
+        "The placeholder for the missing values. All occurrences of missingValue "
+        "will be imputed.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     def __init__(self, *args):
         super(_ImputerParams, self).__init__(*args)
@@ -1649,8 +1830,17 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
     """
 
     @keyword_only
-    def __init__(self, *, strategy="mean", missingValue=float("nan"), inputCols=None,
-                 outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
+    def __init__(
+        self,
+        *,
+        strategy="mean",
+        missingValue=float("nan"),
+        inputCols=None,
+        outputCols=None,
+        inputCol=None,
+        outputCol=None,
+        relativeError=0.001,
+    ):
         """
         __init__(self, \\*, strategy="mean", missingValue=float("nan"), inputCols=None, \
                  outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
@@ -1662,8 +1852,17 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
 
     @keyword_only
     @since("2.2.0")
-    def setParams(self, *, strategy="mean", missingValue=float("nan"), inputCols=None,
-                  outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
+    def setParams(
+        self,
+        *,
+        strategy="mean",
+        missingValue=float("nan"),
+        inputCols=None,
+        outputCols=None,
+        inputCol=None,
+        outputCol=None,
+        relativeError=0.001,
+    ):
         """
         setParams(self, \\*, strategy="mean", missingValue=float("nan"), inputCols=None, \
                   outputCols=None, inputCol=None, outputCol=None, relativeError=0.001)
@@ -1849,6 +2048,7 @@ class _MaxAbsScalerParams(HasInputCol, HasOutputCol):
 
     .. versionadded:: 3.0.0
     """
+
     pass
 
 
@@ -2082,10 +2282,18 @@ class _MinMaxScalerParams(HasInputCol, HasOutputCol):
     .. versionadded:: 3.0.0
     """
 
-    min = Param(Params._dummy(), "min", "Lower bound of the output feature range",
-                typeConverter=TypeConverters.toFloat)
-    max = Param(Params._dummy(), "max", "Upper bound of the output feature range",
-                typeConverter=TypeConverters.toFloat)
+    min = Param(
+        Params._dummy(),
+        "min",
+        "Lower bound of the output feature range",
+        typeConverter=TypeConverters.toFloat,
+    )
+    max = Param(
+        Params._dummy(),
+        "max",
+        "Upper bound of the output feature range",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     def __init__(self, *args):
         super(_MinMaxScalerParams, self).__init__(*args)
@@ -2311,8 +2519,12 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr
     True
     """
 
-    n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)",
-              typeConverter=TypeConverters.toInt)
+    n = Param(
+        Params._dummy(),
+        "n",
+        "number of elements per n-gram (>=1)",
+        typeConverter=TypeConverters.toInt,
+    )
 
     @keyword_only
     def __init__(self, *, n=2, inputCol=None, outputCol=None):
@@ -2395,8 +2607,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
     True
     """
 
-    p = Param(Params._dummy(), "p", "the p norm value.",
-              typeConverter=TypeConverters.toFloat)
+    p = Param(Params._dummy(), "p", "the p norm value.", typeConverter=TypeConverters.toFloat)
 
     @keyword_only
     def __init__(self, *, p=2.0, inputCol=None, outputCol=None):
@@ -2446,23 +2657,32 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         return self._set(outputCol=value)
 
 
-class _OneHotEncoderParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols,
-                           HasHandleInvalid):
+class _OneHotEncoderParams(
+    HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasHandleInvalid
+):
     """
     Params for :py:class:`OneHotEncoder` and :py:class:`OneHotEncoderModel`.
 
     .. versionadded:: 3.0.0
     """
 
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " +
-                          "transform(). Options are 'keep' (invalid data presented as an extra " +
-                          "categorical feature) or error (throw an error). Note that this Param " +
-                          "is only used during transform; during fitting, invalid data will " +
-                          "result in an error.",
-                          typeConverter=TypeConverters.toString)
-
-    dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category",
-                     typeConverter=TypeConverters.toBoolean)
+    handleInvalid = Param(
+        Params._dummy(),
+        "handleInvalid",
+        "How to handle invalid data during "
+        + "transform(). Options are 'keep' (invalid data presented as an extra "
+        + "categorical feature) or error (throw an error). Note that this Param "
+        + "is only used during transform; during fitting, invalid data will "
+        + "result in an error.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    dropLast = Param(
+        Params._dummy(),
+        "dropLast",
+        "whether to drop the last category",
+        typeConverter=TypeConverters.toBoolean,
+    )
 
     def __init__(self, *args):
         super(_OneHotEncoderParams, self).__init__(*args)
@@ -2541,22 +2761,37 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW
     """
 
     @keyword_only
-    def __init__(self, *, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True,
-                 inputCol=None, outputCol=None):
+    def __init__(
+        self,
+        *,
+        inputCols=None,
+        outputCols=None,
+        handleInvalid="error",
+        dropLast=True,
+        inputCol=None,
+        outputCol=None,
+    ):
         """
         __init__(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \
                  inputCol=None, outputCol=None)
         """
         super(OneHotEncoder, self).__init__()
-        self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.feature.OneHotEncoder", self.uid)
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid)
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.3.0")
-    def setParams(self, *, inputCols=None, outputCols=None, handleInvalid="error",
-                  dropLast=True, inputCol=None, outputCol=None):
+    def setParams(
+        self,
+        *,
+        inputCols=None,
+        outputCols=None,
+        handleInvalid="error",
+        dropLast=True,
+        inputCol=None,
+        outputCol=None,
+    ):
         """
         setParams(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", \
                   dropLast=True, inputCol=None, outputCol=None)
@@ -2671,8 +2906,9 @@ class OneHotEncoderModel(JavaModel, _OneHotEncoderParams, JavaMLReadable, JavaML
 
 
 @inherit_doc
-class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
-                          JavaMLWritable):
+class PolynomialExpansion(
+    JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable
+):
     """
     Perform feature expansion in a polynomial space. As said in `wikipedia of Polynomial Expansion
     <http://en.wikipedia.org/wiki/Polynomial_expansion>`_, "In mathematics, an
@@ -2704,8 +2940,12 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
     True
     """
 
-    degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)",
-                   typeConverter=TypeConverters.toInt)
+    degree = Param(
+        Params._dummy(),
+        "degree",
+        "the polynomial degree to expand (>= 1)",
+        typeConverter=TypeConverters.toInt,
+    )
 
     @keyword_only
     def __init__(self, *, degree=2, inputCol=None, outputCol=None):
@@ -2714,7 +2954,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
         """
         super(PolynomialExpansion, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.feature.PolynomialExpansion", self.uid)
+            "org.apache.spark.ml.feature.PolynomialExpansion", self.uid
+        )
         self._setDefault(degree=2)
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
@@ -2757,8 +2998,17 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
 
 
 @inherit_doc
-class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
-                          HasHandleInvalid, HasRelativeError, JavaMLReadable, JavaMLWritable):
+class QuantileDiscretizer(
+    JavaEstimator,
+    HasInputCol,
+    HasOutputCol,
+    HasInputCols,
+    HasOutputCols,
+    HasHandleInvalid,
+    HasRelativeError,
+    JavaMLReadable,
+    JavaMLWritable,
+):
     """
     :py:class:`QuantileDiscretizer` takes a column with continuous features and outputs a column
     with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets`
@@ -2852,46 +3102,78 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols
     ...
     """
 
-    numBuckets = Param(Params._dummy(), "numBuckets",
-                       "Maximum number of buckets (quantiles, or " +
-                       "categories) into which data points are grouped. Must be >= 2.",
-                       typeConverter=TypeConverters.toInt)
-
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
-                          "Options are skip (filter out rows with invalid values), " +
-                          "error (throw an error), or keep (keep invalid values in a special " +
-                          "additional bucket). Note that in the multiple columns " +
-                          "case, the invalid handling is applied to all columns. That said " +
-                          "for 'error' it will throw an error if any invalids are found in " +
-                          "any columns, for 'skip' it will skip rows with any invalids in " +
-                          "any columns, etc.",
-                          typeConverter=TypeConverters.toString)
-
-    numBucketsArray = Param(Params._dummy(), "numBucketsArray", "Array of number of buckets " +
-                            "(quantiles, or categories) into which data points are grouped. " +
-                            "This is for multiple columns input. If transforming multiple " +
-                            "columns and numBucketsArray is not set, but numBuckets is set, " +
-                            "then numBuckets will be applied across all columns.",
-                            typeConverter=TypeConverters.toListInt)
+    numBuckets = Param(
+        Params._dummy(),
+        "numBuckets",
+        "Maximum number of buckets (quantiles, or "
+        + "categories) into which data points are grouped. Must be >= 2.",
+        typeConverter=TypeConverters.toInt,
+    )
+
+    handleInvalid = Param(
+        Params._dummy(),
+        "handleInvalid",
+        "how to handle invalid entries. "
+        + "Options are skip (filter out rows with invalid values), "
+        + "error (throw an error), or keep (keep invalid values in a special "
+        + "additional bucket). Note that in the multiple columns "
+        + "case, the invalid handling is applied to all columns. That said "
+        + "for 'error' it will throw an error if any invalids are found in "
+        + "any columns, for 'skip' it will skip rows with any invalids in "
+        + "any columns, etc.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    numBucketsArray = Param(
+        Params._dummy(),
+        "numBucketsArray",
+        "Array of number of buckets "
+        + "(quantiles, or categories) into which data points are grouped. "
+        + "This is for multiple columns input. If transforming multiple "
+        + "columns and numBucketsArray is not set, but numBuckets is set, "
+        + "then numBuckets will be applied across all columns.",
+        typeConverter=TypeConverters.toListInt,
+    )
 
     @keyword_only
-    def __init__(self, *, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
-                 handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None):
+    def __init__(
+        self,
+        *,
+        numBuckets=2,
+        inputCol=None,
+        outputCol=None,
+        relativeError=0.001,
+        handleInvalid="error",
+        numBucketsArray=None,
+        inputCols=None,
+        outputCols=None,
+    ):
         """
         __init__(self, \\*, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
                  handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
         """
         super(QuantileDiscretizer, self).__init__()
-        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
-                                            self.uid)
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.feature.QuantileDiscretizer", self.uid
+        )
         self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.0.0")
-    def setParams(self, *, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
-                  handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None):
+    def setParams(
+        self,
+        *,
+        numBuckets=2,
+        inputCol=None,
+        outputCol=None,
+        relativeError=0.001,
+        handleInvalid="error",
+        numBucketsArray=None,
+        inputCols=None,
+        outputCols=None,
+    ):
         """
         setParams(self, \\*, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
                   handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
@@ -2971,17 +3253,21 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols
         """
         Private method to convert the java_model to a Python model.
         """
-        if (self.isSet(self.inputCol)):
-            return Bucketizer(splits=list(java_model.getSplits()),
-                              inputCol=self.getInputCol(),
-                              outputCol=self.getOutputCol(),
-                              handleInvalid=self.getHandleInvalid())
+        if self.isSet(self.inputCol):
+            return Bucketizer(
+                splits=list(java_model.getSplits()),
+                inputCol=self.getInputCol(),
+                outputCol=self.getOutputCol(),
+                handleInvalid=self.getHandleInvalid(),
+            )
         else:
             splitsArrayList = [list(x) for x in list(java_model.getSplitsArray())]
-            return Bucketizer(splitsArray=splitsArrayList,
-                              inputCols=self.getInputCols(),
-                              outputCols=self.getOutputCols(),
-                              handleInvalid=self.getHandleInvalid())
+            return Bucketizer(
+                splitsArray=splitsArrayList,
+                inputCols=self.getInputCols(),
+                outputCols=self.getOutputCols(),
+                handleInvalid=self.getHandleInvalid(),
+            )
 
 
 class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
@@ -2991,19 +3277,36 @@ class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
     .. versionadded:: 3.0.0
     """
 
-    lower = Param(Params._dummy(), "lower", "Lower quantile to calculate quantile range",
-                  typeConverter=TypeConverters.toFloat)
-    upper = Param(Params._dummy(), "upper", "Upper quantile to calculate quantile range",
-                  typeConverter=TypeConverters.toFloat)
-    withCentering = Param(Params._dummy(), "withCentering", "Whether to center data with median",
-                          typeConverter=TypeConverters.toBoolean)
-    withScaling = Param(Params._dummy(), "withScaling", "Whether to scale the data to "
-                        "quantile range", typeConverter=TypeConverters.toBoolean)
+    lower = Param(
+        Params._dummy(),
+        "lower",
+        "Lower quantile to calculate quantile range",
+        typeConverter=TypeConverters.toFloat,
+    )
+    upper = Param(
+        Params._dummy(),
+        "upper",
+        "Upper quantile to calculate quantile range",
+        typeConverter=TypeConverters.toFloat,
+    )
+    withCentering = Param(
+        Params._dummy(),
+        "withCentering",
+        "Whether to center data with median",
+        typeConverter=TypeConverters.toBoolean,
+    )
+    withScaling = Param(
+        Params._dummy(),
+        "withScaling",
+        "Whether to scale the data to " "quantile range",
+        typeConverter=TypeConverters.toBoolean,
+    )
 
     def __init__(self, *args):
         super(_RobustScalerParams, self).__init__(*args)
-        self._setDefault(lower=0.25, upper=0.75, withCentering=False, withScaling=True,
-                         relativeError=0.001)
+        self._setDefault(
+            lower=0.25, upper=0.75, withCentering=False, withScaling=True, relativeError=0.001
+        )
 
     @since("3.0.0")
     def getLower(self):
@@ -3089,8 +3392,17 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri
     """
 
     @keyword_only
-    def __init__(self, *, lower=0.25, upper=0.75, withCentering=False, withScaling=True,
-                 inputCol=None, outputCol=None, relativeError=0.001):
+    def __init__(
+        self,
+        *,
+        lower=0.25,
+        upper=0.75,
+        withCentering=False,
+        withScaling=True,
+        inputCol=None,
+        outputCol=None,
+        relativeError=0.001,
+    ):
         """
         __init__(self, \\*, lower=0.25, upper=0.75, withCentering=False, withScaling=True, \
                  inputCol=None, outputCol=None, relativeError=0.001)
@@ -3102,8 +3414,17 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri
 
     @keyword_only
     @since("3.0.0")
-    def setParams(self, *, lower=0.25, upper=0.75, withCentering=False, withScaling=True,
-                  inputCol=None, outputCol=None, relativeError=0.001):
+    def setParams(
+        self,
+        *,
+        lower=0.25,
+        upper=0.75,
+        withCentering=False,
+        withScaling=True,
+        inputCol=None,
+        outputCol=None,
+        relativeError=0.001,
+    ):
         """
         setParams(self, \\*, lower=0.25, upper=0.75, withCentering=False, withScaling=True, \
                   inputCol=None, outputCol=None, relativeError=0.001)
@@ -3249,18 +3570,41 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
     True
     """
 
-    minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)",
-                           typeConverter=TypeConverters.toInt)
-    gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens " +
-                 "(False)")
-    pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing",
-                    typeConverter=TypeConverters.toString)
-    toLowercase = Param(Params._dummy(), "toLowercase", "whether to convert all characters to " +
-                        "lowercase before tokenizing", typeConverter=TypeConverters.toBoolean)
+    minTokenLength = Param(
+        Params._dummy(),
+        "minTokenLength",
+        "minimum token length (>= 0)",
+        typeConverter=TypeConverters.toInt,
+    )
+    gaps = Param(
+        Params._dummy(),
+        "gaps",
+        "whether regex splits on gaps (True) or matches tokens " + "(False)",
+    )
+    pattern = Param(
+        Params._dummy(),
+        "pattern",
+        "regex pattern (Java dialect) used for tokenizing",
+        typeConverter=TypeConverters.toString,
+    )
+    toLowercase = Param(
+        Params._dummy(),
+        "toLowercase",
+        "whether to convert all characters to " + "lowercase before tokenizing",
+        typeConverter=TypeConverters.toBoolean,
+    )
 
     @keyword_only
-    def __init__(self, *, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None,
-                 outputCol=None, toLowercase=True):
+    def __init__(
+        self,
+        *,
+        minTokenLength=1,
+        gaps=True,
+        pattern="\\s+",
+        inputCol=None,
+        outputCol=None,
+        toLowercase=True,
+    ):
         """
         __init__(self, \\*, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \
                  outputCol=None, toLowercase=True)
@@ -3273,8 +3617,16 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None,
-                  outputCol=None, toLowercase=True):
+    def setParams(
+        self,
+        *,
+        minTokenLength=1,
+        gaps=True,
+        pattern="\\s+",
+        inputCol=None,
+        outputCol=None,
+        toLowercase=True,
+    ):
         """
         setParams(self, \\*, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \
                   outputCol=None, toLowercase=True)
@@ -3377,8 +3729,9 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
     True
     """
 
-    statement = Param(Params._dummy(), "statement", "SQL statement",
-                      typeConverter=TypeConverters.toString)
+    statement = Param(
+        Params._dummy(), "statement", "SQL statement", typeConverter=TypeConverters.toString
+    )
 
     @keyword_only
     def __init__(self, *, statement=None):
@@ -3422,10 +3775,15 @@ class _StandardScalerParams(HasInputCol, HasOutputCol):
     .. versionadded:: 3.0.0
     """
 
-    withMean = Param(Params._dummy(), "withMean", "Center data with mean",
-                     typeConverter=TypeConverters.toBoolean)
-    withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation",
-                    typeConverter=TypeConverters.toBoolean)
+    withMean = Param(
+        Params._dummy(), "withMean", "Center data with mean", typeConverter=TypeConverters.toBoolean
+    )
+    withStd = Param(
+        Params._dummy(),
+        "withStd",
+        "Scale to unit standard deviation",
+        typeConverter=TypeConverters.toBoolean,
+    )
 
     def __init__(self, *args):
         super(_StandardScalerParams, self).__init__(*args)
@@ -3582,27 +3940,35 @@ class StandardScalerModel(JavaModel, _StandardScalerParams, JavaMLReadable, Java
         return self._call_java("mean")
 
 
-class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol,
-                           HasInputCols, HasOutputCols):
+class _StringIndexerParams(
+    JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols
+):
     """
     Params for :py:class:`StringIndexer` and :py:class:`StringIndexerModel`.
     """
 
-    stringOrderType = Param(Params._dummy(), "stringOrderType",
-                            "How to order labels of string column. The first label after " +
-                            "ordering is assigned an index of 0. Supported options: " +
-                            "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. " +
-                            "Default is frequencyDesc. In case of equal frequency when " +
-                            "under frequencyDesc/Asc, the strings are further sorted " +
-                            "alphabetically",
-                            typeConverter=TypeConverters.toString)
-
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
-                          "or NULL values) in features and label column of string type. " +
-                          "Options are 'skip' (filter out rows with invalid data), " +
-                          "error (throw an error), or 'keep' (put invalid data " +
-                          "in a special additional bucket, at index numLabels).",
-                          typeConverter=TypeConverters.toString)
+    stringOrderType = Param(
+        Params._dummy(),
+        "stringOrderType",
+        "How to order labels of string column. The first label after "
+        + "ordering is assigned an index of 0. Supported options: "
+        + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. "
+        + "Default is frequencyDesc. In case of equal frequency when "
+        + "under frequencyDesc/Asc, the strings are further sorted "
+        + "alphabetically",
+        typeConverter=TypeConverters.toString,
+    )
+
+    handleInvalid = Param(
+        Params._dummy(),
+        "handleInvalid",
+        "how to handle invalid data (unseen "
+        + "or NULL values) in features and label column of string type. "
+        + "Options are 'skip' (filter out rows with invalid data), "
+        + "error (throw an error), or 'keep' (put invalid data "
+        + "in a special additional bucket, at index numLabels).",
+        typeConverter=TypeConverters.toString,
+    )
 
     def __init__(self, *args):
         super(_StringIndexerParams, self).__init__(*args)
@@ -3701,8 +4067,16 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
     """
 
     @keyword_only
-    def __init__(self, *, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
-                 handleInvalid="error", stringOrderType="frequencyDesc"):
+    def __init__(
+        self,
+        *,
+        inputCol=None,
+        outputCol=None,
+        inputCols=None,
+        outputCols=None,
+        handleInvalid="error",
+        stringOrderType="frequencyDesc",
+    ):
         """
         __init__(self, \\*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
                  handleInvalid="error", stringOrderType="frequencyDesc")
@@ -3714,8 +4088,16 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
-                  handleInvalid="error", stringOrderType="frequencyDesc"):
+    def setParams(
+        self,
+        *,
+        inputCol=None,
+        outputCol=None,
+        inputCols=None,
+        outputCols=None,
+        handleInvalid="error",
+        stringOrderType="frequencyDesc",
+    ):
         """
         setParams(self, \\*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
                   handleInvalid="error", stringOrderType="frequencyDesc")
@@ -3818,7 +4200,8 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
         java_class = sc._gateway.jvm.java.lang.String
         jlabels = StringIndexerModel._new_java_array(labels, java_class)
         model = StringIndexerModel._create_from_java_class(
-            "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+            "org.apache.spark.ml.feature.StringIndexerModel", jlabels
+        )
         model.setInputCol(inputCol)
         if outputCol is not None:
             model.setOutputCol(outputCol)
@@ -3828,8 +4211,7 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
 
     @classmethod
     @since("3.0.0")
-    def from_arrays_of_labels(cls, arrayOfLabels, inputCols, outputCols=None,
-                              handleInvalid=None):
+    def from_arrays_of_labels(cls, arrayOfLabels, inputCols, outputCols=None, handleInvalid=None):
         """
         Construct the model directly from an array of array of label strings,
         requires an active SparkContext.
@@ -3838,7 +4220,8 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
         java_class = sc._gateway.jvm.java.lang.String
         jlabels = StringIndexerModel._new_java_array(arrayOfLabels, java_class)
         model = StringIndexerModel._create_from_java_class(
-            "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+            "org.apache.spark.ml.feature.StringIndexerModel", jlabels
+        )
         model.setInputCols(inputCols)
         if outputCols is not None:
             model.setOutputCols(outputCols)
@@ -3882,10 +4265,13 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
     StringIndexer : for converting categorical values into category indices
     """
 
-    labels = Param(Params._dummy(), "labels",
-                   "Optional array of labels specifying index-string mapping." +
-                   " If not provided or if empty, then metadata from inputCol is used instead.",
-                   typeConverter=TypeConverters.toListString)
+    labels = Param(
+        Params._dummy(),
+        "labels",
+        "Optional array of labels specifying index-string mapping."
+        + " If not provided or if empty, then metadata from inputCol is used instead.",
+        typeConverter=TypeConverters.toListString,
+    )
 
     @keyword_only
     def __init__(self, *, inputCol=None, outputCol=None, labels=None):
@@ -3893,8 +4279,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
         __init__(self, \\*, inputCol=None, outputCol=None, labels=None)
         """
         super(IndexToString, self).__init__()
-        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString",
-                                            self.uid)
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid)
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
@@ -3935,8 +4320,15 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
         return self._set(outputCol=value)
 
 
-class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
-                       JavaMLReadable, JavaMLWritable):
+class StopWordsRemover(
+    JavaTransformer,
+    HasInputCol,
+    HasOutputCol,
+    HasInputCols,
+    HasOutputCols,
+    JavaMLReadable,
+    JavaMLWritable,
+):
     """
     A feature transformer that filters out stop words from input.
     Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting
@@ -3981,32 +4373,66 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols,
     ...
     """
 
-    stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out",
-                      typeConverter=TypeConverters.toListString)
-    caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
-                          "comparison over the stop words", typeConverter=TypeConverters.toBoolean)
-    locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " +
-                   "is true", typeConverter=TypeConverters.toString)
+    stopWords = Param(
+        Params._dummy(),
+        "stopWords",
+        "The words to be filtered out",
+        typeConverter=TypeConverters.toListString,
+    )
+    caseSensitive = Param(
+        Params._dummy(),
+        "caseSensitive",
+        "whether to do a case sensitive " + "comparison over the stop words",
+        typeConverter=TypeConverters.toBoolean,
+    )
+    locale = Param(
+        Params._dummy(),
+        "locale",
+        "locale of the input. ignored when case sensitive " + "is true",
+        typeConverter=TypeConverters.toString,
+    )
 
     @keyword_only
-    def __init__(self, *, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
-                 locale=None, inputCols=None, outputCols=None):
+    def __init__(
+        self,
+        *,
+        inputCol=None,
+        outputCol=None,
+        stopWords=None,
+        caseSensitive=False,
+        locale=None,
+        inputCols=None,
+        outputCols=None,
+    ):
         """
         __init__(self, \\*, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
                  locale=None, inputCols=None, outputCols=None)
         """
         super(StopWordsRemover, self).__init__()
-        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
-                                            self.uid)
-        self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
-                         caseSensitive=False, locale=self._java_obj.getLocale())
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.feature.StopWordsRemover", self.uid
+        )
+        self._setDefault(
+            stopWords=StopWordsRemover.loadDefaultStopWords("english"),
+            caseSensitive=False,
+            locale=self._java_obj.getLocale(),
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.6.0")
-    def setParams(self, *, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
-                  locale=None, inputCols=None, outputCols=None):
+    def setParams(
+        self,
+        *,
+        inputCol=None,
+        outputCol=None,
+        stopWords=None,
+        caseSensitive=False,
+        locale=None,
+        inputCols=None,
+        outputCols=None,
+    ):
         """
         setParams(self, \\*, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
                   locale=None, inputCols=None, outputCols=None)
@@ -4165,8 +4591,9 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
 
 
 @inherit_doc
-class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable,
-                      JavaMLWritable):
+class VectorAssembler(
+    JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable, JavaMLWritable
+):
     """
     A feature transformer that merges multiple columns into a vector column.
 
@@ -4212,15 +4639,19 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInva
     ...
     """
 
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " +
-                          "and NaN values). Options are 'skip' (filter out rows with invalid " +
-                          "data), 'error' (throw an error), or 'keep' (return relevant number " +
-                          "of NaN in the output). Column lengths are taken from the size of ML " +
-                          "Attribute Group, which can be set using `VectorSizeHint` in a " +
-                          "pipeline before `VectorAssembler`. Column lengths can also be " +
-                          "inferred from first rows of the data since it is safe to do so but " +
-                          "only in case of 'error' or 'skip').",
-                          typeConverter=TypeConverters.toString)
+    handleInvalid = Param(
+        Params._dummy(),
+        "handleInvalid",
+        "How to handle invalid data (NULL "
+        + "and NaN values). Options are 'skip' (filter out rows with invalid "
+        + "data), 'error' (throw an error), or 'keep' (return relevant number "
+        + "of NaN in the output). Column lengths are taken from the size of ML "
+        + "Attribute Group, which can be set using `VectorSizeHint` in a "
+        + "pipeline before `VectorAssembler`. Column lengths can also be "
+        + "inferred from first rows of the data since it is safe to do so but "
+        + "only in case of 'error' or 'skip').",
+        typeConverter=TypeConverters.toString,
+    )
 
     @keyword_only
     def __init__(self, *, inputCols=None, outputCol=None, handleInvalid="error"):
@@ -4269,17 +4700,25 @@ class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid):
     .. versionadded:: 3.0.0
     """
 
-    maxCategories = Param(Params._dummy(), "maxCategories",
-                          "Threshold for the number of values a categorical feature can take " +
-                          "(>= 2). If a feature is found to have > maxCategories values, then " +
-                          "it is declared continuous.", typeConverter=TypeConverters.toInt)
-
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " +
-                          "(unseen labels or NULL values). Options are 'skip' (filter out " +
-                          "rows with invalid data), 'error' (throw an error), or 'keep' (put " +
-                          "invalid data in a special additional bucket, at index of the number " +
-                          "of categories of the feature).",
-                          typeConverter=TypeConverters.toString)
+    maxCategories = Param(
+        Params._dummy(),
+        "maxCategories",
+        "Threshold for the number of values a categorical feature can take "
+        + "(>= 2). If a feature is found to have > maxCategories values, then "
+        + "it is declared continuous.",
+        typeConverter=TypeConverters.toInt,
+    )
+
+    handleInvalid = Param(
+        Params._dummy(),
+        "handleInvalid",
+        "How to handle invalid data "
+        + "(unseen labels or NULL values). Options are 'skip' (filter out "
+        + "rows with invalid data), 'error' (throw an error), or 'keep' (put "
+        + "invalid data in a special additional bucket, at index of the number "
+        + "of categories of the feature).",
+        typeConverter=TypeConverters.toString,
+    )
 
     def __init__(self, *args):
         super(_VectorIndexerParams, self).__init__(*args)
@@ -4519,13 +4958,22 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J
     True
     """
 
-    indices = Param(Params._dummy(), "indices", "An array of indices to select features from " +
-                    "a vector column. There can be no overlap with names.",
-                    typeConverter=TypeConverters.toListInt)
-    names = Param(Params._dummy(), "names", "An array of feature names to select features from " +
-                  "a vector column. These names must be specified by ML " +
-                  "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " +
-                  "indices.", typeConverter=TypeConverters.toListString)
+    indices = Param(
+        Params._dummy(),
+        "indices",
+        "An array of indices to select features from "
+        + "a vector column. There can be no overlap with names.",
+        typeConverter=TypeConverters.toListInt,
+    )
+    names = Param(
+        Params._dummy(),
+        "names",
+        "An array of feature names to select features from "
+        + "a vector column. These names must be specified by ML "
+        + "org.apache.spark.ml.attribute.Attribute. There can be no overlap with "
+        + "indices.",
+        typeConverter=TypeConverters.toListString,
+    )
 
     @keyword_only
     def __init__(self, *, inputCol=None, outputCol=None, indices=None, names=None):
@@ -4596,28 +5044,51 @@ class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCo
     .. versionadded:: 3.0.0
     """
 
-    vectorSize = Param(Params._dummy(), "vectorSize",
-                       "the dimension of codes after transforming from words",
-                       typeConverter=TypeConverters.toInt)
-    numPartitions = Param(Params._dummy(), "numPartitions",
-                          "number of partitions for sentences of words",
-                          typeConverter=TypeConverters.toInt)
-    minCount = Param(Params._dummy(), "minCount",
-                     "the minimum number of times a token must appear to be included in the " +
-                     "word2vec model's vocabulary", typeConverter=TypeConverters.toInt)
-    windowSize = Param(Params._dummy(), "windowSize",
-                       "the window size (context words from [-window, window]). Default value is 5",
-                       typeConverter=TypeConverters.toInt)
-    maxSentenceLength = Param(Params._dummy(), "maxSentenceLength",
-                              "Maximum length (in words) of each sentence in the input data. " +
-                              "Any sentence longer than this threshold will " +
-                              "be divided into chunks up to the size.",
-                              typeConverter=TypeConverters.toInt)
+    vectorSize = Param(
+        Params._dummy(),
+        "vectorSize",
+        "the dimension of codes after transforming from words",
+        typeConverter=TypeConverters.toInt,
+    )
+    numPartitions = Param(
+        Params._dummy(),
+        "numPartitions",
+        "number of partitions for sentences of words",
+        typeConverter=TypeConverters.toInt,
+    )
+    minCount = Param(
+        Params._dummy(),
+        "minCount",
+        "the minimum number of times a token must appear to be included in the "
+        + "word2vec model's vocabulary",
+        typeConverter=TypeConverters.toInt,
+    )
+    windowSize = Param(
+        Params._dummy(),
+        "windowSize",
+        "the window size (context words from [-window, window]). Default value is 5",
+        typeConverter=TypeConverters.toInt,
+    )
+    maxSentenceLength = Param(
+        Params._dummy(),
+        "maxSentenceLength",
+        "Maximum length (in words) of each sentence in the input data. "
+        + "Any sentence longer than this threshold will "
+        + "be divided into chunks up to the size.",
+        typeConverter=TypeConverters.toInt,
+    )
 
     def __init__(self, *args):
         super(_Word2VecParams, self).__init__(*args)
-        self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
-                         windowSize=5, maxSentenceLength=1000)
+        self._setDefault(
+            vectorSize=100,
+            minCount=5,
+            numPartitions=1,
+            stepSize=0.025,
+            maxIter=1,
+            windowSize=5,
+            maxSentenceLength=1000,
+        )
 
     @since("1.4.0")
     def getVectorSize(self):
@@ -4721,9 +5192,20 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
     """
 
     @keyword_only
-    def __init__(self, *, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025,
-                 maxIter=1, seed=None, inputCol=None, outputCol=None, windowSize=5,
-                 maxSentenceLength=1000):
+    def __init__(
+        self,
+        *,
+        vectorSize=100,
+        minCount=5,
+        numPartitions=1,
+        stepSize=0.025,
+        maxIter=1,
+        seed=None,
+        inputCol=None,
+        outputCol=None,
+        windowSize=5,
+        maxSentenceLength=1000,
+    ):
         """
         __init__(self, \\*, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, \
                  maxIter=1, seed=None, inputCol=None, outputCol=None, windowSize=5, \
@@ -4736,9 +5218,20 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, *, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025,
-                  maxIter=1, seed=None, inputCol=None, outputCol=None, windowSize=5,
-                  maxSentenceLength=1000):
+    def setParams(
+        self,
+        *,
+        vectorSize=100,
+        minCount=5,
+        numPartitions=1,
+        stepSize=0.025,
+        maxIter=1,
+        seed=None,
+        inputCol=None,
+        outputCol=None,
+        windowSize=5,
+        maxSentenceLength=1000,
+    ):
         """
         setParams(self, \\*, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \
                   seed=None, inputCol=None, outputCol=None, windowSize=5, \
@@ -4878,8 +5371,12 @@ class _PCAParams(HasInputCol, HasOutputCol):
     .. versionadded:: 3.0.0
     """
 
-    k = Param(Params._dummy(), "k", "the number of principal components",
-              typeConverter=TypeConverters.toInt)
+    k = Param(
+        Params._dummy(),
+        "k",
+        "the number of principal components",
+        typeConverter=TypeConverters.toInt,
+    )
 
     @since("1.5.0")
     def getK(self):
@@ -5022,32 +5519,44 @@ class _RFormulaParams(HasFeaturesCol, HasLabelCol, HasHandleInvalid):
     .. versionadded:: 3.0.0
     """
 
-    formula = Param(Params._dummy(), "formula", "R model formula",
-                    typeConverter=TypeConverters.toString)
-
-    forceIndexLabel = Param(Params._dummy(), "forceIndexLabel",
-                            "Force to index label whether it is numeric or string",
-                            typeConverter=TypeConverters.toBoolean)
-
-    stringIndexerOrderType = Param(Params._dummy(), "stringIndexerOrderType",
-                                   "How to order categories of a string feature column used by " +
-                                   "StringIndexer. The last category after ordering is dropped " +
-                                   "when encoding strings. Supported options: frequencyDesc, " +
-                                   "frequencyAsc, alphabetDesc, alphabetAsc. The default value " +
-                                   "is frequencyDesc. When the ordering is set to alphabetDesc, " +
-                                   "RFormula drops the same category as R when encoding strings.",
-                                   typeConverter=TypeConverters.toString)
-
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
-                          "Options are 'skip' (filter out rows with invalid values), " +
-                          "'error' (throw an error), or 'keep' (put invalid data in a special " +
-                          "additional bucket, at index numLabels).",
-                          typeConverter=TypeConverters.toString)
+    formula = Param(
+        Params._dummy(), "formula", "R model formula", typeConverter=TypeConverters.toString
+    )
+
+    forceIndexLabel = Param(
+        Params._dummy(),
+        "forceIndexLabel",
+        "Force to index label whether it is numeric or string",
+        typeConverter=TypeConverters.toBoolean,
+    )
+
+    stringIndexerOrderType = Param(
+        Params._dummy(),
+        "stringIndexerOrderType",
+        "How to order categories of a string feature column used by "
+        + "StringIndexer. The last category after ordering is dropped "
+        + "when encoding strings. Supported options: frequencyDesc, "
+        + "frequencyAsc, alphabetDesc, alphabetAsc. The default value "
+        + "is frequencyDesc. When the ordering is set to alphabetDesc, "
+        + "RFormula drops the same category as R when encoding strings.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    handleInvalid = Param(
+        Params._dummy(),
+        "handleInvalid",
+        "how to handle invalid entries. "
+        + "Options are 'skip' (filter out rows with invalid values), "
+        + "'error' (throw an error), or 'keep' (put invalid data in a special "
+        + "additional bucket, at index numLabels).",
+        typeConverter=TypeConverters.toString,
+    )
 
     def __init__(self, *args):
         super(_RFormulaParams, self).__init__(*args)
-        self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
-                         handleInvalid="error")
+        self._setDefault(
+            forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", handleInvalid="error"
+        )
 
     @since("1.5.0")
     def getFormula(self):
@@ -5146,9 +5655,16 @@ class RFormula(JavaEstimator, _RFormulaParams, JavaMLReadable, JavaMLWritable):
     """
 
     @keyword_only
-    def __init__(self, *, formula=None, featuresCol="features", labelCol="label",
-                 forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
-                 handleInvalid="error"):
+    def __init__(
+        self,
+        *,
+        formula=None,
+        featuresCol="features",
+        labelCol="label",
+        forceIndexLabel=False,
+        stringIndexerOrderType="frequencyDesc",
+        handleInvalid="error",
+    ):
         """
         __init__(self, \\*, formula=None, featuresCol="features", labelCol="label", \
                  forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
@@ -5161,9 +5677,16 @@ class RFormula(JavaEstimator, _RFormulaParams, JavaMLReadable, JavaMLWritable):
 
     @keyword_only
     @since("1.5.0")
-    def setParams(self, *, formula=None, featuresCol="features", labelCol="label",
-                  forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
-                  handleInvalid="error"):
+    def setParams(
+        self,
+        *,
+        formula=None,
+        featuresCol="features",
+        labelCol="label",
+        forceIndexLabel=False,
+        stringIndexerOrderType="frequencyDesc",
+        handleInvalid="error",
+    ):
         """
         setParams(self, \\*, formula=None, featuresCol="features", labelCol="label", \
                   forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
@@ -5240,34 +5763,61 @@ class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol):
     .. versionadded:: 3.1.0
     """
 
-    selectorType = Param(Params._dummy(), "selectorType",
-                         "The selector type. " +
-                         "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.",
-                         typeConverter=TypeConverters.toString)
-
-    numTopFeatures = \
-        Param(Params._dummy(), "numTopFeatures",
-              "Number of features that selector will select, ordered by ascending p-value. " +
-              "If the number of features is < numTopFeatures, then this will select " +
-              "all features.", typeConverter=TypeConverters.toInt)
-
-    percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " +
-                       "will select, ordered by ascending p-value.",
-                       typeConverter=TypeConverters.toFloat)
-
-    fpr = Param(Params._dummy(), "fpr", "The highest p-value for features to be kept.",
-                typeConverter=TypeConverters.toFloat)
-
-    fdr = Param(Params._dummy(), "fdr", "The upper bound of the expected false discovery rate.",
-                typeConverter=TypeConverters.toFloat)
-
-    fwe = Param(Params._dummy(), "fwe", "The upper bound of the expected family-wise error rate.",
-                typeConverter=TypeConverters.toFloat)
+    selectorType = Param(
+        Params._dummy(),
+        "selectorType",
+        "The selector type. "
+        + "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    numTopFeatures = Param(
+        Params._dummy(),
+        "numTopFeatures",
+        "Number of features that selector will select, ordered by ascending p-value. "
+        + "If the number of features is < numTopFeatures, then this will select "
+        + "all features.",
+        typeConverter=TypeConverters.toInt,
+    )
+
+    percentile = Param(
+        Params._dummy(),
+        "percentile",
+        "Percentile of features that selector " + "will select, ordered by ascending p-value.",
+        typeConverter=TypeConverters.toFloat,
+    )
+
+    fpr = Param(
+        Params._dummy(),
+        "fpr",
+        "The highest p-value for features to be kept.",
+        typeConverter=TypeConverters.toFloat,
+    )
+
+    fdr = Param(
+        Params._dummy(),
+        "fdr",
+        "The upper bound of the expected false discovery rate.",
+        typeConverter=TypeConverters.toFloat,
+    )
+
+    fwe = Param(
+        Params._dummy(),
+        "fwe",
+        "The upper bound of the expected family-wise error rate.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     def __init__(self, *args):
         super(_SelectorParams, self).__init__(*args)
-        self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1,
-                         fpr=0.05, fdr=0.05, fwe=0.05)
+        self._setDefault(
+            numTopFeatures=50,
+            selectorType="numTopFeatures",
+            percentile=0.1,
+            fpr=0.05,
+            fdr=0.05,
+            fwe=0.05,
+        )
 
     @since("2.1.0")
     def getSelectorType(self):
@@ -5475,9 +6025,19 @@ class ChiSqSelector(_Selector, JavaMLReadable, JavaMLWritable):
     """
 
     @keyword_only
-    def __init__(self, *, numTopFeatures=50, featuresCol="features", outputCol=None,
-                 labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05,
-                 fdr=0.05, fwe=0.05):
+    def __init__(
+        self,
+        *,
+        numTopFeatures=50,
+        featuresCol="features",
+        outputCol=None,
+        labelCol="label",
+        selectorType="numTopFeatures",
+        percentile=0.1,
+        fpr=0.05,
+        fdr=0.05,
+        fwe=0.05,
+    ):
         """
         __init__(self, \\*, numTopFeatures=50, featuresCol="features", outputCol=None, \
                  labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \
@@ -5490,9 +6050,19 @@ class ChiSqSelector(_Selector, JavaMLReadable, JavaMLWritable):
 
     @keyword_only
     @since("2.0.0")
-    def setParams(self, *, numTopFeatures=50, featuresCol="features", outputCol=None,
-                  labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05,
-                  fdr=0.05, fwe=0.05):
+    def setParams(
+        self,
+        *,
+        numTopFeatures=50,
+        featuresCol="features",
+        outputCol=None,
+        labelCol="labels",
+        selectorType="numTopFeatures",
+        percentile=0.1,
+        fpr=0.05,
+        fdr=0.05,
+        fwe=0.05,
+    ):
         """
         setParams(self, \\*, numTopFeatures=50, featuresCol="features", outputCol=None, \
                   labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \
@@ -5515,8 +6085,9 @@ class ChiSqSelectorModel(_SelectorModel, JavaMLReadable, JavaMLWritable):
 
 
 @inherit_doc
-class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable,
-                     JavaMLWritable):
+class VectorSizeHint(
+    JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable, JavaMLWritable
+):
     """
     A feature transformer that adds size information to the metadata of a vector column.
     VectorAssembler needs size information for its input columns and cannot be used on streaming
@@ -5551,16 +6122,20 @@ class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReada
     True
     """
 
-    size = Param(Params._dummy(), "size", "Size of vectors in column.",
-                 typeConverter=TypeConverters.toInt)
+    size = Param(
+        Params._dummy(), "size", "Size of vectors in column.", typeConverter=TypeConverters.toInt
+    )
 
-    handleInvalid = Param(Params._dummy(), "handleInvalid",
-                          "How to handle invalid vectors in inputCol. Invalid vectors include "
-                          "nulls and vectors with the wrong size. The options are `skip` (filter "
-                          "out rows with invalid vectors), `error` (throw an error) and "
-                          "`optimistic` (do not check the vector size, and keep all rows). "
-                          "`error` by default.",
-                          TypeConverters.toString)
+    handleInvalid = Param(
+        Params._dummy(),
+        "handleInvalid",
+        "How to handle invalid vectors in inputCol. Invalid vectors include "
+        "nulls and vectors with the wrong size. The options are `skip` (filter "
+        "out rows with invalid vectors), `error` (throw an error) and "
+        "`optimistic` (do not check the vector size, and keep all rows). "
+        "`error` by default.",
+        TypeConverters.toString,
+    )
 
     @keyword_only
     def __init__(self, *, inputCol=None, size=None, handleInvalid="error"):
@@ -5584,12 +6159,12 @@ class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReada
 
     @since("2.3.0")
     def getSize(self):
-        """ Gets size param, the size of vectors in `inputCol`."""
+        """Gets size param, the size of vectors in `inputCol`."""
         return self.getOrDefault(self.size)
 
     @since("2.3.0")
     def setSize(self, value):
-        """ Sets size param, the size of vectors in `inputCol`."""
+        """Sets size param, the size of vectors in `inputCol`."""
         return self._set(size=value)
 
     def setInputCol(self, value):
@@ -5613,10 +6188,14 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
     .. versionadded:: 3.1.0
     """
 
-    varianceThreshold = Param(Params._dummy(), "varianceThreshold",
-                              "Param for variance threshold. Features with a variance not " +
-                              "greater than this threshold will be removed. The default value " +
-                              "is 0.0.", typeConverter=TypeConverters.toFloat)
+    varianceThreshold = Param(
+        Params._dummy(),
+        "varianceThreshold",
+        "Param for variance threshold. Features with a variance not "
+        + "greater than this threshold will be removed. The default value "
+        + "is 0.0.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     @since("3.1.0")
     def getVarianceThreshold(self):
@@ -5627,8 +6206,9 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
 
 
 @inherit_doc
-class VarianceThresholdSelector(JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable,
-                                JavaMLWritable):
+class VarianceThresholdSelector(
+    JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
+):
     """
     Feature selector that removes all low-variance features. Features with a
     variance not greater than the threshold will be removed. The default is to keep
@@ -5679,7 +6259,8 @@ class VarianceThresholdSelector(JavaEstimator, _VarianceThresholdSelectorParams,
         """
         super(VarianceThresholdSelector, self).__init__()
         self._java_obj = self._new_java_obj(
-            "org.apache.spark.ml.feature.VarianceThresholdSelector", self.uid)
+            "org.apache.spark.ml.feature.VarianceThresholdSelector", self.uid
+        )
         self._setDefault(varianceThreshold=0.0)
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
@@ -5719,8 +6300,9 @@ class VarianceThresholdSelector(JavaEstimator, _VarianceThresholdSelectorParams,
         return VarianceThresholdSelectorModel(java_model)
 
 
-class VarianceThresholdSelectorModel(JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable,
-                                     JavaMLWritable):
+class VarianceThresholdSelectorModel(
+    JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
+):
     """
     Model fitted by :py:class:`VarianceThresholdSelector`.
 
@@ -5758,25 +6340,35 @@ class _UnivariateFeatureSelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol
     .. versionadded:: 3.1.0
     """
 
-    featureType = Param(Params._dummy(), "featureType",
-                        "The feature type. " +
-                        "Supported options: categorical, continuous.",
-                        typeConverter=TypeConverters.toString)
-
-    labelType = Param(Params._dummy(), "labelType",
-                      "The label type. " +
-                      "Supported options: categorical, continuous.",
-                      typeConverter=TypeConverters.toString)
-
-    selectionMode = Param(Params._dummy(), "selectionMode",
-                          "The selection mode. " +
-                          "Supported options: numTopFeatures (default), percentile, fpr, " +
-                          "fdr, fwe.",
-                          typeConverter=TypeConverters.toString)
-
-    selectionThreshold = Param(Params._dummy(), "selectionThreshold", "The upper bound of the " +
-                               "features that selector will select.",
-                               typeConverter=TypeConverters.toFloat)
+    featureType = Param(
+        Params._dummy(),
+        "featureType",
+        "The feature type. " + "Supported options: categorical, continuous.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    labelType = Param(
+        Params._dummy(),
+        "labelType",
+        "The label type. " + "Supported options: categorical, continuous.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    selectionMode = Param(
+        Params._dummy(),
+        "selectionMode",
+        "The selection mode. "
+        + "Supported options: numTopFeatures (default), percentile, fpr, "
+        + "fdr, fwe.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    selectionThreshold = Param(
+        Params._dummy(),
+        "selectionThreshold",
+        "The upper bound of the " + "features that selector will select.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     def __init__(self, *args):
         super(_UnivariateFeatureSelectorParams, self).__init__(*args)
@@ -5812,8 +6404,9 @@ class _UnivariateFeatureSelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol
 
 
 @inherit_doc
-class UnivariateFeatureSelector(JavaEstimator, _UnivariateFeatureSelectorParams, JavaMLReadable,
-                                JavaMLWritable):
+class UnivariateFeatureSelector(
+    JavaEstimator, _UnivariateFeatureSelectorParams, JavaMLReadable, JavaMLWritable
+):
     """
     UnivariateFeatureSelector
     Feature selector based on univariate statistical tests against labels. Currently, Spark
@@ -5887,22 +6480,35 @@ class UnivariateFeatureSelector(JavaEstimator, _UnivariateFeatureSelectorParams,
     """
 
     @keyword_only
-    def __init__(self, *, featuresCol="features", outputCol=None,
-                 labelCol="label", selectionMode="numTopFeatures"):
+    def __init__(
+        self,
+        *,
+        featuresCol="features",
+        outputCol=None,
+        labelCol="label",
+        selectionMode="numTopFeatures",
+    ):
         """
         __init__(self, \\*, featuresCol="features", outputCol=None, \
                  labelCol="label", selectionMode="numTopFeatures")
         """
         super(UnivariateFeatureSelector, self).__init__()
-        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.UnivariateFeatureSelector",
-                                            self.uid)
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.feature.UnivariateFeatureSelector", self.uid
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("3.1.1")
-    def setParams(self, *, featuresCol="features", outputCol=None,
-                  labelCol="labels", selectionMode="numTopFeatures"):
+    def setParams(
+        self,
+        *,
+        featuresCol="features",
+        outputCol=None,
+        labelCol="labels",
+        selectionMode="numTopFeatures",
+    ):
         """
         setParams(self, \\*, featuresCol="features", outputCol=None, \
                   labelCol="labels", selectionMode="numTopFeatures")
@@ -5961,8 +6567,9 @@ class UnivariateFeatureSelector(JavaEstimator, _UnivariateFeatureSelectorParams,
         return UnivariateFeatureSelectorModel(java_model)
 
 
-class UnivariateFeatureSelectorModel(JavaModel, _UnivariateFeatureSelectorParams, JavaMLReadable,
-                                     JavaMLWritable):
+class UnivariateFeatureSelectorModel(
+    JavaModel, _UnivariateFeatureSelectorParams, JavaMLReadable, JavaMLWritable
+):
     """
     Model fitted by :py:class:`UnivariateFeatureSelector`.
 
@@ -6006,24 +6613,30 @@ if __name__ == "__main__":
 
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    spark = SparkSession.builder\
-        .master("local[2]")\
-        .appName("ml.feature tests")\
-        .getOrCreate()
+    spark = SparkSession.builder.master("local[2]").appName("ml.feature tests").getOrCreate()
     sc = spark.sparkContext
-    globs['sc'] = sc
-    globs['spark'] = spark
-    testData = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="b"),
-                               Row(id=2, label="c"), Row(id=3, label="a"),
-                               Row(id=4, label="a"), Row(id=5, label="c")], 2)
-    globs['stringIndDf'] = spark.createDataFrame(testData)
+    globs["sc"] = sc
+    globs["spark"] = spark
+    testData = sc.parallelize(
+        [
+            Row(id=0, label="a"),
+            Row(id=1, label="b"),
+            Row(id=2, label="c"),
+            Row(id=3, label="a"),
+            Row(id=4, label="a"),
+            Row(id=5, label="c"),
+        ],
+        2,
+    )
+    globs["stringIndDf"] = spark.createDataFrame(testData)
     temp_path = tempfile.mkdtemp()
-    globs['temp_path'] = temp_path
+    globs["temp_path"] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
         spark.stop()
     finally:
         from shutil import rmtree
+
         try:
             rmtree(temp_path)
         except OSError:
diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi
index 33e4691..6efc304 100644
--- a/python/pyspark/ml/feature.pyi
+++ b/python/pyspark/ml/feature.pyi
@@ -61,7 +61,7 @@ class Binarizer(
         *,
         threshold: float = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     @overload
     def __init__(
@@ -69,7 +69,7 @@ class Binarizer(
         *,
         thresholds: Optional[List[float]] = ...,
         inputCols: Optional[List[str]] = ...,
-        outputCols: Optional[List[str]] = ...
+        outputCols: Optional[List[str]] = ...,
     ) -> None: ...
     @overload
     def setParams(
@@ -77,7 +77,7 @@ class Binarizer(
         *,
         threshold: float = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> Binarizer: ...
     @overload
     def setParams(
@@ -85,7 +85,7 @@ class Binarizer(
         *,
         thresholds: Optional[List[float]] = ...,
         inputCols: Optional[List[str]] = ...,
-        outputCols: Optional[List[str]] = ...
+        outputCols: Optional[List[str]] = ...,
     ) -> Binarizer: ...
     def setThreshold(self, value: float) -> Binarizer: ...
     def setThresholds(self, value: List[float]) -> Binarizer: ...
@@ -140,7 +140,7 @@ class BucketedRandomProjectionLSH(
         outputCol: Optional[str] = ...,
         seed: Optional[int] = ...,
         numHashTables: int = ...,
-        bucketLength: Optional[float] = ...
+        bucketLength: Optional[float] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -149,7 +149,7 @@ class BucketedRandomProjectionLSH(
         outputCol: Optional[str] = ...,
         seed: Optional[int] = ...,
         numHashTables: int = ...,
-        bucketLength: Optional[float] = ...
+        bucketLength: Optional[float] = ...,
     ) -> BucketedRandomProjectionLSH: ...
     def setBucketLength(self, value: float) -> BucketedRandomProjectionLSH: ...
     def setSeed(self, value: int) -> BucketedRandomProjectionLSH: ...
@@ -181,7 +181,7 @@ class Bucketizer(
         splits: Optional[List[float]] = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> None: ...
     @overload
     def __init__(
@@ -190,7 +190,7 @@ class Bucketizer(
         handleInvalid: str = ...,
         splitsArray: Optional[List[List[float]]] = ...,
         inputCols: Optional[List[str]] = ...,
-        outputCols: Optional[List[str]] = ...
+        outputCols: Optional[List[str]] = ...,
     ) -> None: ...
     @overload
     def setParams(
@@ -199,7 +199,7 @@ class Bucketizer(
         splits: Optional[List[float]] = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> Bucketizer: ...
     @overload
     def setParams(
@@ -208,7 +208,7 @@ class Bucketizer(
         handleInvalid: str = ...,
         splitsArray: Optional[List[List[float]]] = ...,
         inputCols: Optional[List[str]] = ...,
-        outputCols: Optional[List[str]] = ...
+        outputCols: Optional[List[str]] = ...,
     ) -> Bucketizer: ...
     def setSplits(self, value: List[float]) -> Bucketizer: ...
     def getSplits(self) -> List[float]: ...
@@ -248,7 +248,7 @@ class CountVectorizer(
         vocabSize: int = ...,
         binary: bool = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -259,7 +259,7 @@ class CountVectorizer(
         vocabSize: int = ...,
         binary: bool = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> CountVectorizer: ...
     def setMinTF(self, value: float) -> CountVectorizer: ...
     def setMinDF(self, value: float) -> CountVectorizer: ...
@@ -269,9 +269,7 @@ class CountVectorizer(
     def setInputCol(self, value: str) -> CountVectorizer: ...
     def setOutputCol(self, value: str) -> CountVectorizer: ...
 
-class CountVectorizerModel(
-    JavaModel, JavaMLReadable[CountVectorizerModel], JavaMLWritable
-):
+class CountVectorizerModel(JavaModel, JavaMLReadable[CountVectorizerModel], JavaMLWritable):
     def setInputCol(self, value: str) -> CountVectorizerModel: ...
     def setOutputCol(self, value: str) -> CountVectorizerModel: ...
     def setMinTF(self, value: float) -> CountVectorizerModel: ...
@@ -288,23 +286,13 @@ class CountVectorizerModel(
     @property
     def vocabulary(self) -> List[str]: ...
 
-class DCT(
-    JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[DCT], JavaMLWritable
-):
+class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[DCT], JavaMLWritable):
     inverse: Param[bool]
     def __init__(
-        self,
-        *,
-        inverse: bool = ...,
-        inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        self, *, inverse: bool = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
     ) -> None: ...
     def setParams(
-        self,
-        *,
-        inverse: bool = ...,
-        inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        self, *, inverse: bool = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
     ) -> DCT: ...
     def setInverse(self, value: bool) -> DCT: ...
     def getInverse(self) -> bool: ...
@@ -324,14 +312,14 @@ class ElementwiseProduct(
         *,
         scalingVec: Optional[Vector] = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
         *,
         scalingVec: Optional[Vector] = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> ElementwiseProduct: ...
     def setScalingVec(self, value: Vector) -> ElementwiseProduct: ...
     def getScalingVec(self) -> Vector: ...
@@ -353,7 +341,7 @@ class FeatureHasher(
         numFeatures: int = ...,
         inputCols: Optional[List[str]] = ...,
         outputCol: Optional[str] = ...,
-        categoricalCols: Optional[List[str]] = ...
+        categoricalCols: Optional[List[str]] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -361,7 +349,7 @@ class FeatureHasher(
         numFeatures: int = ...,
         inputCols: Optional[List[str]] = ...,
         outputCol: Optional[str] = ...,
-        categoricalCols: Optional[List[str]] = ...
+        categoricalCols: Optional[List[str]] = ...,
     ) -> FeatureHasher: ...
     def setCategoricalCols(self, value: List[str]) -> FeatureHasher: ...
     def getCategoricalCols(self) -> List[str]: ...
@@ -384,7 +372,7 @@ class HashingTF(
         numFeatures: int = ...,
         binary: bool = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -392,7 +380,7 @@ class HashingTF(
         numFeatures: int = ...,
         binary: bool = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> HashingTF: ...
     def setBinary(self, value: bool) -> HashingTF: ...
     def getBinary(self) -> bool: ...
@@ -412,14 +400,14 @@ class IDF(JavaEstimator[IDFModel], _IDFParams, JavaMLReadable[IDF], JavaMLWritab
         *,
         minDocFreq: int = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
         *,
         minDocFreq: int = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> IDF: ...
     def setMinDocFreq(self, value: int) -> IDF: ...
     def setInputCol(self, value: str) -> IDF: ...
@@ -435,17 +423,13 @@ class IDFModel(JavaModel, _IDFParams, JavaMLReadable[IDFModel], JavaMLWritable):
     @property
     def numDocs(self) -> int: ...
 
-class _ImputerParams(
-    HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasRelativeError
-):
+class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasRelativeError):
     strategy: Param[str]
     missingValue: Param[float]
     def getStrategy(self) -> str: ...
     def getMissingValue(self) -> float: ...
 
-class Imputer(
-    JavaEstimator[ImputerModel], _ImputerParams, JavaMLReadable[Imputer], JavaMLWritable
-):
+class Imputer(JavaEstimator[ImputerModel], _ImputerParams, JavaMLReadable[Imputer], JavaMLWritable):
     @overload
     def __init__(
         self,
@@ -454,7 +438,7 @@ class Imputer(
         missingValue: float = ...,
         inputCols: Optional[List[str]] = ...,
         outputCols: Optional[List[str]] = ...,
-        relativeError: float = ...
+        relativeError: float = ...,
     ) -> None: ...
     @overload
     def __init__(
@@ -464,7 +448,7 @@ class Imputer(
         missingValue: float = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        relativeError: float = ...
+        relativeError: float = ...,
     ) -> None: ...
     @overload
     def setParams(
@@ -474,7 +458,7 @@ class Imputer(
         missingValue: float = ...,
         inputCols: Optional[List[str]] = ...,
         outputCols: Optional[List[str]] = ...,
-        relativeError: float = ...
+        relativeError: float = ...,
     ) -> Imputer: ...
     @overload
     def setParams(
@@ -484,7 +468,7 @@ class Imputer(
         missingValue: float = ...,
         inputCol: Optional[str] = ...,
         outputCols: Optional[str] = ...,
-        relativeError: float = ...
+        relativeError: float = ...,
     ) -> Imputer: ...
     def setStrategy(self, value: str) -> Imputer: ...
     def setMissingValue(self, value: float) -> Imputer: ...
@@ -494,9 +478,7 @@ class Imputer(
     def setOutputCol(self, value: str) -> Imputer: ...
     def setRelativeError(self, value: float) -> Imputer: ...
 
-class ImputerModel(
-    JavaModel, _ImputerParams, JavaMLReadable[ImputerModel], JavaMLWritable
-):
+class ImputerModel(JavaModel, _ImputerParams, JavaMLReadable[ImputerModel], JavaMLWritable):
     def setInputCols(self, value: List[str]) -> ImputerModel: ...
     def setOutputCols(self, value: List[str]) -> ImputerModel: ...
     def setInputCol(self, value: str) -> ImputerModel: ...
@@ -559,7 +541,7 @@ class MinHashLSH(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         seed: Optional[int] = ...,
-        numHashTables: int = ...
+        numHashTables: int = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -567,7 +549,7 @@ class MinHashLSH(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         seed: Optional[int] = ...,
-        numHashTables: int = ...
+        numHashTables: int = ...,
     ) -> MinHashLSH: ...
     def setSeed(self, value: int) -> MinHashLSH: ...
 
@@ -592,7 +574,7 @@ class MinMaxScaler(
         min: float = ...,
         max: float = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -600,7 +582,7 @@ class MinMaxScaler(
         min: float = ...,
         max: float = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> MinMaxScaler: ...
     def setMin(self, value: float) -> MinMaxScaler: ...
     def setMax(self, value: float) -> MinMaxScaler: ...
@@ -619,23 +601,13 @@ class MinMaxScalerModel(
     @property
     def originalMax(self) -> Vector: ...
 
-class NGram(
-    JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[NGram], JavaMLWritable
-):
+class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[NGram], JavaMLWritable):
     n: Param[int]
     def __init__(
-        self,
-        *,
-        n: int = ...,
-        inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        self, *, n: int = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
     ) -> None: ...
     def setParams(
-        self,
-        *,
-        n: int = ...,
-        inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        self, *, n: int = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
     ) -> NGram: ...
     def setN(self, value: int) -> NGram: ...
     def getN(self) -> int: ...
@@ -651,18 +623,10 @@ class Normalizer(
 ):
     p: Param[float]
     def __init__(
-        self,
-        *,
-        p: float = ...,
-        inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        self, *, p: float = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
     ) -> None: ...
     def setParams(
-        self,
-        *,
-        p: float = ...,
-        inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        self, *, p: float = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
     ) -> Normalizer: ...
     def setP(self, value: float) -> Normalizer: ...
     def getP(self) -> float: ...
@@ -688,7 +652,7 @@ class OneHotEncoder(
         inputCols: Optional[List[str]] = ...,
         outputCols: Optional[List[str]] = ...,
         handleInvalid: str = ...,
-        dropLast: bool = ...
+        dropLast: bool = ...,
     ) -> None: ...
     @overload
     def __init__(
@@ -697,7 +661,7 @@ class OneHotEncoder(
         handleInvalid: str = ...,
         dropLast: bool = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     @overload
     def setParams(
@@ -706,7 +670,7 @@ class OneHotEncoder(
         inputCols: Optional[List[str]] = ...,
         outputCols: Optional[List[str]] = ...,
         handleInvalid: str = ...,
-        dropLast: bool = ...
+        dropLast: bool = ...,
     ) -> OneHotEncoder: ...
     @overload
     def setParams(
@@ -715,7 +679,7 @@ class OneHotEncoder(
         handleInvalid: str = ...,
         dropLast: bool = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> OneHotEncoder: ...
     def setDropLast(self, value: bool) -> OneHotEncoder: ...
     def setInputCols(self, value: List[str]) -> OneHotEncoder: ...
@@ -745,18 +709,10 @@ class PolynomialExpansion(
 ):
     degree: Param[int]
     def __init__(
-        self,
-        *,
-        degree: int = ...,
-        inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        self, *, degree: int = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
     ) -> None: ...
     def setParams(
-        self,
-        *,
-        degree: int = ...,
-        inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        self, *, degree: int = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
     ) -> PolynomialExpansion: ...
     def setDegree(self, value: int) -> PolynomialExpansion: ...
     def getDegree(self) -> int: ...
@@ -785,7 +741,7 @@ class QuantileDiscretizer(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         relativeError: float = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> None: ...
     @overload
     def __init__(
@@ -795,7 +751,7 @@ class QuantileDiscretizer(
         handleInvalid: str = ...,
         numBucketsArray: Optional[List[int]] = ...,
         inputCols: Optional[List[str]] = ...,
-        outputCols: Optional[List[str]] = ...
+        outputCols: Optional[List[str]] = ...,
     ) -> None: ...
     @overload
     def setParams(
@@ -805,7 +761,7 @@ class QuantileDiscretizer(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         relativeError: float = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> QuantileDiscretizer: ...
     @overload
     def setParams(
@@ -815,7 +771,7 @@ class QuantileDiscretizer(
         handleInvalid: str = ...,
         numBucketsArray: Optional[List[int]] = ...,
         inputCols: Optional[List[str]] = ...,
-        outputCols: Optional[List[str]] = ...
+        outputCols: Optional[List[str]] = ...,
     ) -> QuantileDiscretizer: ...
     def setNumBuckets(self, value: int) -> QuantileDiscretizer: ...
     def getNumBuckets(self) -> int: ...
@@ -851,7 +807,7 @@ class RobustScaler(
         withScaling: bool = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        relativeError: float = ...
+        relativeError: float = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -862,7 +818,7 @@ class RobustScaler(
         withScaling: bool = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        relativeError: float = ...
+        relativeError: float = ...,
     ) -> RobustScaler: ...
     def setLower(self, value: float) -> RobustScaler: ...
     def setUpper(self, value: float) -> RobustScaler: ...
@@ -901,7 +857,7 @@ class RegexTokenizer(
         pattern: str = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        toLowercase: bool = ...
+        toLowercase: bool = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -911,7 +867,7 @@ class RegexTokenizer(
         pattern: str = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        toLowercase: bool = ...
+        toLowercase: bool = ...,
     ) -> RegexTokenizer: ...
     def setMinTokenLength(self, value: int) -> RegexTokenizer: ...
     def getMinTokenLength(self) -> int: ...
@@ -950,7 +906,7 @@ class StandardScaler(
         withMean: bool = ...,
         withStd: bool = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -958,7 +914,7 @@ class StandardScaler(
         withMean: bool = ...,
         withStd: bool = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> StandardScaler: ...
     def setWithMean(self, value: bool) -> StandardScaler: ...
     def setWithStd(self, value: bool) -> StandardScaler: ...
@@ -999,7 +955,7 @@ class StringIndexer(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         handleInvalid: str = ...,
-        stringOrderType: str = ...
+        stringOrderType: str = ...,
     ) -> None: ...
     @overload
     def __init__(
@@ -1008,7 +964,7 @@ class StringIndexer(
         inputCols: Optional[List[str]] = ...,
         outputCols: Optional[List[str]] = ...,
         handleInvalid: str = ...,
-        stringOrderType: str = ...
+        stringOrderType: str = ...,
     ) -> None: ...
     @overload
     def setParams(
@@ -1017,7 +973,7 @@ class StringIndexer(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         handleInvalid: str = ...,
-        stringOrderType: str = ...
+        stringOrderType: str = ...,
     ) -> StringIndexer: ...
     @overload
     def setParams(
@@ -1026,7 +982,7 @@ class StringIndexer(
         inputCols: Optional[List[str]] = ...,
         outputCols: Optional[List[str]] = ...,
         handleInvalid: str = ...,
-        stringOrderType: str = ...
+        stringOrderType: str = ...,
     ) -> StringIndexer: ...
     def setStringOrderType(self, value: str) -> StringIndexer: ...
     def setInputCol(self, value: str) -> StringIndexer: ...
@@ -1075,14 +1031,14 @@ class IndexToString(
         *,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        labels: Optional[List[str]] = ...
+        labels: Optional[List[str]] = ...,
     ) -> None: ...
     def setParams(
         self,
         *,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        labels: Optional[List[str]] = ...
+        labels: Optional[List[str]] = ...,
     ) -> IndexToString: ...
     def setLabels(self, value: List[str]) -> IndexToString: ...
     def getLabels(self) -> List[str]: ...
@@ -1109,7 +1065,7 @@ class StopWordsRemover(
         outputCol: Optional[str] = ...,
         stopWords: Optional[List[str]] = ...,
         caseSensitive: bool = ...,
-        locale: Optional[str] = ...
+        locale: Optional[str] = ...,
     ) -> None: ...
     @overload
     def __init__(
@@ -1119,7 +1075,7 @@ class StopWordsRemover(
         caseSensitive: bool = ...,
         locale: Optional[str] = ...,
         inputCols: Optional[List[str]] = ...,
-        outputCols: Optional[List[str]] = ...
+        outputCols: Optional[List[str]] = ...,
     ) -> None: ...
     @overload
     def setParams(
@@ -1129,7 +1085,7 @@ class StopWordsRemover(
         outputCol: Optional[str] = ...,
         stopWords: Optional[List[str]] = ...,
         caseSensitive: bool = ...,
-        locale: Optional[str] = ...
+        locale: Optional[str] = ...,
     ) -> StopWordsRemover: ...
     @overload
     def setParams(
@@ -1139,7 +1095,7 @@ class StopWordsRemover(
         caseSensitive: bool = ...,
         locale: Optional[str] = ...,
         inputCols: Optional[List[str]] = ...,
-        outputCols: Optional[List[str]] = ...
+        outputCols: Optional[List[str]] = ...,
     ) -> StopWordsRemover: ...
     def setStopWords(self, value: List[str]) -> StopWordsRemover: ...
     def getStopWords(self) -> List[str]: ...
@@ -1184,14 +1140,14 @@ class VectorAssembler(
         *,
         inputCols: Optional[List[str]] = ...,
         outputCol: Optional[str] = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> None: ...
     def setParams(
         self,
         *,
         inputCols: Optional[List[str]] = ...,
         outputCol: Optional[str] = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> VectorAssembler: ...
     def setInputCols(self, value: List[str]) -> VectorAssembler: ...
     def setOutputCol(self, value: str) -> VectorAssembler: ...
@@ -1216,7 +1172,7 @@ class VectorIndexer(
         maxCategories: int = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -1224,7 +1180,7 @@ class VectorIndexer(
         maxCategories: int = ...,
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> VectorIndexer: ...
     def setMaxCategories(self, value: int) -> VectorIndexer: ...
     def setInputCol(self, value: str) -> VectorIndexer: ...
@@ -1256,7 +1212,7 @@ class VectorSlicer(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         indices: Optional[List[int]] = ...,
-        names: Optional[List[str]] = ...
+        names: Optional[List[str]] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -1264,7 +1220,7 @@ class VectorSlicer(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         indices: Optional[List[int]] = ...,
-        names: Optional[List[str]] = ...
+        names: Optional[List[str]] = ...,
     ) -> VectorSlicer: ...
     def setIndices(self, value: List[int]) -> VectorSlicer: ...
     def getIndices(self) -> List[int]: ...
@@ -1304,7 +1260,7 @@ class Word2Vec(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         windowSize: int = ...,
-        maxSentenceLength: int = ...
+        maxSentenceLength: int = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -1318,7 +1274,7 @@ class Word2Vec(
         inputCol: Optional[str] = ...,
         outputCol: Optional[str] = ...,
         windowSize: int = ...,
-        maxSentenceLength: int = ...
+        maxSentenceLength: int = ...,
     ) -> Word2Vec: ...
     def setVectorSize(self, value: int) -> Word2Vec: ...
     def setNumPartitions(self, value: int) -> Word2Vec: ...
@@ -1331,9 +1287,7 @@ class Word2Vec(
     def setSeed(self, value: int) -> Word2Vec: ...
     def setStepSize(self, value: float) -> Word2Vec: ...
 
-class Word2VecModel(
-    JavaModel, _Word2VecParams, JavaMLReadable[Word2VecModel], JavaMLWritable
-):
+class Word2VecModel(JavaModel, _Word2VecParams, JavaMLReadable[Word2VecModel], JavaMLWritable):
     def getVectors(self) -> DataFrame: ...
     def setInputCol(self, value: str) -> Word2VecModel: ...
     def setOutputCol(self, value: str) -> Word2VecModel: ...
@@ -1356,14 +1310,14 @@ class PCA(JavaEstimator[PCAModel], _PCAParams, JavaMLReadable[PCA], JavaMLWritab
         *,
         k: Optional[int] = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> None: ...
     def setParams(
         self,
         *,
         k: Optional[int] = ...,
         inputCol: Optional[str] = ...,
-        outputCol: Optional[str] = ...
+        outputCol: Optional[str] = ...,
     ) -> PCA: ...
     def setK(self, value: int) -> PCA: ...
     def setInputCol(self, value: str) -> PCA: ...
@@ -1401,7 +1355,7 @@ class RFormula(
         labelCol: str = ...,
         forceIndexLabel: bool = ...,
         stringIndexerOrderType: str = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -1411,7 +1365,7 @@ class RFormula(
         labelCol: str = ...,
         forceIndexLabel: bool = ...,
         stringIndexerOrderType: str = ...,
-        handleInvalid: str = ...
+        handleInvalid: str = ...,
     ) -> RFormula: ...
     def setFormula(self, value: str) -> RFormula: ...
     def setForceIndexLabel(self, value: bool) -> RFormula: ...
@@ -1420,9 +1374,7 @@ class RFormula(
     def setLabelCol(self, value: str) -> RFormula: ...
     def setHandleInvalid(self, value: str) -> RFormula: ...
 
-class RFormulaModel(
-    JavaModel, _RFormulaParams, JavaMLReadable[RFormulaModel], JavaMLWritable
-): ...
+class RFormulaModel(JavaModel, _RFormulaParams, JavaMLReadable[RFormulaModel], JavaMLWritable): ...
 
 class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol):
     selectorType: Param[str]
@@ -1472,7 +1424,7 @@ class ChiSqSelector(
         percentile: float = ...,
         fpr: float = ...,
         fdr: float = ...,
-        fwe: float = ...
+        fwe: float = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -1485,7 +1437,7 @@ class ChiSqSelector(
         percentile: float = ...,
         fpr: float = ...,
         fdr: float = ...,
-        fwe: float = ...
+        fwe: float = ...,
     ) -> ChiSqSelector: ...
     def setSelectorType(self, value: str) -> ChiSqSelector: ...
     def setNumTopFeatures(self, value: int) -> ChiSqSelector: ...
@@ -1497,9 +1449,7 @@ class ChiSqSelector(
     def setOutputCol(self, value: str) -> ChiSqSelector: ...
     def setLabelCol(self, value: str) -> ChiSqSelector: ...
 
-class ChiSqSelectorModel(
-    _SelectorModel, JavaMLReadable[ChiSqSelectorModel], JavaMLWritable
-):
+class ChiSqSelectorModel(_SelectorModel, JavaMLReadable[ChiSqSelectorModel], JavaMLWritable):
     def setFeaturesCol(self, value: str) -> ChiSqSelectorModel: ...
     def setOutputCol(self, value: str) -> ChiSqSelectorModel: ...
     @property
@@ -1515,18 +1465,10 @@ class VectorSizeHint(
     size: Param[int]
     handleInvalid: Param[str]
     def __init__(
-        self,
-        *,
-        inputCol: Optional[str] = ...,
-        size: Optional[int] = ...,
-        handleInvalid: str = ...
+        self, *, inputCol: Optional[str] = ..., size: Optional[int] = ..., handleInvalid: str = ...
     ) -> None: ...
     def setParams(
-        self,
-        *,
-        inputCol: Optional[str] = ...,
-        size: Optional[int] = ...,
-        handleInvalid: str = ...
+        self, *, inputCol: Optional[str] = ..., size: Optional[int] = ..., handleInvalid: str = ...
     ) -> VectorSizeHint: ...
     def setSize(self, value: int) -> VectorSizeHint: ...
     def getSize(self) -> int: ...
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 37dac48..9cfd3af 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -33,34 +33,39 @@ class _FPGrowthParams(HasPredictionCol):
     .. versionadded:: 3.0.0
     """
 
-    itemsCol = Param(Params._dummy(), "itemsCol",
-                     "items column name", typeConverter=TypeConverters.toString)
+    itemsCol = Param(
+        Params._dummy(), "itemsCol", "items column name", typeConverter=TypeConverters.toString
+    )
     minSupport = Param(
         Params._dummy(),
         "minSupport",
-        "Minimal support level of the frequent pattern. [0.0, 1.0]. " +
-        "Any pattern that appears more than (minSupport * size-of-the-dataset) " +
-        "times will be output in the frequent itemsets.",
-        typeConverter=TypeConverters.toFloat)
+        "Minimal support level of the frequent pattern. [0.0, 1.0]. "
+        + "Any pattern that appears more than (minSupport * size-of-the-dataset) "
+        + "times will be output in the frequent itemsets.",
+        typeConverter=TypeConverters.toFloat,
+    )
     numPartitions = Param(
         Params._dummy(),
         "numPartitions",
-        "Number of partitions (at least 1) used by parallel FP-growth. " +
-        "By default the param is not set, " +
-        "and partition number of the input dataset is used.",
-        typeConverter=TypeConverters.toInt)
+        "Number of partitions (at least 1) used by parallel FP-growth. "
+        + "By default the param is not set, "
+        + "and partition number of the input dataset is used.",
+        typeConverter=TypeConverters.toInt,
+    )
     minConfidence = Param(
         Params._dummy(),
         "minConfidence",
-        "Minimal confidence for generating Association Rule. [0.0, 1.0]. " +
-        "minConfidence will not affect the mining for frequent itemsets, " +
-        "but will affect the association rules generation.",
-        typeConverter=TypeConverters.toFloat)
+        "Minimal confidence for generating Association Rule. [0.0, 1.0]. "
+        + "minConfidence will not affect the mining for frequent itemsets, "
+        + "but will affect the association rules generation.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
     def __init__(self, *args):
         super(_FPGrowthParams, self).__init__(*args)
-        self._setDefault(minSupport=0.3, minConfidence=0.8,
-                         itemsCol="items", predictionCol="prediction")
+        self._setDefault(
+            minSupport=0.3, minConfidence=0.8, itemsCol="items", predictionCol="prediction"
+        )
 
     def getItemsCol(self):
         """
@@ -224,9 +229,17 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
     >>> fpm.transform(data).take(1) == model2.transform(data).take(1)
     True
     """
+
     @keyword_only
-    def __init__(self, *, minSupport=0.3, minConfidence=0.8, itemsCol="items",
-                 predictionCol="prediction", numPartitions=None):
+    def __init__(
+        self,
+        *,
+        minSupport=0.3,
+        minConfidence=0.8,
+        itemsCol="items",
+        predictionCol="prediction",
+        numPartitions=None,
+    ):
         """
         __init__(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
                  predictionCol="prediction", numPartitions=None)
@@ -238,8 +251,15 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
 
     @keyword_only
     @since("2.2.0")
-    def setParams(self, *, minSupport=0.3, minConfidence=0.8, itemsCol="items",
-                  predictionCol="prediction", numPartitions=None):
+    def setParams(
+        self,
+        *,
+        minSupport=0.3,
+        minConfidence=0.8,
+        itemsCol="items",
+        predictionCol="prediction",
+        numPartitions=None,
+    ):
         """
         setParams(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
                   predictionCol="prediction", numPartitions=None)
@@ -327,45 +347,72 @@ class PrefixSpan(JavaParams):
     ...
     """
 
-    minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " +
-                       "sequential pattern. Sequential pattern that appears more than " +
-                       "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.",
-                       typeConverter=TypeConverters.toFloat)
-
-    maxPatternLength = Param(Params._dummy(), "maxPatternLength",
-                             "The maximal length of the sequential pattern. Must be > 0.",
-                             typeConverter=TypeConverters.toInt)
+    minSupport = Param(
+        Params._dummy(),
+        "minSupport",
+        "The minimal support level of the "
+        + "sequential pattern. Sequential pattern that appears more than "
+        + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.",
+        typeConverter=TypeConverters.toFloat,
+    )
 
-    maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize",
-                               "The maximum number of items (including delimiters used in the " +
-                               "internal storage format) allowed in a projected database before " +
-                               "local processing. If a projected database exceeds this size, " +
-                               "another iteration of distributed prefix growth is run. " +
-                               "Must be > 0.",
-                               typeConverter=TypeConverters.toInt)
+    maxPatternLength = Param(
+        Params._dummy(),
+        "maxPatternLength",
+        "The maximal length of the sequential pattern. Must be > 0.",
+        typeConverter=TypeConverters.toInt,
+    )
 
-    sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " +
-                        "dataset, rows with nulls in this column are ignored.",
-                        typeConverter=TypeConverters.toString)
+    maxLocalProjDBSize = Param(
+        Params._dummy(),
+        "maxLocalProjDBSize",
+        "The maximum number of items (including delimiters used in the "
+        + "internal storage format) allowed in a projected database before "
+        + "local processing. If a projected database exceeds this size, "
+        + "another iteration of distributed prefix growth is run. "
+        + "Must be > 0.",
+        typeConverter=TypeConverters.toInt,
+    )
+
+    sequenceCol = Param(
+        Params._dummy(),
+        "sequenceCol",
+        "The name of the sequence column in "
+        + "dataset, rows with nulls in this column are ignored.",
+        typeConverter=TypeConverters.toString,
+    )
 
     @keyword_only
-    def __init__(self, *, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
-                 sequenceCol="sequence"):
+    def __init__(
+        self,
+        *,
+        minSupport=0.1,
+        maxPatternLength=10,
+        maxLocalProjDBSize=32000000,
+        sequenceCol="sequence",
+    ):
         """
         __init__(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
                  sequenceCol="sequence")
         """
         super(PrefixSpan, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid)
-        self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
-                         sequenceCol="sequence")
+        self._setDefault(
+            minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, sequenceCol="sequence"
+        )
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.4.0")
-    def setParams(self, *, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
-                  sequenceCol="sequence"):
+    def setParams(
+        self,
+        *,
+        minSupport=0.1,
+        maxPatternLength=10,
+        maxLocalProjDBSize=32000000,
+        sequenceCol="sequence",
+    ):
         """
         setParams(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
                   sequenceCol="sequence")
@@ -460,24 +507,24 @@ if __name__ == "__main__":
     import doctest
     import pyspark.ml.fpm
     from pyspark.sql import SparkSession
+
     globs = pyspark.ml.fpm.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
-    spark = SparkSession.builder\
-        .master("local[2]")\
-        .appName("ml.fpm tests")\
-        .getOrCreate()
+    spark = SparkSession.builder.master("local[2]").appName("ml.fpm tests").getOrCreate()
     sc = spark.sparkContext
-    globs['sc'] = sc
-    globs['spark'] = spark
+    globs["sc"] = sc
+    globs["spark"] = spark
     import tempfile
+
     temp_path = tempfile.mkdtemp()
-    globs['temp_path'] = temp_path
+    globs["temp_path"] = temp_path
     try:
         (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
         spark.stop()
     finally:
         from shutil import rmtree
+
         try:
             rmtree(temp_path)
         except OSError:
diff --git a/python/pyspark/ml/fpm.pyi b/python/pyspark/ml/fpm.pyi
index 7cc304a..609bc44 100644
--- a/python/pyspark/ml/fpm.pyi
+++ b/python/pyspark/ml/fpm.pyi
@@ -36,9 +36,7 @@ class _FPGrowthParams(HasPredictionCol):
     def getNumPartitions(self) -> int: ...
     def getMinConfidence(self) -> float: ...
 
-class FPGrowthModel(
-    JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable[FPGrowthModel]
-):
+class FPGrowthModel(JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable[FPGrowthModel]):
     def setItemsCol(self, value: str) -> FPGrowthModel: ...
     def setMinConfidence(self, value: float) -> FPGrowthModel: ...
     def setPredictionCol(self, value: str) -> FPGrowthModel: ...
@@ -60,7 +58,7 @@ class FPGrowth(
         minConfidence: float = ...,
         itemsCol: str = ...,
         predictionCol: str = ...,
-        numPartitions: Optional[int] = ...
+        numPartitions: Optional[int] = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -69,7 +67,7 @@ class FPGrowth(
         minConfidence: float = ...,
         itemsCol: str = ...,
         predictionCol: str = ...,
-        numPartitions: Optional[int] = ...
+        numPartitions: Optional[int] = ...,
     ) -> FPGrowth: ...
     def setItemsCol(self, value: str) -> FPGrowth: ...
     def setMinSupport(self, value: float) -> FPGrowth: ...
@@ -88,7 +86,7 @@ class PrefixSpan(JavaParams):
         minSupport: float = ...,
         maxPatternLength: int = ...,
         maxLocalProjDBSize: int = ...,
-        sequenceCol: str = ...
+        sequenceCol: str = ...,
     ) -> None: ...
     def setParams(
         self,
@@ -96,7 +94,7 @@ class PrefixSpan(JavaParams):
         minSupport: float = ...,
         maxPatternLength: int = ...,
         maxLocalProjDBSize: int = ...,
-        sequenceCol: str = ...
+        sequenceCol: str = ...,
     ) -> PrefixSpan: ...
     def setMinSupport(self, value: float) -> PrefixSpan: ...
     def getMinSupport(self) -> float: ...
diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py
index 1eadbd6..64b5948 100644
--- a/python/pyspark/ml/functions.py
+++ b/python/pyspark/ml/functions.py
@@ -66,7 +66,8 @@ def vector_to_array(col, dtype="float64"):
     """
     sc = SparkContext._active_spark_context
     return Column(
-        sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col), dtype))
+        sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col), dtype)
+    )
 
 
 def array_to_vector(col):
@@ -100,8 +101,7 @@ def array_to_vector(col):
     [Row(vec1=DenseVector([1.0, 3.0]))]
     """
     sc = SparkContext._active_spark_context
-    return Column(
-        sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
+    return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
 def _test():
@@ -109,18 +109,18 @@ def _test():
     from pyspark.sql import SparkSession
     import pyspark.ml.functions
     import sys
+
     globs = pyspark.ml.functions.__dict__.copy()
-    spark = SparkSession.builder \
-        .master("local[2]") \
-        .appName("ml.functions tests") \
-        .getOrCreate()
+    spark = SparkSession.builder.master("local[2]").appName("ml.functions tests").getOrCreate()
     sc = spark.sparkContext
-    globs['sc'] = sc
-    globs['spark'] = spark
+    globs["sc"] = sc
+    globs["spark"] = spark
 
     (failure_count, test_count) = doctest.testmod(
-        pyspark.ml.functions, globs=globs,
-        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+        pyspark.ml.functions,
+        globs=globs,
+        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
+    )
     spark.stop()
     if failure_count:
         sys.exit(-1)
diff --git a/python/pyspark/ml/functions.pyi b/python/pyspark/ml/functions.pyi
index 12b44fc..cb08398 100644
--- a/python/pyspark/ml/functions.pyi
+++ b/python/pyspark/ml/functions.pyi
@@ -20,5 +20,4 @@ from pyspark import SparkContext as SparkContext, since as since  # noqa: F401
 from pyspark.sql.column import Column as Column
 
 def vector_to_array(col: Column) -> Column: ...
-
 def array_to_vector(col: Column) -> Column: ...
diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py
index 728e9a3..bdd1c9d 100644
--- a/python/pyspark/ml/image.py
+++ b/python/pyspark/ml/image.py
@@ -136,8 +136,9 @@ class _ImageSchema(object):
 
         if self._undefinedImageType is None:
             ctx = SparkContext._active_spark_context
-            self._undefinedImageType = \
+            self._undefinedImageType = (
                 ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType()
+            )
         return self._undefinedImageType
 
     def toNDArray(self, image):
@@ -161,12 +162,14 @@ class _ImageSchema(object):
         if not isinstance(image, Row):
             raise TypeError(
                 "image argument should be pyspark.sql.types.Row; however, "
-                "it got [%s]." % type(image))
+                "it got [%s]." % type(image)
+            )
 
         if any(not hasattr(image, f) for f in self.imageFields):
             raise ValueError(
                 "image argument should have attributes specified in "
-                "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))
+                "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields)
+            )
 
         height = image.height
         width = image.width
@@ -175,7 +178,8 @@ class _ImageSchema(object):
             shape=(height, width, nChannels),
             dtype=np.uint8,
             buffer=image.data,
-            strides=(width * nChannels, nChannels, 1))
+            strides=(width * nChannels, nChannels, 1),
+        )
 
     def toImage(self, array, origin=""):
         """
@@ -198,7 +202,8 @@ class _ImageSchema(object):
 
         if not isinstance(array, np.ndarray):
             raise TypeError(
-                "array argument should be numpy.ndarray; however, it got [%s]." % type(array))
+                "array argument should be numpy.ndarray; however, it got [%s]." % type(array)
+            )
 
         if array.ndim != 3:
             raise ValueError("Invalid array shape")
@@ -217,7 +222,7 @@ class _ImageSchema(object):
         # Running `bytearray(numpy.array([1]))` fails in specific Python versions
         # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
         # Here, it avoids it by converting it to bytes.
-        if LooseVersion(np.__version__) >= LooseVersion('1.9'):
+        if LooseVersion(np.__version__) >= LooseVersion("1.9"):
             data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
         else:
             # Numpy prior to 1.9 don't have `tobytes` method.
@@ -226,8 +231,7 @@ class _ImageSchema(object):
         # Creating new Row with _create_row(), because Row(name = value, ... )
         # orders fields by name, which conflicts with expected schema order
         # when the new DataFrame is created by UDF
-        return _create_row(self.imageFields,
-                           [origin, height, width, nChannels, mode, data])
+        return _create_row(self.imageFields, [origin, height, width, nChannels, mode, data])
 
 
 ImageSchema = _ImageSchema()
@@ -236,22 +240,22 @@ ImageSchema = _ImageSchema()
 # Monkey patch to disallow instantiation of this class.
 def _disallow_instance(_):
     raise RuntimeError("Creating instance of _ImageSchema class is disallowed.")
+
+
 _ImageSchema.__init__ = _disallow_instance
 
 
 def _test():
     import doctest
     import pyspark.ml.image
+
     globs = pyspark.ml.image.__dict__.copy()
-    spark = SparkSession.builder\
-        .master("local[2]")\
-        .appName("ml.image tests")\
-        .getOrCreate()
-    globs['spark'] = spark
+    spark = SparkSession.builder.master("local[2]").appName("ml.image tests").getOrCreate()
+    globs["spark"] = spark
 
     (failure_count, test_count) = doctest.testmod(
-        pyspark.ml.image, globs=globs,
-        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+        pyspark.ml.image, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
+    )
     spark.stop()
     if failure_count:
         sys.exit(-1)
diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py
index 2f50bea..46a8b97 100644
--- a/python/pyspark/ml/linalg/__init__.py
+++ b/python/pyspark/ml/linalg/__init__.py
@@ -29,12 +29,28 @@ import struct
 
 import numpy as np
 
-from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
-    IntegerType, ByteType, BooleanType
-
-
-__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
-           'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices']
+from pyspark.sql.types import (
+    UserDefinedType,
+    StructField,
+    StructType,
+    ArrayType,
+    DoubleType,
+    IntegerType,
+    ByteType,
+    BooleanType,
+)
+
+
+__all__ = [
+    "Vector",
+    "DenseVector",
+    "SparseVector",
+    "Vectors",
+    "Matrix",
+    "DenseMatrix",
+    "SparseMatrix",
+    "Matrices",
+]
 
 
 # Check whether we have SciPy. MLlib works without it too, but if we have it, some methods,
@@ -42,6 +58,7 @@ __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
 
 try:
     import scipy.sparse
+
     _have_scipy = True
 except:
     # No SciPy in environment, but that's okay
@@ -103,8 +120,8 @@ def _vector_size(v):
 
 def _format_float(f, digits=4):
     s = str(round(f, digits))
-    if '.' in s:
-        s = s[:s.index('.') + 1 + digits]
+    if "." in s:
+        s = s[: s.index(".") + 1 + digits]
     return s
 
 
@@ -114,9 +131,9 @@ def _format_float_list(l):
 
 def _double_to_long_bits(value):
     if np.isnan(value):
-        value = float('nan')
+        value = float("nan")
     # pack double into 64 bits, then unpack as long int
-    return struct.unpack('Q', struct.pack('d', value))[0]
+    return struct.unpack("Q", struct.pack("d", value))[0]
 
 
 class VectorUDT(UserDefinedType):
@@ -126,11 +143,14 @@ class VectorUDT(UserDefinedType):
 
     @classmethod
     def sqlType(cls):
-        return StructType([
-            StructField("type", ByteType(), False),
-            StructField("size", IntegerType(), True),
-            StructField("indices", ArrayType(IntegerType(), False), True),
-            StructField("values", ArrayType(DoubleType(), False), True)])
+        return StructType(
+            [
+                StructField("type", ByteType(), False),
+                StructField("size", IntegerType(), True),
+                StructField("indices", ArrayType(IntegerType(), False), True),
+                StructField("values", ArrayType(DoubleType(), False), True),
+            ]
+        )
 
     @classmethod
     def module(cls):
@@ -152,8 +172,9 @@ class VectorUDT(UserDefinedType):
             raise TypeError("cannot serialize %r of type %r" % (obj, type(obj)))
 
     def deserialize(self, datum):
-        assert len(datum) == 4, \
-            "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
+        assert (
+            len(datum) == 4
+        ), "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
         tpe = datum[0]
         if tpe == 0:
             return SparseVector(datum[1], datum[2], datum[3])
@@ -173,14 +194,17 @@ class MatrixUDT(UserDefinedType):
 
     @classmethod
     def sqlType(cls):
-        return StructType([
-            StructField("type", ByteType(), False),
-            StructField("numRows", IntegerType(), False),
-            StructField("numCols", IntegerType(), False),
-            StructField("colPtrs", ArrayType(IntegerType(), False), True),
-            StructField("rowIndices", ArrayType(IntegerType(), False), True),
-            StructField("values", ArrayType(DoubleType(), False), True),
-            StructField("isTransposed", BooleanType(), False)])
+        return StructType(
+            [
+                StructField("type", ByteType(), False),
+                StructField("numRows", IntegerType(), False),
+                StructField("numCols", IntegerType(), False),
+                StructField("colPtrs", ArrayType(IntegerType(), False), True),
+                StructField("rowIndices", ArrayType(IntegerType(), False), True),
+                StructField("values", ArrayType(DoubleType(), False), True),
+                StructField("isTransposed", BooleanType(), False),
+            ]
+        )
 
     @classmethod
     def module(cls):
@@ -195,18 +219,25 @@ class MatrixUDT(UserDefinedType):
             colPtrs = [int(i) for i in obj.colPtrs]
             rowIndices = [int(i) for i in obj.rowIndices]
             values = [float(v) for v in obj.values]
-            return (0, obj.numRows, obj.numCols, colPtrs,
-                    rowIndices, values, bool(obj.isTransposed))
+            return (
+                0,
+                obj.numRows,
+                obj.numCols,
+                colPtrs,
+                rowIndices,
+                values,
+                bool(obj.isTransposed),
+            )
         elif isinstance(obj, DenseMatrix):
             values = [float(v) for v in obj.values]
-            return (1, obj.numRows, obj.numCols, None, None, values,
-                    bool(obj.isTransposed))
+            return (1, obj.numRows, obj.numCols, None, None, values, bool(obj.isTransposed))
         else:
             raise TypeError("cannot serialize type %r" % (type(obj)))
 
     def deserialize(self, datum):
-        assert len(datum) == 7, \
-            "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
+        assert (
+            len(datum) == 7
+        ), "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
         tpe = datum[0]
         if tpe == 0:
             return SparseMatrix(*datum[1:])
@@ -226,6 +257,7 @@ class Vector(object):
     """
     Abstract class for DenseVector and SparseVector
     """
+
     def toArray(self):
         """
         Convert the vector into an numpy.ndarray
@@ -260,6 +292,7 @@ class DenseVector(Vector):
     >>> -v
     DenseVector([-1.0, -2.0])
     """
+
     def __init__(self, ar):
         if isinstance(ar, bytes):
             ar = np.frombuffer(ar, dtype=np.float64)
@@ -400,7 +433,7 @@ class DenseVector(Vector):
         return "[" + ",".join([str(v) for v in self.array]) + "]"
 
     def __repr__(self):
-        return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array))
+        return "DenseVector([%s])" % (", ".join(_format_float(i) for i in self.array))
 
     def __eq__(self, other):
         if isinstance(other, DenseVector):
@@ -439,6 +472,7 @@ class DenseVector(Vector):
             if isinstance(other, DenseVector):
                 other = other.array
             return DenseVector(getattr(self.array, op)(other))
+
         return func
 
     __add__ = _delegate("__add__")
@@ -460,6 +494,7 @@ class SparseVector(Vector):
     A simple sparse vector class for passing data to MLlib. Users may
     alternatively pass SciPy's {scipy.sparse} data types.
     """
+
     def __init__(self, size, *args):
         """
         Create a sparse vector, using either a dictionary, a list of
@@ -523,14 +558,17 @@ class SparseVector(Vector):
                 if self.indices[i] >= self.indices[i + 1]:
                     raise TypeError(
                         "Indices %s and %s are not strictly increasing"
-                        % (self.indices[i], self.indices[i + 1]))
+                        % (self.indices[i], self.indices[i + 1])
+                    )
 
         if self.indices.size > 0:
-            assert np.max(self.indices) < self.size, \
-                "Index %d is out of the size of vector with size=%d" \
-                % (np.max(self.indices), self.size)
-            assert np.min(self.indices) >= 0, \
-                "Contains negative index %d" % (np.min(self.indices))
+            assert (
+                np.max(self.indices) < self.size
+            ), "Index %d is out of the size of vector with size=%d" % (
+                np.max(self.indices),
+                self.size,
+            )
+            assert np.min(self.indices) >= 0, "Contains negative index %d" % (np.min(self.indices))
 
     def numNonzeros(self):
         """
@@ -553,9 +591,7 @@ class SparseVector(Vector):
         return np.linalg.norm(self.values, p)
 
     def __reduce__(self):
-        return (
-            SparseVector,
-            (self.size, self.indices.tostring(), self.values.tostring()))
+        return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring()))
 
     def dot(self, other):
         """
@@ -646,8 +682,9 @@ class SparseVector(Vector):
 
         if isinstance(other, np.ndarray) or isinstance(other, DenseVector):
             if isinstance(other, np.ndarray) and other.ndim != 1:
-                raise ValueError("Cannot call squared_distance with %d-dimensional array" %
-                                 other.ndim)
+                raise ValueError(
+                    "Cannot call squared_distance with %d-dimensional array" % other.ndim
+                )
             if isinstance(other, DenseVector):
                 other = other.array
             sparse_ind = np.zeros(other.size, dtype=bool)
@@ -703,14 +740,18 @@ class SparseVector(Vector):
     def __repr__(self):
         inds = self.indices
         vals = self.values
-        entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i]))
-                             for i in range(len(inds))])
+        entries = ", ".join(
+            ["{0}: {1}".format(inds[i], _format_float(vals[i])) for i in range(len(inds))]
+        )
         return "SparseVector({0}, {{{1}}})".format(self.size, entries)
 
     def __eq__(self, other):
         if isinstance(other, SparseVector):
-            return other.size == self.size and np.array_equal(other.indices, self.indices) \
+            return (
+                other.size == self.size
+                and np.array_equal(other.indices, self.indices)
                 and np.array_equal(other.values, self.values)
+            )
         elif isinstance(other, DenseVector):
             if self.size != len(other):
                 return False
@@ -721,8 +762,7 @@ class SparseVector(Vector):
         inds = self.indices
         vals = self.values
         if not isinstance(index, int):
-            raise TypeError(
-                "Indices must be of type integer, got type %s" % type(index))
+            raise TypeError("Indices must be of type integer, got type %s" % type(index))
 
         if index >= self.size or index < -self.size:
             raise IndexError("Index %d out of bounds." % index)
@@ -730,13 +770,13 @@ class SparseVector(Vector):
             index += self.size
 
         if (inds.size == 0) or (index > inds.item(-1)):
-            return 0.
+            return 0.0
 
         insert_index = np.searchsorted(inds, index)
         row_ind = inds[insert_index]
         if row_ind == index:
             return vals[insert_index]
-        return 0.
+        return 0.0
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -872,6 +912,7 @@ class Matrix(object):
     """
     Represents a local matrix.
     """
+
     def __init__(self, numRows, numCols, isTransposed=False):
         self.numRows = numRows
         self.numCols = numCols
@@ -897,6 +938,7 @@ class DenseMatrix(Matrix):
     """
     Column-major dense matrix.
     """
+
     def __init__(self, numRows, numCols, values, isTransposed=False):
         Matrix.__init__(self, numRows, numCols, isTransposed)
         values = self._convert_to_array(values, np.float64)
@@ -905,8 +947,11 @@ class DenseMatrix(Matrix):
 
     def __reduce__(self):
         return DenseMatrix, (
-            self.numRows, self.numCols, self.values.tostring(),
-            int(self.isTransposed))
+            self.numRows,
+            self.numCols,
+            self.values.tostring(),
+            int(self.isTransposed),
+        )
 
     def __str__(self):
         """
@@ -928,7 +973,7 @@ class DenseMatrix(Matrix):
 
         # We need to adjust six spaces which is the difference in number
         # of letters between "DenseMatrix" and "array"
-        x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]])
+        x = "\n".join([(" " * 6 + line) for line in array_lines[1:]])
         return array_lines[0].replace("array", "DenseMatrix") + "\n" + x
 
     def __repr__(self):
@@ -947,14 +992,13 @@ class DenseMatrix(Matrix):
             entries = _format_float_list(self.values)
         else:
             entries = (
-                _format_float_list(self.values[:8]) +
-                ["..."] +
-                _format_float_list(self.values[-8:])
+                _format_float_list(self.values[:8]) + ["..."] + _format_float_list(self.values[-8:])
             )
 
         entries = ", ".join(entries)
         return "DenseMatrix({0}, {1}, [{2}], {3})".format(
-            self.numRows, self.numCols, entries, self.isTransposed)
+            self.numRows, self.numCols, entries, self.isTransposed
+        )
 
     def toArray(self):
         """
@@ -968,21 +1012,19 @@ class DenseMatrix(Matrix):
                [ 1.,  3.]])
         """
         if self.isTransposed:
-            return np.asfortranarray(
-                self.values.reshape((self.numRows, self.numCols)))
+            return np.asfortranarray(self.values.reshape((self.numRows, self.numCols)))
         else:
-            return self.values.reshape((self.numRows, self.numCols), order='F')
+            return self.values.reshape((self.numRows, self.numCols), order="F")
 
     def toSparse(self):
         """Convert to SparseMatrix"""
         if self.isTransposed:
-            values = np.ravel(self.toArray(), order='F')
+            values = np.ravel(self.toArray(), order="F")
         else:
             values = self.values
         indices = np.nonzero(values)[0]
         colCounts = np.bincount(indices // self.numRows)
-        colPtrs = np.cumsum(np.hstack(
-            (0, colCounts, np.zeros(self.numCols - colCounts.size))))
+        colPtrs = np.cumsum(np.hstack((0, colCounts, np.zeros(self.numCols - colCounts.size))))
         values = values[indices]
         rowIndices = indices % self.numRows
 
@@ -991,11 +1033,9 @@ class DenseMatrix(Matrix):
     def __getitem__(self, indices):
         i, j = indices
         if i < 0 or i >= self.numRows:
-            raise IndexError("Row index %d is out of range [0, %d)"
-                             % (i, self.numRows))
+            raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows))
         if j >= self.numCols or j < 0:
-            raise IndexError("Column index %d is out of range [0, %d)"
-                             % (j, self.numCols))
+            raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols))
 
         if self.isTransposed:
             return self.values[i * self.numCols + j]
@@ -1003,20 +1043,20 @@ class DenseMatrix(Matrix):
             return self.values[i + j * self.numRows]
 
     def __eq__(self, other):
-        if (self.numRows != other.numRows or self.numCols != other.numCols):
+        if self.numRows != other.numRows or self.numCols != other.numCols:
             return False
         if isinstance(other, SparseMatrix):
             return np.all(self.toArray() == other.toArray())
 
-        self_values = np.ravel(self.toArray(), order='F')
-        other_values = np.ravel(other.toArray(), order='F')
+        self_values = np.ravel(self.toArray(), order="F")
+        other_values = np.ravel(other.toArray(), order="F")
         return np.all(self_values == other_values)
 
 
 class SparseMatrix(Matrix):
     """Sparse Matrix stored in CSC format."""
-    def __init__(self, numRows, numCols, colPtrs, rowIndices, values,
-                 isTransposed=False):
+
+    def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=False):
         Matrix.__init__(self, numRows, numCols, isTransposed)
         self.colPtrs = self._convert_to_array(colPtrs, np.int32)
         self.rowIndices = self._convert_to_array(rowIndices, np.int32)
@@ -1024,15 +1064,19 @@ class SparseMatrix(Matrix):
 
         if self.isTransposed:
             if self.colPtrs.size != numRows + 1:
-                raise ValueError("Expected colPtrs of size %d, got %d."
-                                 % (numRows + 1, self.colPtrs.size))
+                raise ValueError(
+                    "Expected colPtrs of size %d, got %d." % (numRows + 1, self.colPtrs.size)
+                )
         else:
             if self.colPtrs.size != numCols + 1:
-                raise ValueError("Expected colPtrs of size %d, got %d."
-                                 % (numCols + 1, self.colPtrs.size))
+                raise ValueError(
+                    "Expected colPtrs of size %d, got %d." % (numCols + 1, self.colPtrs.size)
+                )
         if self.rowIndices.size != self.values.size:
-            raise ValueError("Expected rowIndices of length %d, got %d."
-                             % (self.rowIndices.size, self.values.size))
+            raise ValueError(
+                "Expected rowIndices of length %d, got %d."
+                % (self.rowIndices.size, self.values.size)
+            )
 
     def __str__(self):
         """
@@ -1071,11 +1115,9 @@ class SparseMatrix(Matrix):
             if self.colPtrs[cur_col + 1] <= i:
                 cur_col += 1
             if self.isTransposed:
-                smlist.append('({0},{1}) {2}'.format(
-                    cur_col, rowInd, _format_float(value)))
+                smlist.append("({0},{1}) {2}".format(cur_col, rowInd, _format_float(value)))
             else:
-                smlist.append('({0},{1}) {2}'.format(
-                    rowInd, cur_col, _format_float(value)))
+                smlist.append("({0},{1}) {2}".format(rowInd, cur_col, _format_float(value)))
         spstr += "\n".join(smlist)
 
         if len(self.values) > 16:
@@ -1100,9 +1142,7 @@ class SparseMatrix(Matrix):
 
         else:
             values = (
-                _format_float_list(self.values[:8]) +
-                ["..."] +
-                _format_float_list(self.values[-8:])
+                _format_float_list(self.values[:8]) + ["..."] + _format_float_list(self.values[-8:])
             )
             rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:]
 
@@ -1113,23 +1153,25 @@ class SparseMatrix(Matrix):
         rowIndices = ", ".join([str(ind) for ind in rowIndices])
         colPtrs = ", ".join([str(ptr) for ptr in colPtrs])
         return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format(
-            self.numRows, self.numCols, colPtrs, rowIndices,
-            values, self.isTransposed)
+            self.numRows, self.numCols, colPtrs, rowIndices, values, self.isTransposed
+        )
 
     def __reduce__(self):
         return SparseMatrix, (
-            self.numRows, self.numCols, self.colPtrs.tostring(),
-            self.rowIndices.tostring(), self.values.tostring(),
-            int(self.isTransposed))
+            self.numRows,
+            self.numCols,
+            self.colPtrs.tostring(),
+            self.rowIndices.tostring(),
+            self.values.tostring(),
+            int(self.isTransposed),
+        )
 
     def __getitem__(self, indices):
         i, j = indices
         if i < 0 or i >= self.numRows:
-            raise IndexError("Row index %d is out of range [0, %d)"
-                             % (i, self.numRows))
+            raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows))
         if j < 0 or j >= self.numCols:
-            raise IndexError("Column index %d is out of range [0, %d)"
-                             % (j, self.numCols))
+            raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols))
 
         # If a CSR matrix is given, then the row index should be searched
         # for in ColPtrs, and the column index should be searched for in the
@@ -1139,7 +1181,7 @@ class SparseMatrix(Matrix):
 
         colStart = self.colPtrs[j]
         colEnd = self.colPtrs[j + 1]
-        nz = self.rowIndices[colStart: colEnd]
+        nz = self.rowIndices[colStart:colEnd]
         ind = np.searchsorted(nz, i) + colStart
         if ind < colEnd and self.rowIndices[ind] == i:
             return self.values[ind]
@@ -1150,7 +1192,7 @@ class SparseMatrix(Matrix):
         """
         Return a numpy.ndarray
         """
-        A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F')
+        A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order="F")
         for k in range(self.colPtrs.size - 1):
             startptr = self.colPtrs[k]
             endptr = self.colPtrs[k + 1]
@@ -1161,7 +1203,7 @@ class SparseMatrix(Matrix):
         return A
 
     def toDense(self):
-        densevals = np.ravel(self.toArray(), order='F')
+        densevals = np.ravel(self.toArray(), order="F")
         return DenseMatrix(self.numRows, self.numCols, densevals)
 
     # TODO: More efficient implementation:
@@ -1187,14 +1229,16 @@ class Matrices(object):
 
 def _test():
     import doctest
+
     try:
         # Numpy 1.14+ changed it's string format.
-        np.set_printoptions(legacy='1.13')
+        np.set_printoptions(legacy="1.13")
     except TypeError:
         pass
     (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
     if failure_count:
         sys.exit(-1)
 
+
 if __name__ == "__main__":
     _test()
diff --git a/python/pyspark/ml/linalg/__init__.pyi b/python/pyspark/ml/linalg/__init__.pyi
index 46bd812..bb09397 100644
--- a/python/pyspark/ml/linalg/__init__.pyi
+++ b/python/pyspark/ml/linalg/__init__.pyi
@@ -46,9 +46,7 @@ class MatrixUDT(UserDefinedType):
     def scalaUDT(cls) -> str: ...
     def serialize(
         self, obj: Matrix
-    ) -> Tuple[
-        int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool
-    ]: ...
+    ) -> Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool]: ...
     def deserialize(self, datum: Any) -> Matrix: ...
     def simpleString(self) -> str: ...
 
@@ -101,9 +99,7 @@ class SparseVector(Vector):
     @overload
     def __init__(self, size: int, __indices: bytes, __values: bytes) -> None: ...
     @overload
-    def __init__(
-        self, size: int, __indices: Iterable[int], __values: Iterable[float]
-    ) -> None: ...
+    def __init__(self, size: int, __indices: Iterable[int], __values: Iterable[float]) -> None: ...
     @overload
     def __init__(self, size: int, __pairs: Iterable[Tuple[int, float]]) -> None: ...
     @overload
@@ -129,9 +125,7 @@ class Vectors:
     def sparse(size: int, __indices: bytes, __values: bytes) -> SparseVector: ...
     @overload
     @staticmethod
-    def sparse(
-        size: int, __indices: Iterable[int], __values: Iterable[float]
-    ) -> SparseVector: ...
+    def sparse(size: int, __indices: Iterable[int], __values: Iterable[float]) -> SparseVector: ...
     @overload
     @staticmethod
     def sparse(size: int, __pairs: Iterable[Tuple[int, float]]) -> SparseVector: ...
@@ -161,9 +155,7 @@ class Matrix:
     numRows: int
     numCols: int
     isTransposed: bool
-    def __init__(
-        self, numRows: int, numCols: int, isTransposed: bool = ...
-    ) -> None: ...
+    def __init__(self, numRows: int, numCols: int, isTransposed: bool = ...) -> None: ...
     def toArray(self) -> ndarray: ...
 
 class DenseMatrix(Matrix):
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index ab3491c..e011c50 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -25,7 +25,7 @@ from pyspark.ml.linalg import DenseVector, Vector, Matrix
 from pyspark.ml.util import Identifiable
 
 
-__all__ = ['Param', 'Params', 'TypeConverters']
+__all__ = ["Param", "Params", "TypeConverters"]
 
 
 class Param(object):
@@ -78,7 +78,7 @@ class TypeConverters(object):
     @staticmethod
     def _is_numeric(value):
         vtype = type(value)
-        return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long'
+        return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == "long"
 
     @staticmethod
     def _is_integer(value):
@@ -263,9 +263,16 @@ class Params(Identifiable, metaclass=ABCMeta):
         :py:class:`Param`.
         """
         if self._params is None:
-            self._params = list(filter(lambda attr: isinstance(attr, Param),
-                                       [getattr(self, x) for x in dir(self) if x != "params" and
-                                        not isinstance(getattr(type(self), x, None), property)]))
+            self._params = list(
+                filter(
+                    lambda attr: isinstance(attr, Param),
+                    [
+                        getattr(self, x)
+                        for x in dir(self)
+                        if x != "params" and not isinstance(getattr(type(self), x, None), property)
+                    ],
+                )
+            )
         return self._params
 
     def explainParam(self, param):
@@ -484,8 +491,9 @@ class Params(Identifiable, metaclass=ABCMeta):
                 try:
                     value = p.typeConverter(value)
                 except TypeError as e:
-                    raise TypeError('Invalid default param value given for param "%s". %s'
-                                    % (p.name, e))
+                    raise TypeError(
+                        'Invalid default param value given for param "%s". %s' % (p.name, e)
+                    )
             self._defaultParamMap[p] = value
         return self
 
@@ -512,11 +520,13 @@ class Params(Identifiable, metaclass=ABCMeta):
                 if isinstance(param, Param):
                     paramMap[param] = value
                 else:
-                    raise TypeError("Expecting a valid instance of Param, but received: {}"
-                                    .format(param))
+                    raise TypeError(
+                        "Expecting a valid instance of Param, but received: {}".format(param)
+                    )
         elif extra is not None:
-            raise TypeError("Expecting a dict, but received an object of type {}."
-                            .format(type(extra)))
+            raise TypeError(
+                "Expecting a dict, but received an object of type {}.".format(type(extra))
+            )
         for param in self.params:
             # copy default params
             if param in self._defaultParamMap and to.hasParam(param.name):
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index bcab51f..f8b3d1f 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -54,18 +54,19 @@ def _gen_param_header(name, doc, defaultValueStr, typeConverter):
         super(Has$Name, self).__init__()'''
 
     if defaultValueStr is not None:
-        template += '''
-        self._setDefault($name=$defaultValueStr)'''
+        template += """
+        self._setDefault($name=$defaultValueStr)"""
 
     Name = name[0].upper() + name[1:]
     if typeConverter is None:
         typeConverter = str(None)
-    return template \
-        .replace("$name", name) \
-        .replace("$Name", Name) \
-        .replace("$doc", doc) \
-        .replace("$defaultValueStr", str(defaultValueStr)) \
+    return (
+        template.replace("$name", name)
+        .replace("$Name", Name)
+        .replace("$doc", doc)
+        .replace("$defaultValueStr", str(defaultValueStr))
         .replace("$typeConverter", typeConverter)
+    )
 
 
 def _gen_param_code(name, doc, defaultValueStr):
@@ -86,11 +87,13 @@ def _gen_param_code(name, doc, defaultValueStr):
         return self.getOrDefault(self.$name)'''
 
     Name = name[0].upper() + name[1:]
-    return template \
-        .replace("$name", name) \
-        .replace("$Name", Name) \
-        .replace("$doc", doc) \
+    return (
+        template.replace("$name", name)
+        .replace("$Name", Name)
+        .replace("$doc", doc)
         .replace("$defaultValueStr", str(defaultValueStr))
+    )
+
 
 if __name__ == "__main__":
     print(header)
@@ -102,74 +105,169 @@ if __name__ == "__main__":
         ("featuresCol", "features column name.", "'features'", "TypeConverters.toString"),
         ("labelCol", "label column name.", "'label'", "TypeConverters.toString"),
         ("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"),
-        ("probabilityCol", "Column name for predicted class conditional probabilities. " +
-         "Note: Not all models output well-calibrated probability estimates! These probabilities " +
-         "should be treated as confidences, not precise probabilities.", "'probability'",
-         "TypeConverters.toString"),
... 39109 lines suppressed ...

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org