You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/11/10 09:40:30 UTC
[spark] branch master updated: [SPARK-37022][PYTHON] Use black as a
formatter for PySpark
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2fe9af8 [SPARK-37022][PYTHON] Use black as a formatter for PySpark
2fe9af8 is described below
commit 2fe9af8b2b91d0a46782dd6fff57eca8609be105
Author: zero323 <ms...@gmail.com>
AuthorDate: Wed Nov 10 18:39:06 2021 +0900
[SPARK-37022][PYTHON] Use black as a formatter for PySpark
### What changes were proposed in this pull request?
This PR applies `black` (21.5b2) formatting to the whole `python/pyspark` source tree.
Additionally, the following changes were made:
- Disabled E501 (line too long) in pycodestyle config ‒ black allows line to exceed `line-length` in cases of inline comments. There are 15 cases like this, all listed below
```
pycodestyle checks failed:
./python/pyspark/sql/catalog.py:349:101: E501 line too long (103 > 100 characters)
./python/pyspark/sql/session.py:652:101: E501 line too long (106 > 100 characters)
./python/pyspark/sql/utils.py:50:101: E501 line too long (108 > 100 characters)
./python/pyspark/sql/streaming.py:1063:101: E501 line too long (128 > 100 characters)
./python/pyspark/sql/streaming.py:1071:101: E501 line too long (112 > 100 characters)
./python/pyspark/sql/streaming.py:1080:101: E501 line too long (124 > 100 characters)
./python/pyspark/sql/streaming.py:1259:101: E501 line too long (134 > 100 characters)
./python/pyspark/sql/pandas/conversion.py:136:101: E501 line too long (106 > 100 characters)
./python/pyspark/ml/param/_shared_params_code_gen.py:111:101: E501 line too long (103 > 100 characters)
./python/pyspark/ml/param/_shared_params_code_gen.py:136:101: E501 line too long (105 > 100 characters)
./python/pyspark/ml/param/_shared_params_code_gen.py:163:101: E501 line too long (101 > 100 characters)
./python/pyspark/ml/param/_shared_params_code_gen.py:233:101: E501 line too long (101 > 100 characters)
./python/pyspark/ml/param/_shared_params_code_gen.py:265:101: E501 line too long (101 > 100 characters)
./python/pyspark/tests/test_readwrite.py:235:101: E501 line too long (114 > 100 characters)
./python/pyspark/tests/test_readwrite.py:336:101: E501 line too long (114 > 100 characters)
```
- After reformatting, minor typing changes were made:
- Realign certain `type: ignore` comments with ignored code.
- Apply explicit `casts` to ` unittest.skipIf` messages. The following
```python
unittest.skipIf(
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
) # type: ignore[arg-type]
```
replaced with
```python
unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
```
### Why are the changes needed?
Consistency and reduced maintenance overhead.
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
Existing liners and tests.
Closes #34297 from zero323/SPARK-37022.
Authored-by: zero323 <ms...@gmail.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
dev/lint-python | 2 +-
dev/pyproject.toml | 6 +
dev/reformat-python | 3 +-
dev/tox.ini | 2 +-
python/pyspark/__init__.py | 41 +-
python/pyspark/_globals.py | 7 +-
python/pyspark/accumulators.py | 20 +-
python/pyspark/accumulators.pyi | 4 +-
python/pyspark/broadcast.py | 20 +-
python/pyspark/conf.py | 22 +-
python/pyspark/context.py | 309 ++--
python/pyspark/daemon.py | 13 +-
python/pyspark/files.py | 5 +-
python/pyspark/find_spark_home.py | 15 +-
python/pyspark/install.py | 38 +-
python/pyspark/java_gateway.py | 31 +-
python/pyspark/join.py | 5 +
python/pyspark/ml/__init__.py | 52 +-
python/pyspark/ml/_typing.pyi | 4 +-
python/pyspark/ml/base.py | 34 +-
python/pyspark/ml/base.pyi | 8 +-
python/pyspark/ml/classification.py | 1259 ++++++++++-----
python/pyspark/ml/classification.pyi | 67 +-
python/pyspark/ml/clustering.py | 535 +++++--
python/pyspark/ml/clustering.pyi | 28 +-
python/pyspark/ml/common.py | 29 +-
python/pyspark/ml/evaluation.py | 394 +++--
python/pyspark/ml/evaluation.pyi | 36 +-
python/pyspark/ml/feature.py | 1633 ++++++++++++++------
python/pyspark/ml/feature.pyi | 228 +--
python/pyspark/ml/fpm.py | 149 +-
python/pyspark/ml/fpm.pyi | 12 +-
python/pyspark/ml/functions.py | 22 +-
python/pyspark/ml/functions.pyi | 1 -
python/pyspark/ml/image.py | 34 +-
python/pyspark/ml/linalg/__init__.py | 240 +--
python/pyspark/ml/linalg/__init__.pyi | 16 +-
python/pyspark/ml/param/__init__.py | 32 +-
python/pyspark/ml/param/_shared_params_code_gen.py | 242 ++-
python/pyspark/ml/param/shared.py | 238 ++-
python/pyspark/ml/pipeline.py | 52 +-
python/pyspark/ml/pipeline.pyi | 4 +-
python/pyspark/ml/recommendation.py | 206 ++-
python/pyspark/ml/recommendation.pyi | 16 +-
python/pyspark/ml/regression.py | 956 +++++++++---
python/pyspark/ml/regression.pyi | 44 +-
python/pyspark/ml/stat.py | 35 +-
python/pyspark/ml/stat.pyi | 8 +-
python/pyspark/ml/tests/test_algorithms.py | 245 +--
python/pyspark/ml/tests/test_base.py | 26 +-
python/pyspark/ml/tests/test_evaluation.py | 20 +-
python/pyspark/ml/tests/test_feature.py | 254 ++-
python/pyspark/ml/tests/test_image.py | 32 +-
python/pyspark/ml/tests/test_linalg.py | 145 +-
python/pyspark/ml/tests/test_param.py | 129 +-
python/pyspark/ml/tests/test_persistence.py | 214 ++-
python/pyspark/ml/tests/test_pipeline.py | 10 +-
python/pyspark/ml/tests/test_stat.py | 16 +-
python/pyspark/ml/tests/test_training_summary.py | 152 +-
python/pyspark/ml/tests/test_tuning.py | 650 ++++----
python/pyspark/ml/tests/test_util.py | 43 +-
python/pyspark/ml/tests/test_wrapper.py | 33 +-
python/pyspark/ml/tree.py | 229 ++-
python/pyspark/ml/tree.pyi | 4 +-
python/pyspark/ml/tuning.py | 475 +++---
python/pyspark/ml/tuning.pyi | 12 +-
python/pyspark/ml/util.py | 95 +-
python/pyspark/ml/wrapper.py | 13 +-
python/pyspark/ml/wrapper.pyi | 4 +-
python/pyspark/mllib/__init__.py | 17 +-
python/pyspark/mllib/classification.py | 226 ++-
python/pyspark/mllib/clustering.py | 233 ++-
python/pyspark/mllib/clustering.pyi | 20 +-
python/pyspark/mllib/common.py | 36 +-
python/pyspark/mllib/evaluation.py | 125 +-
python/pyspark/mllib/evaluation.pyi | 4 +-
python/pyspark/mllib/feature.py | 137 +-
python/pyspark/mllib/feature.pyi | 4 +-
python/pyspark/mllib/fpm.py | 18 +-
python/pyspark/mllib/fpm.pyi | 4 +-
python/pyspark/mllib/linalg/__init__.py | 289 ++--
python/pyspark/mllib/linalg/__init__.pyi | 16 +-
python/pyspark/mllib/linalg/distributed.py | 130 +-
python/pyspark/mllib/linalg/distributed.pyi | 12 +-
python/pyspark/mllib/random.py | 57 +-
python/pyspark/mllib/recommendation.py | 65 +-
python/pyspark/mllib/recommendation.pyi | 8 +-
python/pyspark/mllib/regression.py | 194 ++-
python/pyspark/mllib/regression.pyi | 12 +-
python/pyspark/mllib/stat/KernelDensity.py | 4 +-
python/pyspark/mllib/stat/__init__.py | 11 +-
python/pyspark/mllib/stat/_statistics.py | 18 +-
python/pyspark/mllib/stat/_statistics.pyi | 12 +-
python/pyspark/mllib/stat/distribution.py | 3 +-
python/pyspark/mllib/tests/test_algorithms.py | 144 +-
python/pyspark/mllib/tests/test_feature.py | 70 +-
python/pyspark/mllib/tests/test_linalg.py | 185 ++-
python/pyspark/mllib/tests/test_stat.py | 45 +-
.../mllib/tests/test_streaming_algorithms.py | 117 +-
python/pyspark/mllib/tests/test_util.py | 24 +-
python/pyspark/mllib/tree.py | 275 +++-
python/pyspark/mllib/tree.pyi | 4 +-
python/pyspark/mllib/util.py | 42 +-
python/pyspark/pandas/base.py | 14 +-
python/pyspark/pandas/config.py | 2 +-
python/pyspark/pandas/data_type_ops/base.py | 4 +-
.../pandas/data_type_ops/categorical_ops.py | 2 +-
python/pyspark/pandas/frame.py | 12 +-
python/pyspark/pandas/generic.py | 6 +-
python/pyspark/pandas/groupby.py | 2 +-
python/pyspark/pandas/indexes/base.py | 4 +-
python/pyspark/pandas/namespace.py | 24 +-
python/pyspark/pandas/plot/matplotlib.py | 6 +-
python/pyspark/pandas/series.py | 4 +-
python/pyspark/pandas/sql_processor.py | 2 +-
python/pyspark/pandas/utils.py | 6 +-
python/pyspark/profiler.py | 20 +-
python/pyspark/profiler.pyi | 4 +-
python/pyspark/rdd.py | 342 ++--
python/pyspark/rdd.pyi | 40 +-
python/pyspark/rddsampler.py | 4 -
python/pyspark/resource/__init__.py | 16 +-
python/pyspark/resource/profile.py | 70 +-
python/pyspark/resource/requests.py | 65 +-
python/pyspark/resource/tests/test_resources.py | 7 +-
python/pyspark/serializers.py | 65 +-
python/pyspark/shell.py | 20 +-
python/pyspark/shuffle.py | 118 +-
python/pyspark/sql/__init__.py | 21 +-
python/pyspark/sql/avro/__init__.py | 2 +-
python/pyspark/sql/avro/functions.py | 26 +-
python/pyspark/sql/catalog.py | 106 +-
python/pyspark/sql/column.py | 122 +-
python/pyspark/sql/conf.py | 15 +-
python/pyspark/sql/context.py | 120 +-
python/pyspark/sql/dataframe.py | 338 ++--
python/pyspark/sql/functions.py | 235 ++-
python/pyspark/sql/group.py | 58 +-
python/pyspark/sql/observation.py | 14 +-
python/pyspark/sql/pandas/_typing/__init__.pyi | 8 +-
.../pyspark/sql/pandas/_typing/protocols/frame.pyi | 38 +-
.../sql/pandas/_typing/protocols/series.pyi | 38 +-
python/pyspark/sql/pandas/conversion.py | 192 ++-
python/pyspark/sql/pandas/functions.py | 69 +-
python/pyspark/sql/pandas/functions.pyi | 16 +-
python/pyspark/sql/pandas/group_ops.py | 39 +-
python/pyspark/sql/pandas/map_ops.py | 19 +-
python/pyspark/sql/pandas/serializers.py | 81 +-
python/pyspark/sql/pandas/typehints.py | 123 +-
python/pyspark/sql/pandas/types.py | 94 +-
python/pyspark/sql/pandas/utils.py | 39 +-
python/pyspark/sql/readwriter.py | 260 ++--
python/pyspark/sql/session.py | 192 ++-
python/pyspark/sql/streaming.py | 191 ++-
python/pyspark/sql/tests/test_arrow.py | 258 +++-
python/pyspark/sql/tests/test_catalog.py | 227 +--
python/pyspark/sql/tests/test_column.py | 111 +-
python/pyspark/sql/tests/test_conf.py | 4 +-
python/pyspark/sql/tests/test_context.py | 75 +-
python/pyspark/sql/tests/test_dataframe.py | 659 +++++---
python/pyspark/sql/tests/test_datasources.py | 147 +-
python/pyspark/sql/tests/test_functions.py | 482 +++---
python/pyspark/sql/tests/test_group.py | 12 +-
.../pyspark/sql/tests/test_pandas_cogrouped_map.py | 215 +--
.../pyspark/sql/tests/test_pandas_grouped_map.py | 567 ++++---
python/pyspark/sql/tests/test_pandas_map.py | 47 +-
python/pyspark/sql/tests/test_pandas_udf.py | 136 +-
.../sql/tests/test_pandas_udf_grouped_agg.py | 462 +++---
python/pyspark/sql/tests/test_pandas_udf_scalar.py | 860 ++++++-----
.../pyspark/sql/tests/test_pandas_udf_typehints.py | 141 +-
python/pyspark/sql/tests/test_pandas_udf_window.py | 244 +--
python/pyspark/sql/tests/test_readwriter.py | 66 +-
python/pyspark/sql/tests/test_serde.py | 28 +-
python/pyspark/sql/tests/test_session.py | 111 +-
python/pyspark/sql/tests/test_streaming.py | 200 +--
python/pyspark/sql/tests/test_types.py | 690 +++++----
python/pyspark/sql/tests/test_udf.py | 281 ++--
python/pyspark/sql/tests/test_utils.py | 33 +-
python/pyspark/sql/types.py | 474 +++---
python/pyspark/sql/udf.py | 125 +-
python/pyspark/sql/utils.py | 84 +-
python/pyspark/sql/window.py | 11 +-
python/pyspark/statcounter.py | 30 +-
python/pyspark/status.py | 3 +
python/pyspark/storagelevel.py | 8 +-
python/pyspark/streaming/__init__.py | 2 +-
python/pyspark/streaming/context.py | 33 +-
python/pyspark/streaming/context.pyi | 8 +-
python/pyspark/streaming/dstream.py | 99 +-
python/pyspark/streaming/dstream.pyi | 16 +-
python/pyspark/streaming/kinesis.py | 18 +-
python/pyspark/streaming/listener.py | 1 -
python/pyspark/streaming/tests/test_context.py | 13 +-
python/pyspark/streaming/tests/test_dstream.py | 241 +--
python/pyspark/streaming/tests/test_kinesis.py | 57 +-
python/pyspark/streaming/tests/test_listener.py | 7 +-
python/pyspark/streaming/util.py | 18 +-
python/pyspark/taskcontext.py | 25 +-
python/pyspark/testing/mllibutils.py | 2 +-
python/pyspark/testing/mlutils.py | 85 +-
python/pyspark/testing/pandasutils.py | 22 +-
python/pyspark/testing/sqlutils.py | 16 +-
python/pyspark/testing/streamingutils.py | 18 +-
python/pyspark/testing/utils.py | 22 +-
python/pyspark/tests/test_appsubmit.py | 168 +-
python/pyspark/tests/test_broadcast.py | 8 +-
python/pyspark/tests/test_conf.py | 3 +-
python/pyspark/tests/test_context.py | 66 +-
python/pyspark/tests/test_daemon.py | 7 +-
python/pyspark/tests/test_install_spark.py | 59 +-
python/pyspark/tests/test_join.py | 4 +-
python/pyspark/tests/test_pin_thread.py | 15 +-
python/pyspark/tests/test_profiler.py | 12 +-
python/pyspark/tests/test_rdd.py | 172 ++-
python/pyspark/tests/test_rddbarrier.py | 4 +-
python/pyspark/tests/test_readwrite.py | 286 ++--
python/pyspark/tests/test_serializers.py | 86 +-
python/pyspark/tests/test_shuffle.py | 40 +-
python/pyspark/tests/test_taskcontext.py | 37 +-
python/pyspark/tests/test_util.py | 6 +-
python/pyspark/tests/test_worker.py | 31 +-
python/pyspark/traceback_utils.py | 5 +-
python/pyspark/util.py | 56 +-
python/pyspark/worker.py | 193 ++-
224 files changed, 15237 insertions(+), 9349 deletions(-)
diff --git a/dev/lint-python b/dev/lint-python
index d77d645..30b1816 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -256,7 +256,7 @@ function black_test {
echo "starting black test..."
# Black is only applied for pandas API on Spark for now.
- BLACK_REPORT=$( ($BLACK_BUILD python/pyspark/pandas --line-length 100 --check ) 2>&1)
+ BLACK_REPORT=$( ($BLACK_BUILD --config dev/pyproject.toml --check python/pyspark) 2>&1)
BLACK_STATUS=$?
if [ "$BLACK_STATUS" -ne 0 ]; then
diff --git a/dev/pyproject.toml b/dev/pyproject.toml
index 286b728..b90257f 100644
--- a/dev/pyproject.toml
+++ b/dev/pyproject.toml
@@ -23,3 +23,9 @@ testpaths = [
"pyspark/sql/tests/typing",
"pyspark/ml/typing",
]
+
+[tool.black]
+line-length = 100
+target-version = ['py37']
+include = '\.pyi?$'
+extend-exclude = 'cloudpickle'
diff --git a/dev/reformat-python b/dev/reformat-python
index 1b5ee65..0b712d5 100755
--- a/dev/reformat-python
+++ b/dev/reformat-python
@@ -28,5 +28,4 @@ if [ $? -ne 0 ]; then
exit 1
fi
-# This script is only applied for pandas API on Spark for now.
-$BLACK_BUILD python/pyspark/pandas --line-length 100
+$BLACK_BUILD --config python/pyproject.toml python/pyspark
diff --git a/dev/tox.ini b/dev/tox.ini
index e1a4cf5..bd69a3f 100644
--- a/dev/tox.ini
+++ b/dev/tox.ini
@@ -14,7 +14,7 @@
# limitations under the License.
[pycodestyle]
-ignore=E203,E226,E241,E305,E402,E722,E731,E741,W503,W504
+ignore=E203,E226,E241,E305,E402,E722,E731,E741,W503,W504,E501
max-line-length=100
exclude=*/target/*,python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,dev/ansible-for-test-node/*
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 9651b53..70392fb 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -69,13 +69,15 @@ def since(version):
A decorator that annotates a function to append the version of Spark the function was added.
"""
import re
- indent_p = re.compile(r'\n( +)')
+
+ indent_p = re.compile(r"\n( +)")
def deco(f):
indents = indent_p.findall(f.__doc__)
- indent = ' ' * (min(len(m) for m in indents) if indents else 0)
+ indent = " " * (min(len(m) for m in indents) if indents else 0)
f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
return f
+
return deco
@@ -86,8 +88,9 @@ def copy_func(f, name=None, sinceversion=None, doc=None):
"""
# See
# http://stackoverflow.com/questions/6527633/how-can-i-make-a-deepcopy-of-a-function-in-python
- fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__, f.__defaults__,
- f.__closure__)
+ fn = types.FunctionType(
+ f.__code__, f.__globals__, name or f.__name__, f.__defaults__, f.__closure__
+ )
# in case f was given attrs (note this dict is a shallow copy):
fn.__dict__.update(f.__dict__)
if doc is not None:
@@ -106,14 +109,17 @@ def keyword_only(func):
-----
Should only be used to wrap a method where first arg is `self`
"""
+
@wraps(func)
def wrapper(self, *args, **kwargs):
if len(args) > 0:
raise TypeError("Method %s forces keyword arguments." % func.__name__)
self._input_kwargs = kwargs
return func(self, **kwargs)
+
return wrapper
+
# To avoid circular dependencies
from pyspark.context import SparkContext
@@ -121,9 +127,26 @@ from pyspark.context import SparkContext
from pyspark.sql import SQLContext, HiveContext, Row # noqa: F401
__all__ = [
- "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
- "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
- "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext",
- "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "InheritableThread",
- "inheritable_thread_target", "__version__",
+ "SparkConf",
+ "SparkContext",
+ "SparkFiles",
+ "RDD",
+ "StorageLevel",
+ "Broadcast",
+ "Accumulator",
+ "AccumulatorParam",
+ "MarshalSerializer",
+ "PickleSerializer",
+ "StatusTracker",
+ "SparkJobInfo",
+ "SparkStageInfo",
+ "Profiler",
+ "BasicProfiler",
+ "TaskContext",
+ "RDDBarrier",
+ "BarrierTaskContext",
+ "BarrierTaskInfo",
+ "InheritableThread",
+ "inheritable_thread_target",
+ "__version__",
]
diff --git a/python/pyspark/_globals.py b/python/pyspark/_globals.py
index 8e6099d..a635972 100644
--- a/python/pyspark/_globals.py
+++ b/python/pyspark/_globals.py
@@ -32,13 +32,13 @@ See gh-7844 for a discussion of the reload problem that motivated this module.
Note that this approach is taken after from NumPy.
"""
-__ALL__ = ['_NoValue']
+__ALL__ = ["_NoValue"]
# Disallow reloading this module so as to preserve the identities of the
# classes defined here.
-if '_is_loaded' in globals():
- raise RuntimeError('Reloading pyspark._globals is not allowed')
+if "_is_loaded" in globals():
+ raise RuntimeError("Reloading pyspark._globals is not allowed")
_is_loaded = True
@@ -51,6 +51,7 @@ class _NoValueType(object):
This class was copied from NumPy.
"""
+
__instance = None
def __new__(cls):
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index e6106b3..c43ebe4 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -23,7 +23,7 @@ import threading
from pyspark.serializers import read_int, PickleSerializer
-__all__ = ['Accumulator', 'AccumulatorParam']
+__all__ = ["Accumulator", "AccumulatorParam"]
pickleSer = PickleSerializer()
@@ -35,6 +35,7 @@ _accumulatorRegistry = {}
def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
+
# If this certain accumulator was deserialized, don't overwrite it.
if aid in _accumulatorRegistry:
return _accumulatorRegistry[aid]
@@ -108,6 +109,7 @@ class Accumulator(object):
def __init__(self, aid, value, accum_param):
"""Create a new Accumulator with a given initial value and AccumulatorParam object"""
from pyspark.accumulators import _accumulatorRegistry
+
self.aid = aid
self.accum_param = accum_param
self._value = value
@@ -225,6 +227,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
def handle(self):
from pyspark.accumulators import _accumulatorRegistry
+
auth_token = self.server.auth_token
def poll(func):
@@ -248,13 +251,14 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
received_token = self.rfile.read(len(auth_token))
if isinstance(received_token, bytes):
received_token = received_token.decode("utf-8")
- if (received_token == auth_token):
+ if received_token == auth_token:
accum_updates()
# we've authenticated, we can break out of the first loop now
return True
else:
raise ValueError(
- "The value of the provided token to the AccumulatorServer is not correct.")
+ "The value of the provided token to the AccumulatorServer is not correct."
+ )
# first we keep polling till we've received the authentication token
poll(authenticate_and_accum_updates)
@@ -263,7 +267,6 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
class AccumulatorServer(SocketServer.TCPServer):
-
def __init__(self, server_address, RequestHandlerClass, auth_token):
SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass)
self.auth_token = auth_token
@@ -288,16 +291,17 @@ def _start_update_server(auth_token):
thread.start()
return server
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
+
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- globs['sc'] = SparkContext('local', 'test')
- (failure_count, test_count) = doctest.testmod(
- globs=globs, optionflags=doctest.ELLIPSIS)
- globs['sc'].stop()
+ globs["sc"] = SparkContext("local", "test")
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ globs["sc"].stop()
if failure_count:
sys.exit(-1)
diff --git a/python/pyspark/accumulators.pyi b/python/pyspark/accumulators.pyi
index 13a1792cd..3159792 100644
--- a/python/pyspark/accumulators.pyi
+++ b/python/pyspark/accumulators.pyi
@@ -32,9 +32,7 @@ _accumulatorRegistry: Dict[int, Accumulator]
class Accumulator(Generic[T]):
aid: int
accum_param: AccumulatorParam[T]
- def __init__(
- self, aid: int, value: T, accum_param: AccumulatorParam[T]
- ) -> None: ...
+ def __init__(self, aid: int, value: T, accum_param: AccumulatorParam[T]) -> None: ...
def __reduce__(
self,
) -> Tuple[
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index aecd71f..995da33 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -27,7 +27,7 @@ from pyspark.serializers import ChunkedStream, pickle_protocol
from pyspark.util import print_exec
-__all__ = ['Broadcast']
+__all__ = ["Broadcast"]
# Holds broadcasted data received from Java, keyed by its id.
@@ -36,6 +36,7 @@ _broadcastRegistry = {}
def _from_id(bid):
from pyspark.broadcast import _broadcastRegistry
+
if bid not in _broadcastRegistry:
raise RuntimeError("Broadcast variable '%s' not loaded!" % bid)
return _broadcastRegistry[bid]
@@ -61,8 +62,7 @@ class Broadcast(object):
>>> large_broadcast = sc.broadcast(range(10000))
"""
- def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
- sock_file=None):
+ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, sock_file=None):
"""
Should not be called directly by users -- use :meth:`SparkContext.broadcast`
instead.
@@ -99,7 +99,7 @@ class Broadcast(object):
else:
# the jvm just dumps the pickled data in path -- we'll unpickle lazily when
# the value is requested
- assert(path is not None)
+ assert path is not None
self._path = path
def dump(self, value, f):
@@ -108,14 +108,13 @@ class Broadcast(object):
except pickle.PickleError:
raise
except Exception as e:
- msg = "Could not serialize broadcast: %s: %s" \
- % (e.__class__.__name__, str(e))
+ msg = "Could not serialize broadcast: %s: %s" % (e.__class__.__name__, str(e))
print_exec(sys.stderr)
raise pickle.PicklingError(msg)
f.close()
def load_from_path(self, path):
- with open(path, 'rb', 1 << 20) as f:
+ with open(path, "rb", 1 << 20) as f:
return self.load(f)
def load(self, file):
@@ -128,8 +127,7 @@ class Broadcast(object):
@property
def value(self):
- """ Return the broadcasted value
- """
+ """Return the broadcasted value"""
if not hasattr(self, "_value") and self._path is not None:
# we only need to decrypt it here when encryption is enabled and
# if its on the driver, since executor decryption is handled already
@@ -185,8 +183,7 @@ class Broadcast(object):
class BroadcastPickleRegistry(threading.local):
- """ Thread-local registry for broadcast variables that have been pickled
- """
+ """Thread-local registry for broadcast variables that have been pickled"""
def __init__(self):
self.__dict__.setdefault("_registry", set())
@@ -204,6 +201,7 @@ class BroadcastPickleRegistry(threading.local):
if __name__ == "__main__":
import doctest
+
(failure_count, test_count) = doctest.testmod()
if failure_count:
sys.exit(-1)
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index a6b2784..47ea8b6 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-__all__ = ['SparkConf']
+__all__ = ["SparkConf"]
import sys
from typing import Dict, List, Optional, Tuple, cast, overload
@@ -110,8 +110,12 @@ class SparkConf(object):
_jconf: Optional[JavaObject]
_conf: Optional[Dict[str, str]]
- def __init__(self, loadDefaults: bool = True, _jvm: Optional[JVMView] = None,
- _jconf: Optional[JavaObject] = None):
+ def __init__(
+ self,
+ loadDefaults: bool = True,
+ _jvm: Optional[JVMView] = None,
+ _jconf: Optional[JavaObject] = None,
+ ):
"""
Create a new Spark configuration.
"""
@@ -119,6 +123,7 @@ class SparkConf(object):
self._jconf = _jconf
else:
from pyspark.context import SparkContext
+
_jvm = _jvm or SparkContext._jvm # type: ignore[attr-defined]
if _jvm is not None:
@@ -169,8 +174,12 @@ class SparkConf(object):
def setExecutorEnv(self, *, pairs: List[Tuple[str, str]]) -> "SparkConf":
...
- def setExecutorEnv(self, key: Optional[str] = None, value: Optional[str] = None,
- pairs: Optional[List[Tuple[str, str]]] = None) -> "SparkConf":
+ def setExecutorEnv(
+ self,
+ key: Optional[str] = None,
+ value: Optional[str] = None,
+ pairs: Optional[List[Tuple[str, str]]] = None,
+ ) -> "SparkConf":
"""Set an environment variable to be passed to executors."""
if (key is not None and pairs is not None) or (key is None and pairs is None):
raise RuntimeError("Either pass one key-value pair or a list of pairs")
@@ -236,11 +245,12 @@ class SparkConf(object):
return self._jconf.toDebugString()
else:
assert self._conf is not None
- return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items())
+ return "\n".join("%s=%s" % (k, v) for k, v in self._conf.items())
def _test() -> None:
import doctest
+
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
sys.exit(-1)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 9ed1a82..2c78994 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -34,8 +34,15 @@ from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway, local_connect_and_auth
-from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
+from pyspark.serializers import (
+ PickleSerializer,
+ BatchedSerializer,
+ UTF8Deserializer,
+ PairDeserializer,
+ AutoBatchedSerializer,
+ NoOpSerializer,
+ ChunkedStream,
+)
from pyspark.storagelevel import StorageLevel
from pyspark.resource.information import ResourceInformation
from pyspark.rdd import RDD, _load_from_socket
@@ -45,7 +52,7 @@ from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
-__all__ = ['SparkContext']
+__all__ = ["SparkContext"]
# These are special default configs for PySpark, they will overwrite
@@ -125,13 +132,23 @@ class SparkContext(object):
_lock = RLock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
- PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')
-
- def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
- environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
- gateway=None, jsc=None, profiler_cls=BasicProfiler):
- if (conf is None or
- conf.get("spark.executor.allowSparkContext", "false").lower() != "true"):
+ PACKAGE_EXTENSIONS = (".zip", ".egg", ".jar")
+
+ def __init__(
+ self,
+ master=None,
+ appName=None,
+ sparkHome=None,
+ pyFiles=None,
+ environment=None,
+ batchSize=0,
+ serializer=PickleSerializer(),
+ conf=None,
+ gateway=None,
+ jsc=None,
+ profiler_cls=BasicProfiler,
+ ):
+ if conf is None or conf.get("spark.executor.allowSparkContext", "false").lower() != "true":
# In order to prevent SparkContext from being created in executors.
SparkContext._assert_on_driver()
@@ -139,19 +156,41 @@ class SparkContext(object):
if gateway is not None and gateway.gateway_parameters.auth_token is None:
raise ValueError(
"You are trying to pass an insecure Py4j gateway to Spark. This"
- " is not allowed as it is a security risk.")
+ " is not allowed as it is a security risk."
+ )
SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
try:
- self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf, jsc, profiler_cls)
+ self._do_init(
+ master,
+ appName,
+ sparkHome,
+ pyFiles,
+ environment,
+ batchSize,
+ serializer,
+ conf,
+ jsc,
+ profiler_cls,
+ )
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise
- def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf, jsc, profiler_cls):
+ def _do_init(
+ self,
+ master,
+ appName,
+ sparkHome,
+ pyFiles,
+ environment,
+ batchSize,
+ serializer,
+ conf,
+ jsc,
+ profiler_cls,
+ ):
self.environment = environment or {}
# java gateway must have been launched at this point.
if conf is not None and conf._jconf is not None:
@@ -170,8 +209,7 @@ class SparkContext(object):
if batchSize == 0:
self.serializer = AutoBatchedSerializer(self._unbatched_serializer)
else:
- self.serializer = BatchedSerializer(self._unbatched_serializer,
- batchSize)
+ self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
# Set any parameters passed directly to us on the conf
if master:
@@ -200,7 +238,7 @@ class SparkContext(object):
for (k, v) in self._conf.getAll():
if k.startswith("spark.executorEnv."):
- varName = k[len("spark.executorEnv."):]
+ varName = k[len("spark.executorEnv.") :]
self.environment[varName] = v
self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", "0")
@@ -222,21 +260,18 @@ class SparkContext(object):
# data via a socket.
# scala's mangled names w/ $ in them require special treatment.
self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc)
- os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = \
- str(self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc))
- os.environ["SPARK_BUFFER_SIZE"] = \
- str(self._jvm.PythonUtils.getSparkBufferSize(self._jsc))
+ os.environ["SPARK_AUTH_SOCKET_TIMEOUT"] = str(
+ self._jvm.PythonUtils.getPythonAuthSocketTimeout(self._jsc)
+ )
+ os.environ["SPARK_BUFFER_SIZE"] = str(self._jvm.PythonUtils.getSparkBufferSize(self._jsc))
- self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python3')
+ self.pythonExec = os.environ.get("PYSPARK_PYTHON", "python3")
self.pythonVer = "%d.%d" % sys.version_info[:2]
if sys.version_info[:2] < (3, 7):
with warnings.catch_warnings():
warnings.simplefilter("once")
- warnings.warn(
- "Python 3.6 support is deprecated in Spark 3.2.",
- FutureWarning
- )
+ warnings.warn("Python 3.6 support is deprecated in Spark 3.2.", FutureWarning)
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
@@ -250,7 +285,7 @@ class SparkContext(object):
# Deploy any code dependencies specified in the constructor
self._python_includes = list()
- for path in (pyFiles or []):
+ for path in pyFiles or []:
self.addPyFile(path)
# Deploy code dependencies set by spark-submit; these will already have been added
@@ -272,13 +307,14 @@ class SparkContext(object):
warnings.warn(
"Failed to add file [%s] specified in 'spark.submit.pyFiles' to "
"Python path:\n %s" % (path, "\n ".join(sys.path)),
- RuntimeWarning)
+ RuntimeWarning,
+ )
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
- self._temp_dir = \
- self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark") \
- .getAbsolutePath()
+ self._temp_dir = self._jvm.org.apache.spark.util.Utils.createTempDir(
+ local_dir, "pyspark"
+ ).getAbsolutePath()
# profiling stats collected for each PythonRDD
if self._conf.get("spark.python.profile", "false") == "true":
@@ -340,8 +376,10 @@ class SparkContext(object):
SparkContext._jvm = SparkContext._gateway.jvm
if instance:
- if (SparkContext._active_spark_context and
- SparkContext._active_spark_context != instance):
+ if (
+ SparkContext._active_spark_context
+ and SparkContext._active_spark_context != instance
+ ):
currentMaster = SparkContext._active_spark_context.master
currentAppName = SparkContext._active_spark_context.appName
callsite = SparkContext._active_spark_context._callsite
@@ -351,8 +389,14 @@ class SparkContext(object):
"Cannot run multiple SparkContexts at once; "
"existing SparkContext(app=%s, master=%s)"
" created by %s at %s:%s "
- % (currentAppName, currentMaster,
- callsite.function, callsite.file, callsite.linenum))
+ % (
+ currentAppName,
+ currentMaster,
+ callsite.function,
+ callsite.file,
+ callsite.linenum,
+ )
+ )
else:
SparkContext._active_spark_context = instance
@@ -466,10 +510,10 @@ class SparkContext(object):
except Py4JError:
# Case: SPARK-18523
warnings.warn(
- 'Unable to cleanly shutdown Spark JVM process.'
- ' It is possible that the process has crashed,'
- ' been killed or may also be in a zombie state.',
- RuntimeWarning
+ "Unable to cleanly shutdown Spark JVM process."
+ " It is possible that the process has crashed,"
+ " been killed or may also be in a zombie state.",
+ RuntimeWarning,
)
finally:
self._jsc = None
@@ -561,7 +605,7 @@ class SparkContext(object):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
- c = list(c) # Make it a list so we can compute its length
+ c = list(c) # Make it a list so we can compute its length
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
@@ -652,8 +696,7 @@ class SparkContext(object):
['Hello world!']
"""
minPartitions = minPartitions or min(self.defaultParallelism, 2)
- return RDD(self._jsc.textFile(name, minPartitions), self,
- UTF8Deserializer(use_unicode))
+ return RDD(self._jsc.textFile(name, minPartitions), self, UTF8Deserializer(use_unicode))
def wholeTextFiles(self, path, minPartitions=None, use_unicode=True):
"""
@@ -704,8 +747,11 @@ class SparkContext(object):
[('.../1.txt', '1'), ('.../2.txt', '2')]
"""
minPartitions = minPartitions or self.defaultMinPartitions
- return RDD(self._jsc.wholeTextFiles(path, minPartitions), self,
- PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode)))
+ return RDD(
+ self._jsc.wholeTextFiles(path, minPartitions),
+ self,
+ PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode)),
+ )
def binaryFiles(self, path, minPartitions=None):
"""
@@ -720,8 +766,11 @@ class SparkContext(object):
Small files are preferred, large file is also allowable, but may cause bad performance.
"""
minPartitions = minPartitions or self.defaultMinPartitions
- return RDD(self._jsc.binaryFiles(path, minPartitions), self,
- PairDeserializer(UTF8Deserializer(), NoOpSerializer()))
+ return RDD(
+ self._jsc.binaryFiles(path, minPartitions),
+ self,
+ PairDeserializer(UTF8Deserializer(), NoOpSerializer()),
+ )
def binaryRecords(self, path, recordLength):
"""
@@ -746,8 +795,16 @@ class SparkContext(object):
jm[k] = v
return jm
- def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
- valueConverter=None, minSplits=None, batchSize=0):
+ def sequenceFile(
+ self,
+ path,
+ keyClass=None,
+ valueClass=None,
+ keyConverter=None,
+ valueConverter=None,
+ minSplits=None,
+ batchSize=0,
+ ):
"""
Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -779,12 +836,29 @@ class SparkContext(object):
Java object. (default 0, choose batchSize automatically)
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
- jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass,
- keyConverter, valueConverter, minSplits, batchSize)
+ jrdd = self._jvm.PythonRDD.sequenceFile(
+ self._jsc,
+ path,
+ keyClass,
+ valueClass,
+ keyConverter,
+ valueConverter,
+ minSplits,
+ batchSize,
+ )
return RDD(jrdd, self)
- def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=0):
+ def newAPIHadoopFile(
+ self,
+ path,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ keyConverter=None,
+ valueConverter=None,
+ conf=None,
+ batchSize=0,
+ ):
"""
Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -820,13 +894,29 @@ class SparkContext(object):
Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
- valueClass, keyConverter, valueConverter,
- jconf, batchSize)
+ jrdd = self._jvm.PythonRDD.newAPIHadoopFile(
+ self._jsc,
+ path,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ keyConverter,
+ valueConverter,
+ jconf,
+ batchSize,
+ )
return RDD(jrdd, self)
- def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=0):
+ def newAPIHadoopRDD(
+ self,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ keyConverter=None,
+ valueConverter=None,
+ conf=None,
+ batchSize=0,
+ ):
"""
Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
Hadoop configuration, which is passed in as a Python dict.
@@ -856,13 +946,29 @@ class SparkContext(object):
Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass,
- valueClass, keyConverter, valueConverter,
- jconf, batchSize)
+ jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(
+ self._jsc,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ keyConverter,
+ valueConverter,
+ jconf,
+ batchSize,
+ )
return RDD(jrdd, self)
- def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=0):
+ def hadoopFile(
+ self,
+ path,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ keyConverter=None,
+ valueConverter=None,
+ conf=None,
+ batchSize=0,
+ ):
"""
Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -894,13 +1000,29 @@ class SparkContext(object):
Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass,
- valueClass, keyConverter, valueConverter,
- jconf, batchSize)
+ jrdd = self._jvm.PythonRDD.hadoopFile(
+ self._jsc,
+ path,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ keyConverter,
+ valueConverter,
+ jconf,
+ batchSize,
+ )
return RDD(jrdd, self)
- def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=0):
+ def hadoopRDD(
+ self,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ keyConverter=None,
+ valueConverter=None,
+ conf=None,
+ batchSize=0,
+ ):
"""
Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
Hadoop configuration, which is passed in as a Python dict.
@@ -930,9 +1052,16 @@ class SparkContext(object):
Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass,
- valueClass, keyConverter, valueConverter,
- jconf, batchSize)
+ jrdd = self._jvm.PythonRDD.hadoopRDD(
+ self._jsc,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ keyConverter,
+ valueConverter,
+ jconf,
+ batchSize,
+ )
return RDD(jrdd, self)
def _checkpointFile(self, name, input_deserializer):
@@ -1087,11 +1216,13 @@ class SparkContext(object):
raise TypeError("storageLevel must be of type pyspark.StorageLevel")
newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
- return newStorageLevel(storageLevel.useDisk,
- storageLevel.useMemory,
- storageLevel.useOffHeap,
- storageLevel.deserialized,
- storageLevel.replication)
+ return newStorageLevel(
+ storageLevel.useDisk,
+ storageLevel.useMemory,
+ storageLevel.useOffHeap,
+ storageLevel.deserialized,
+ storageLevel.replication,
+ )
def setJobGroup(self, groupId, description, interruptOnCancel=False):
"""
@@ -1228,21 +1359,24 @@ class SparkContext(object):
return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer))
def show_profiles(self):
- """ Print the profile stats to stdout """
+ """Print the profile stats to stdout"""
if self.profiler_collector is not None:
self.profiler_collector.show_profiles()
else:
- raise RuntimeError("'spark.python.profile' configuration must be set "
- "to 'true' to enable Python profile.")
+ raise RuntimeError(
+ "'spark.python.profile' configuration must be set "
+ "to 'true' to enable Python profile."
+ )
def dump_profiles(self, path):
- """ Dump the profile stats into directory `path`
- """
+ """Dump the profile stats into directory `path`"""
if self.profiler_collector is not None:
self.profiler_collector.dump_profiles(path)
else:
- raise RuntimeError("'spark.python.profile' configuration must be set "
- "to 'true' to enable Python profile.")
+ raise RuntimeError(
+ "'spark.python.profile' configuration must be set "
+ "to 'true' to enable Python profile."
+ )
def getConf(self):
conf = SparkConf()
@@ -1275,12 +1409,13 @@ def _test():
import atexit
import doctest
import tempfile
+
globs = globals().copy()
- globs['sc'] = SparkContext('local[4]', 'PythonTest')
- globs['tempdir'] = tempfile.mkdtemp()
- atexit.register(lambda: shutil.rmtree(globs['tempdir']))
+ globs["sc"] = SparkContext("local[4]", "PythonTest")
+ globs["tempdir"] = tempfile.mkdtemp()
+ atexit.register(lambda: shutil.rmtree(globs["tempdir"]))
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- globs['sc'].stop()
+ globs["sc"].stop()
if failure_count:
sys.exit(-1)
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 97b6b25..b8fd03d 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -88,13 +88,13 @@ def manager():
# Create a listening socket on the AF_INET loopback interface
listen_sock = socket.socket(AF_INET, SOCK_STREAM)
- listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(max(1024, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
# re-open stdin/stdout in 'wb' mode
- stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4)
- stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4)
+ stdin_bin = os.fdopen(sys.stdin.fileno(), "rb", 4)
+ stdout_bin = os.fdopen(sys.stdout.fileno(), "wb", 4)
write_int(listen_port, stdout_bin)
stdout_bin.flush()
@@ -106,6 +106,7 @@ def manager():
def handle_sigterm(*args):
shutdown(1)
+
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
signal.signal(SIGCHLD, SIG_IGN)
@@ -150,7 +151,7 @@ def manager():
time.sleep(1)
pid = os.fork() # error here will shutdown daemon
else:
- outfile = sock.makefile(mode='wb')
+ outfile = sock.makefile(mode="wb")
write_int(e.errno, outfile) # Signal that the fork failed
outfile.flush()
outfile.close()
@@ -171,7 +172,7 @@ def manager():
# Therefore, here we redirects it to '/dev/null' by duplicating
# another file descriptor for '/dev/null' to the standard input (0).
# See SPARK-26175.
- devnull = open(os.devnull, 'r')
+ devnull = open(os.devnull, "r")
os.dup2(devnull.fileno(), 0)
devnull.close()
@@ -207,5 +208,5 @@ def manager():
shutdown(1)
-if __name__ == '__main__':
+if __name__ == "__main__":
manager()
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
index a0b74ca..59ab0b6 100644
--- a/python/pyspark/files.py
+++ b/python/pyspark/files.py
@@ -18,7 +18,7 @@
import os
-__all__ = ['SparkFiles']
+__all__ = ["SparkFiles"]
from typing import cast, ClassVar, Optional, TYPE_CHECKING
@@ -61,6 +61,5 @@ class SparkFiles(object):
else:
# This will have to change if we support multiple SparkContexts:
return cast(
- "SparkContext",
- cls._sc
+ "SparkContext", cls._sc
)._jvm.org.apache.spark.SparkFiles.getRootDirectory() # type: ignore[attr-defined]
diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py
index 62a36d4..09f4551 100755
--- a/python/pyspark/find_spark_home.py
+++ b/python/pyspark/find_spark_home.py
@@ -32,9 +32,10 @@ def _find_spark_home():
def is_spark_home(path):
"""Takes a path and returns true if the provided path could be a reasonable SPARK_HOME"""
- return (os.path.isfile(os.path.join(path, "bin/spark-submit")) and
- (os.path.isdir(os.path.join(path, "jars")) or
- os.path.isdir(os.path.join(path, "assembly"))))
+ return os.path.isfile(os.path.join(path, "bin/spark-submit")) and (
+ os.path.isdir(os.path.join(path, "jars"))
+ or os.path.isdir(os.path.join(path, "assembly"))
+ )
# Spark distribution can be downloaded when PYSPARK_HADOOP_VERSION environment variable is set.
# We should look up this directory first, see also SPARK-32017.
@@ -43,11 +44,13 @@ def _find_spark_home():
"../", # When we're in spark/python.
# Two case belows are valid when the current script is called as a library.
os.path.join(os.path.dirname(os.path.realpath(__file__)), spark_dist_dir),
- os.path.dirname(os.path.realpath(__file__))]
+ os.path.dirname(os.path.realpath(__file__)),
+ ]
# Add the path of the PySpark module if it exists
import_error_raised = False
from importlib.util import find_spec
+
try:
module_home = os.path.dirname(find_spec("pyspark").origin)
paths.append(os.path.join(module_home, spark_dist_dir))
@@ -78,7 +81,9 @@ def _find_spark_home():
"for example, 'python -m pip install pyspark [--user]'. Otherwise, you can also\n"
"explicitly set the Python executable, that has PySpark installed, to\n"
"PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON environment variables, for example,\n"
- "'PYSPARK_PYTHON=python3 pyspark'.\n", file=sys.stderr)
+ "'PYSPARK_PYTHON=python3 pyspark'.\n",
+ file=sys.stderr,
+ )
sys.exit(-1)
diff --git a/python/pyspark/install.py b/python/pyspark/install.py
index 7efee42..bfd7001 100644
--- a/python/pyspark/install.py
+++ b/python/pyspark/install.py
@@ -20,6 +20,7 @@ import tarfile
import traceback
import urllib.request
from shutil import rmtree
+
# NOTE that we shouldn't import pyspark here because this is used in
# setup.py, and assume there's no PySpark imported.
@@ -27,8 +28,7 @@ DEFAULT_HADOOP = "hadoop3.2"
DEFAULT_HIVE = "hive2.3"
SUPPORTED_HADOOP_VERSIONS = ["hadoop2.7", "hadoop3.2", "without-hadoop"]
SUPPORTED_HIVE_VERSIONS = ["hive2.3"]
-UNSUPPORTED_COMBINATIONS = [ # type: ignore
-]
+UNSUPPORTED_COMBINATIONS = [] # type: ignore
def checked_package_name(spark_version, hadoop_version, hive_version):
@@ -60,8 +60,8 @@ def checked_versions(spark_version, hadoop_version, hive_version):
spark_version = "spark-%s" % spark_version
if not spark_version.startswith("spark-"):
raise RuntimeError(
- "Spark version should start with 'spark-' prefix; however, "
- "got %s" % spark_version)
+ "Spark version should start with 'spark-' prefix; however, " "got %s" % spark_version
+ )
if hadoop_version == "without":
hadoop_version = "without-hadoop"
@@ -71,8 +71,8 @@ def checked_versions(spark_version, hadoop_version, hive_version):
if hadoop_version not in SUPPORTED_HADOOP_VERSIONS:
raise RuntimeError(
"Spark distribution of %s is not supported. Hadoop version should be "
- "one of [%s]" % (hadoop_version, ", ".join(
- SUPPORTED_HADOOP_VERSIONS)))
+ "one of [%s]" % (hadoop_version, ", ".join(SUPPORTED_HADOOP_VERSIONS))
+ )
if re.match("^[0-9]+\\.[0-9]+$", hive_version):
hive_version = "hive%s" % hive_version
@@ -80,8 +80,8 @@ def checked_versions(spark_version, hadoop_version, hive_version):
if hive_version not in SUPPORTED_HIVE_VERSIONS:
raise RuntimeError(
"Spark distribution of %s is not supported. Hive version should be "
- "one of [%s]" % (hive_version, ", ".join(
- SUPPORTED_HADOOP_VERSIONS)))
+ "one of [%s]" % (hive_version, ", ".join(SUPPORTED_HADOOP_VERSIONS))
+ )
return spark_version, hadoop_version, hive_version
@@ -114,7 +114,8 @@ def install_spark(dest, spark_version, hadoop_version, hive_version):
pretty_pkg_name = "%s for Hadoop %s" % (
spark_version,
- "Free build" if hadoop_version == "without" else hadoop_version)
+ "Free build" if hadoop_version == "without" else hadoop_version,
+ )
for site in sites:
os.makedirs(dest, exist_ok=True)
@@ -151,19 +152,22 @@ def get_preferred_mirrors():
for _ in range(3):
try:
response = urllib.request.urlopen(
- "https://www.apache.org/dyn/closer.lua?preferred=true")
- mirror_urls.append(response.read().decode('utf-8'))
+ "https://www.apache.org/dyn/closer.lua?preferred=true"
+ )
+ mirror_urls.append(response.read().decode("utf-8"))
except Exception:
# If we can't get a mirror URL, skip it. No retry.
pass
default_sites = [
- "https://archive.apache.org/dist", "https://dist.apache.org/repos/dist/release"]
+ "https://archive.apache.org/dist",
+ "https://dist.apache.org/repos/dist/release",
+ ]
return list(set(mirror_urls)) + default_sites
def download_to_file(response, path, chunk_size=1024 * 1024):
- total_size = int(response.info().get('Content-Length').strip())
+ total_size = int(response.info().get("Content-Length").strip())
bytes_so_far = 0
with open(path, mode="wb") as dest:
@@ -173,7 +177,7 @@ def download_to_file(response, path, chunk_size=1024 * 1024):
if not chunk:
break
dest.write(chunk)
- print("Downloaded %d of %d bytes (%0.2f%%)" % (
- bytes_so_far,
- total_size,
- round(float(bytes_so_far) / total_size * 100, 2)))
+ print(
+ "Downloaded %d of %d bytes (%0.2f%%)"
+ % (bytes_so_far, total_size, round(float(bytes_so_far) / total_size * 100, 2))
+ )
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index bffdc0b..a41ccfa 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -64,13 +64,10 @@ def launch_gateway(conf=None, popen_kwargs=None):
command = [os.path.join(SPARK_HOME, script)]
if conf:
for k, v in conf.getAll():
- command += ['--conf', '%s=%s' % (k, v)]
+ command += ["--conf", "%s=%s" % (k, v)]
submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
if os.environ.get("SPARK_TESTING"):
- submit_args = ' '.join([
- "--conf spark.ui.enabled=false",
- submit_args
- ])
+ submit_args = " ".join(["--conf spark.ui.enabled=false", submit_args])
command = command + shlex.split(submit_args)
# Create a temporary directory where the gateway server should write the connection
@@ -87,14 +84,15 @@ def launch_gateway(conf=None, popen_kwargs=None):
# Launch the Java gateway.
popen_kwargs = {} if popen_kwargs is None else popen_kwargs
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
- popen_kwargs['stdin'] = PIPE
+ popen_kwargs["stdin"] = PIPE
# We always set the necessary environment variables.
- popen_kwargs['env'] = env
+ popen_kwargs["env"] = env
if not on_windows:
# Don't send ctrl-c / SIGINT to the Java gateway:
def preexec_func():
signal.signal(signal.SIGINT, signal.SIG_IGN)
- popen_kwargs['preexec_fn'] = preexec_func
+
+ popen_kwargs["preexec_fn"] = preexec_func
proc = Popen(command, **popen_kwargs)
else:
# preexec_fn not supported on Windows
@@ -127,24 +125,23 @@ def launch_gateway(conf=None, popen_kwargs=None):
# child processes in the tree (http://technet.microsoft.com/en-us/library/bb491009.aspx)
def killChild():
Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
+
atexit.register(killChild)
# Connect to the gateway (or client server to pin the thread between JVM and Python)
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
gateway = ClientServer(
java_parameters=JavaParameters(
- port=gateway_port,
- auth_token=gateway_secret,
- auto_convert=True),
- python_parameters=PythonParameters(
- port=0,
- eager_load=False))
+ port=gateway_port, auth_token=gateway_secret, auto_convert=True
+ ),
+ python_parameters=PythonParameters(port=0, eager_load=False),
+ )
else:
gateway = JavaGateway(
gateway_parameters=GatewayParameters(
- port=gateway_port,
- auth_token=gateway_secret,
- auto_convert=True))
+ port=gateway_port, auth_token=gateway_secret, auto_convert=True
+ )
+ )
# Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr)
gateway.proc = proc
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index c1f5362..040c946 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -50,6 +50,7 @@ def python_join(rdd, other, numPartitions):
elif n == 2:
wbuf.append(v)
return ((v, w) for v in vbuf for w in wbuf)
+
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -64,6 +65,7 @@ def python_right_outer_join(rdd, other, numPartitions):
if not vbuf:
vbuf.append(None)
return ((v, w) for v in vbuf for w in wbuf)
+
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -78,6 +80,7 @@ def python_left_outer_join(rdd, other, numPartitions):
if not wbuf:
wbuf.append(None)
return ((v, w) for v in vbuf for w in wbuf)
+
return _do_python_join(rdd, other, numPartitions, dispatch)
@@ -94,12 +97,14 @@ def python_full_outer_join(rdd, other, numPartitions):
if not wbuf:
wbuf.append(None)
return ((v, w) for v in vbuf for w in wbuf)
+
return _do_python_join(rdd, other, numPartitions, dispatch)
def python_cogroup(rdds, numPartitions):
def make_mapper(i):
return lambda v: (i, v)
+
vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)]
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
rdd_len = len(vrdds)
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
index 7d0e55a..8a235f7 100644
--- a/python/pyspark/ml/__init__.py
+++ b/python/pyspark/ml/__init__.py
@@ -19,15 +19,51 @@
DataFrame-based machine learning APIs to let users quickly assemble and configure practical
machine learning pipelines.
"""
-from pyspark.ml.base import Estimator, Model, Predictor, PredictionModel, \
- Transformer, UnaryTransformer
+from pyspark.ml.base import (
+ Estimator,
+ Model,
+ Predictor,
+ PredictionModel,
+ Transformer,
+ UnaryTransformer,
+)
from pyspark.ml.pipeline import Pipeline, PipelineModel
-from pyspark.ml import classification, clustering, evaluation, feature, fpm, \
- image, recommendation, regression, stat, tuning, util, linalg, param
+from pyspark.ml import (
+ classification,
+ clustering,
+ evaluation,
+ feature,
+ fpm,
+ image,
+ recommendation,
+ regression,
+ stat,
+ tuning,
+ util,
+ linalg,
+ param,
+)
__all__ = [
- "Transformer", "UnaryTransformer", "Estimator", "Model",
- "Predictor", "PredictionModel", "Pipeline", "PipelineModel",
- "classification", "clustering", "evaluation", "feature", "fpm", "image",
- "recommendation", "regression", "stat", "tuning", "util", "linalg", "param",
+ "Transformer",
+ "UnaryTransformer",
+ "Estimator",
+ "Model",
+ "Predictor",
+ "PredictionModel",
+ "Pipeline",
+ "PipelineModel",
+ "classification",
+ "clustering",
+ "evaluation",
+ "feature",
+ "fpm",
+ "image",
+ "recommendation",
+ "regression",
+ "stat",
+ "tuning",
+ "util",
+ "linalg",
+ "param",
]
diff --git a/python/pyspark/ml/_typing.pyi b/python/pyspark/ml/_typing.pyi
index d966a78..40531d1 100644
--- a/python/pyspark/ml/_typing.pyi
+++ b/python/pyspark/ml/_typing.pyi
@@ -32,9 +32,7 @@ P = TypeVar("P", bound=pyspark.ml.param.Params)
M = TypeVar("M", bound=pyspark.ml.base.Transformer)
JM = TypeVar("JM", bound=pyspark.ml.wrapper.JavaTransformer)
-BinaryClassificationEvaluatorMetricType = Union[
- Literal["areaUnderROC"], Literal["areaUnderPR"]
-]
+BinaryClassificationEvaluatorMetricType = Union[Literal["areaUnderROC"], Literal["areaUnderPR"]]
RegressionEvaluatorMetricType = Union[
Literal["rmse"], Literal["mse"], Literal["r2"], Literal["mae"], Literal["var"]
]
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 31ce93d..970444d 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -22,8 +22,14 @@ import threading
from pyspark import since
from pyspark.ml.common import inherit_doc
-from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasLabelCol, HasFeaturesCol, \
- HasPredictionCol, Params
+from pyspark.ml.param.shared import (
+ HasInputCol,
+ HasOutputCol,
+ HasLabelCol,
+ HasFeaturesCol,
+ HasPredictionCol,
+ Params,
+)
from pyspark.sql.functions import udf
from pyspark.sql.types import StructField, StructType
@@ -48,10 +54,9 @@ class _FitMultipleIterator(object):
-----
See :py:meth:`Estimator.fitMultiple` for more info.
"""
- def __init__(self, fitSingleModel, numModels):
- """
- """
+ def __init__(self, fitSingleModel, numModels):
+ """ """
self.fitSingleModel = fitSingleModel
self.numModel = numModels
self.counter = 0
@@ -80,6 +85,7 @@ class Estimator(Params, metaclass=ABCMeta):
.. versionadded:: 1.3.0
"""
+
pass
@abstractmethod
@@ -160,8 +166,10 @@ class Estimator(Params, metaclass=ABCMeta):
else:
return self._fit(dataset)
else:
- raise TypeError("Params must be either a param map or a list/tuple of param maps, "
- "but got %s." % type(params))
+ raise TypeError(
+ "Params must be either a param map or a list/tuple of param maps, "
+ "but got %s." % type(params)
+ )
@inherit_doc
@@ -171,6 +179,7 @@ class Transformer(Params, metaclass=ABCMeta):
.. versionadded:: 1.3.0
"""
+
pass
@abstractmethod
@@ -226,6 +235,7 @@ class Model(Transformer, metaclass=ABCMeta):
.. versionadded:: 1.4.0
"""
+
pass
@@ -279,16 +289,15 @@ class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
if self.getOutputCol() in schema.names:
raise ValueError("Output column %s already exists." % self.getOutputCol())
outputFields = copy.copy(schema.fields)
- outputFields.append(StructField(self.getOutputCol(),
- self.outputDataType(),
- nullable=False))
+ outputFields.append(StructField(self.getOutputCol(), self.outputDataType(), nullable=False))
return StructType(outputFields)
def _transform(self, dataset):
self.transformSchema(dataset.schema)
transformUDF = udf(self.createTransformFunc(), self.outputDataType())
- transformedDataset = dataset.withColumn(self.getOutputCol(),
- transformUDF(dataset[self.getInputCol()]))
+ transformedDataset = dataset.withColumn(
+ self.getOutputCol(), transformUDF(dataset[self.getInputCol()])
+ )
return transformedDataset
@@ -299,6 +308,7 @@ class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
.. versionadded:: 3.0.0
"""
+
pass
diff --git a/python/pyspark/ml/base.pyi b/python/pyspark/ml/base.pyi
index 4f1c6f9..37ae6de 100644
--- a/python/pyspark/ml/base.pyi
+++ b/python/pyspark/ml/base.pyi
@@ -57,9 +57,7 @@ class _FitMultipleIterator:
numModel: int
counter: int = ...
lock: _thread.LockType
- def __init__(
- self, fitSingleModel: Callable[[int], Transformer], numModels: int
- ) -> None: ...
+ def __init__(self, fitSingleModel: Callable[[int], Transformer], numModels: int) -> None: ...
def __iter__(self) -> _FitMultipleIterator: ...
def __next__(self) -> Tuple[int, Transformer]: ...
def next(self) -> Tuple[int, Transformer]: ...
@@ -76,9 +74,7 @@ class Estimator(Generic[M], Params, metaclass=abc.ABCMeta):
) -> Iterable[Tuple[int, M]]: ...
class Transformer(Params, metaclass=abc.ABCMeta):
- def transform(
- self, dataset: DataFrame, params: Optional[ParamMap] = ...
- ) -> DataFrame: ...
+ def transform(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> DataFrame: ...
class Model(Transformer, metaclass=abc.ABCMeta): ...
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 79b57d7..e6ce3e0 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -25,20 +25,54 @@ from multiprocessing.pool import ThreadPool
from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
from pyspark.ml import Estimator, Predictor, PredictionModel, Model
-from pyspark.ml.param.shared import HasRawPredictionCol, HasProbabilityCol, HasThresholds, \
- HasRegParam, HasMaxIter, HasFitIntercept, HasTol, HasStandardization, HasWeightCol, \
- HasAggregationDepth, HasThreshold, HasBlockSize, HasMaxBlockSizeInMB, Param, Params, \
- TypeConverters, HasElasticNetParam, HasSeed, HasStepSize, HasSolver, HasParallelism
-from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
- _TreeEnsembleModel, _RandomForestParams, _GBTParams, \
- _HasVarianceImpurity, _TreeClassifierParams
+from pyspark.ml.param.shared import (
+ HasRawPredictionCol,
+ HasProbabilityCol,
+ HasThresholds,
+ HasRegParam,
+ HasMaxIter,
+ HasFitIntercept,
+ HasTol,
+ HasStandardization,
+ HasWeightCol,
+ HasAggregationDepth,
+ HasThreshold,
+ HasBlockSize,
+ HasMaxBlockSizeInMB,
+ Param,
+ Params,
+ TypeConverters,
+ HasElasticNetParam,
+ HasSeed,
+ HasStepSize,
+ HasSolver,
+ HasParallelism,
+)
+from pyspark.ml.tree import (
+ _DecisionTreeModel,
+ _DecisionTreeParams,
+ _TreeEnsembleModel,
+ _RandomForestParams,
+ _GBTParams,
+ _HasVarianceImpurity,
+ _TreeClassifierParams,
+)
from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel
from pyspark.ml.base import _PredictorParams
-from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, \
- JavaMLReadable, JavaMLReader, JavaMLWritable, JavaMLWriter, \
- MLReader, MLReadable, MLWriter, MLWritable, HasTrainingSummary
-from pyspark.ml.wrapper import JavaParams, \
- JavaPredictor, JavaPredictionModel, JavaWrapper
+from pyspark.ml.util import (
+ DefaultParamsReader,
+ DefaultParamsWriter,
+ JavaMLReadable,
+ JavaMLReader,
+ JavaMLWritable,
+ JavaMLWriter,
+ MLReader,
+ MLReadable,
+ MLWriter,
+ MLWritable,
+ HasTrainingSummary,
+)
+from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql import DataFrame
@@ -46,24 +80,40 @@ from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.storagelevel import StorageLevel
-__all__ = ['LinearSVC', 'LinearSVCModel',
- 'LinearSVCSummary', 'LinearSVCTrainingSummary',
- 'LogisticRegression', 'LogisticRegressionModel',
- 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
- 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
- 'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
- 'GBTClassifier', 'GBTClassificationModel',
- 'RandomForestClassifier', 'RandomForestClassificationModel',
- 'RandomForestClassificationSummary', 'RandomForestClassificationTrainingSummary',
- 'BinaryRandomForestClassificationSummary',
- 'BinaryRandomForestClassificationTrainingSummary',
- 'NaiveBayes', 'NaiveBayesModel',
- 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel',
- 'MultilayerPerceptronClassificationSummary',
- 'MultilayerPerceptronClassificationTrainingSummary',
- 'OneVsRest', 'OneVsRestModel',
- 'FMClassifier', 'FMClassificationModel', 'FMClassificationSummary',
- 'FMClassificationTrainingSummary']
+__all__ = [
+ "LinearSVC",
+ "LinearSVCModel",
+ "LinearSVCSummary",
+ "LinearSVCTrainingSummary",
+ "LogisticRegression",
+ "LogisticRegressionModel",
+ "LogisticRegressionSummary",
+ "LogisticRegressionTrainingSummary",
+ "BinaryLogisticRegressionSummary",
+ "BinaryLogisticRegressionTrainingSummary",
+ "DecisionTreeClassifier",
+ "DecisionTreeClassificationModel",
+ "GBTClassifier",
+ "GBTClassificationModel",
+ "RandomForestClassifier",
+ "RandomForestClassificationModel",
+ "RandomForestClassificationSummary",
+ "RandomForestClassificationTrainingSummary",
+ "BinaryRandomForestClassificationSummary",
+ "BinaryRandomForestClassificationTrainingSummary",
+ "NaiveBayes",
+ "NaiveBayesModel",
+ "MultilayerPerceptronClassifier",
+ "MultilayerPerceptronClassificationModel",
+ "MultilayerPerceptronClassificationSummary",
+ "MultilayerPerceptronClassificationTrainingSummary",
+ "OneVsRest",
+ "OneVsRestModel",
+ "FMClassifier",
+ "FMClassificationModel",
+ "FMClassificationSummary",
+ "FMClassificationTrainingSummary",
+]
class _ClassifierParams(HasRawPredictionCol, _PredictorParams):
@@ -72,6 +122,7 @@ class _ClassifierParams(HasRawPredictionCol, _PredictorParams):
.. versionadded:: 3.0.0
"""
+
pass
@@ -128,12 +179,12 @@ class _ProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _Classifi
.. versionadded:: 3.0.0
"""
+
pass
@inherit_doc
-class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams,
- metaclass=ABCMeta):
+class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams, metaclass=ABCMeta):
"""
Probabilistic Classifier for classification tasks.
"""
@@ -154,9 +205,9 @@ class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams,
@inherit_doc
-class ProbabilisticClassificationModel(ClassificationModel,
- _ProbabilisticClassifierParams,
- metaclass=ABCMeta):
+class ProbabilisticClassificationModel(
+ ClassificationModel, _ProbabilisticClassifierParams, metaclass=ABCMeta
+):
"""
Model produced by a ``ProbabilisticClassifier``.
"""
@@ -224,17 +275,18 @@ class _JavaClassificationModel(ClassificationModel, JavaPredictionModel):
@inherit_doc
-class _JavaProbabilisticClassifier(ProbabilisticClassifier, _JavaClassifier,
- metaclass=ABCMeta):
+class _JavaProbabilisticClassifier(ProbabilisticClassifier, _JavaClassifier, metaclass=ABCMeta):
"""
Java Probabilistic Classifier for classification tasks.
"""
+
pass
@inherit_doc
-class _JavaProbabilisticClassificationModel(ProbabilisticClassificationModel,
- _JavaClassificationModel):
+class _JavaProbabilisticClassificationModel(
+ ProbabilisticClassificationModel, _JavaClassificationModel
+):
"""
Java Model produced by a ``ProbabilisticClassifier``.
"""
@@ -505,26 +557,45 @@ class _BinaryClassificationSummary(_ClassificationSummary):
return self._call_java("recallByThreshold")
-class _LinearSVCParams(_ClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol,
- HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold,
- HasMaxBlockSizeInMB):
+class _LinearSVCParams(
+ _ClassifierParams,
+ HasRegParam,
+ HasMaxIter,
+ HasFitIntercept,
+ HasTol,
+ HasStandardization,
+ HasWeightCol,
+ HasAggregationDepth,
+ HasThreshold,
+ HasMaxBlockSizeInMB,
+):
"""
Params for :py:class:`LinearSVC` and :py:class:`LinearSVCModel`.
.. versionadded:: 3.0.0
"""
- threshold = Param(Params._dummy(), "threshold",
- "The threshold in binary classification applied to the linear model"
- " prediction. This threshold can be any real number, where Inf will make"
- " all predictions 0.0 and -Inf will make all predictions 1.0.",
- typeConverter=TypeConverters.toFloat)
+ threshold = Param(
+ Params._dummy(),
+ "threshold",
+ "The threshold in binary classification applied to the linear model"
+ " prediction. This threshold can be any real number, where Inf will make"
+ " all predictions 0.0 and -Inf will make all predictions 1.0.",
+ typeConverter=TypeConverters.toFloat,
+ )
def __init__(self, *args):
super(_LinearSVCParams, self).__init__(*args)
- self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, threshold=0.0, aggregationDepth=2,
- maxBlockSizeInMB=0.0)
+ self._setDefault(
+ maxIter=100,
+ regParam=0.0,
+ tol=1e-6,
+ fitIntercept=True,
+ standardization=True,
+ threshold=0.0,
+ aggregationDepth=2,
+ maxBlockSizeInMB=0.0,
+ )
@inherit_doc
@@ -605,10 +676,23 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
- fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
- aggregationDepth=2, maxBlockSizeInMB=0.0):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ maxIter=100,
+ regParam=0.0,
+ tol=1e-6,
+ rawPredictionCol="rawPrediction",
+ fitIntercept=True,
+ standardization=True,
+ threshold=0.0,
+ weightCol=None,
+ aggregationDepth=2,
+ maxBlockSizeInMB=0.0,
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
@@ -617,16 +701,30 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl
"""
super(LinearSVC, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.classification.LinearSVC", self.uid)
+ "org.apache.spark.ml.classification.LinearSVC", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.2.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
- fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
- aggregationDepth=2, maxBlockSizeInMB=0.0):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ maxIter=100,
+ regParam=0.0,
+ tol=1e-6,
+ rawPredictionCol="rawPrediction",
+ fitIntercept=True,
+ standardization=True,
+ threshold=0.0,
+ weightCol=None,
+ aggregationDepth=2,
+ maxBlockSizeInMB=0.0,
+ ):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
@@ -704,8 +802,9 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl
return self._set(maxBlockSizeInMB=value)
-class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable,
- HasTrainingSummary):
+class LinearSVCModel(
+ _JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary
+):
"""
Model fitted by LinearSVC.
@@ -744,8 +843,9 @@ class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable,
if self.hasSummary:
return LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
else:
- raise RuntimeError("No training summary available for this %s" %
- self.__class__.__name__)
+ raise RuntimeError(
+ "No training summary available for this %s" % self.__class__.__name__
+ )
def evaluate(self, dataset):
"""
@@ -770,6 +870,7 @@ class LinearSVCSummary(_BinaryClassificationSummary):
.. versionadded:: 3.1.0
"""
+
pass
@@ -780,66 +881,95 @@ class LinearSVCTrainingSummary(LinearSVCSummary, _TrainingSummary):
.. versionadded:: 3.1.0
"""
+
pass
-class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
- HasElasticNetParam, HasMaxIter, HasFitIntercept, HasTol,
- HasStandardization, HasWeightCol, HasAggregationDepth,
- HasThreshold, HasMaxBlockSizeInMB):
+class _LogisticRegressionParams(
+ _ProbabilisticClassifierParams,
+ HasRegParam,
+ HasElasticNetParam,
+ HasMaxIter,
+ HasFitIntercept,
+ HasTol,
+ HasStandardization,
+ HasWeightCol,
+ HasAggregationDepth,
+ HasThreshold,
+ HasMaxBlockSizeInMB,
+):
"""
Params for :py:class:`LogisticRegression` and :py:class:`LogisticRegressionModel`.
.. versionadded:: 3.0.0
"""
- threshold = Param(Params._dummy(), "threshold",
- "Threshold in binary classification prediction, in range [0, 1]." +
- " If threshold and thresholds are both set, they must match." +
- "e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
- typeConverter=TypeConverters.toFloat)
-
- family = Param(Params._dummy(), "family",
- "The name of family which is a description of the label distribution to " +
- "be used in the model. Supported options: auto, binomial, multinomial",
- typeConverter=TypeConverters.toString)
-
- lowerBoundsOnCoefficients = Param(Params._dummy(), "lowerBoundsOnCoefficients",
- "The lower bounds on coefficients if fitting under bound "
- "constrained optimization. The bound matrix must be "
- "compatible with the shape "
- "(1, number of features) for binomial regression, or "
- "(number of classes, number of features) "
- "for multinomial regression.",
- typeConverter=TypeConverters.toMatrix)
-
- upperBoundsOnCoefficients = Param(Params._dummy(), "upperBoundsOnCoefficients",
- "The upper bounds on coefficients if fitting under bound "
- "constrained optimization. The bound matrix must be "
- "compatible with the shape "
- "(1, number of features) for binomial regression, or "
- "(number of classes, number of features) "
- "for multinomial regression.",
- typeConverter=TypeConverters.toMatrix)
-
- lowerBoundsOnIntercepts = Param(Params._dummy(), "lowerBoundsOnIntercepts",
- "The lower bounds on intercepts if fitting under bound "
- "constrained optimization. The bounds vector size must be"
- "equal with 1 for binomial regression, or the number of"
- "lasses for multinomial regression.",
- typeConverter=TypeConverters.toVector)
-
- upperBoundsOnIntercepts = Param(Params._dummy(), "upperBoundsOnIntercepts",
- "The upper bounds on intercepts if fitting under bound "
- "constrained optimization. The bound vector size must be "
- "equal with 1 for binomial regression, or the number of "
- "classes for multinomial regression.",
- typeConverter=TypeConverters.toVector)
+ threshold = Param(
+ Params._dummy(),
+ "threshold",
+ "Threshold in binary classification prediction, in range [0, 1]."
+ + " If threshold and thresholds are both set, they must match."
+ + "e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
+ typeConverter=TypeConverters.toFloat,
+ )
+
+ family = Param(
+ Params._dummy(),
+ "family",
+ "The name of family which is a description of the label distribution to "
+ + "be used in the model. Supported options: auto, binomial, multinomial",
+ typeConverter=TypeConverters.toString,
+ )
+
+ lowerBoundsOnCoefficients = Param(
+ Params._dummy(),
+ "lowerBoundsOnCoefficients",
+ "The lower bounds on coefficients if fitting under bound "
+ "constrained optimization. The bound matrix must be "
+ "compatible with the shape "
+ "(1, number of features) for binomial regression, or "
+ "(number of classes, number of features) "
+ "for multinomial regression.",
+ typeConverter=TypeConverters.toMatrix,
+ )
+
+ upperBoundsOnCoefficients = Param(
+ Params._dummy(),
+ "upperBoundsOnCoefficients",
+ "The upper bounds on coefficients if fitting under bound "
+ "constrained optimization. The bound matrix must be "
+ "compatible with the shape "
+ "(1, number of features) for binomial regression, or "
+ "(number of classes, number of features) "
+ "for multinomial regression.",
+ typeConverter=TypeConverters.toMatrix,
+ )
+
+ lowerBoundsOnIntercepts = Param(
+ Params._dummy(),
+ "lowerBoundsOnIntercepts",
+ "The lower bounds on intercepts if fitting under bound "
+ "constrained optimization. The bounds vector size must be"
+ "equal with 1 for binomial regression, or the number of"
+ "lasses for multinomial regression.",
+ typeConverter=TypeConverters.toVector,
+ )
+
+ upperBoundsOnIntercepts = Param(
+ Params._dummy(),
+ "upperBoundsOnIntercepts",
+ "The upper bounds on intercepts if fitting under bound "
+ "constrained optimization. The bound vector size must be "
+ "equal with 1 for binomial regression, or the number of "
+ "classes for multinomial regression.",
+ typeConverter=TypeConverters.toVector,
+ )
def __init__(self, *args):
super(_LogisticRegressionParams, self).__init__(*args)
- self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto",
- maxBlockSizeInMB=0.0)
+ self._setDefault(
+ maxIter=100, regParam=0.0, tol=1e-6, threshold=0.5, family="auto", maxBlockSizeInMB=0.0
+ )
@since("1.4.0")
def setThreshold(self, value):
@@ -865,10 +995,13 @@ class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
if self.isSet(self.thresholds):
ts = self.getOrDefault(self.thresholds)
if len(ts) != 2:
- raise ValueError("Logistic Regression getThreshold only applies to" +
- " binary classification, but thresholds has length != 2." +
- " thresholds: " + ",".join(ts))
- return 1.0/(1.0 + ts[0]/ts[1])
+ raise ValueError(
+ "Logistic Regression getThreshold only applies to"
+ + " binary classification, but thresholds has length != 2."
+ + " thresholds: "
+ + ",".join(ts)
+ )
+ return 1.0 / (1.0 + ts[0] / ts[1])
else:
return self.getOrDefault(self.threshold)
@@ -893,7 +1026,7 @@ class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
self._checkThresholdConsistency()
if not self.isSet(self.thresholds) and self.isSet(self.threshold):
t = self.getOrDefault(self.threshold)
- return [1.0-t, t]
+ return [1.0 - t, t]
else:
return self.getOrDefault(self.thresholds)
@@ -901,14 +1034,18 @@ class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
if self.isSet(self.threshold) and self.isSet(self.thresholds):
ts = self.getOrDefault(self.thresholds)
if len(ts) != 2:
- raise ValueError("Logistic Regression getThreshold only applies to" +
- " binary classification, but thresholds has length != 2." +
- " thresholds: {0}".format(str(ts)))
- t = 1.0/(1.0 + ts[0]/ts[1])
+ raise ValueError(
+ "Logistic Regression getThreshold only applies to"
+ + " binary classification, but thresholds has length != 2."
+ + " thresholds: {0}".format(str(ts))
+ )
+ t = 1.0 / (1.0 + ts[0] / ts[1])
t2 = self.getOrDefault(self.threshold)
- if abs(t2 - t) >= 1E-5:
- raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
- " threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
+ if abs(t2 - t) >= 1e-5:
+ raise ValueError(
+ "Logistic Regression getThreshold found inconsistent values for"
+ + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)
+ )
@since("2.1.0")
def getFamily(self):
@@ -947,8 +1084,9 @@ class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
@inherit_doc
-class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams, JavaMLWritable,
- JavaMLReadable):
+class LogisticRegression(
+ _JavaProbabilisticClassifier, _LogisticRegressionParams, JavaMLWritable, JavaMLReadable
+):
"""
Logistic regression.
This class supports multinomial logistic (softmax) and binomial logistic regression.
@@ -1043,14 +1181,31 @@ class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=0.5, thresholds=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
- aggregationDepth=2, family="auto",
- lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
- lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None,
- maxBlockSizeInMB=0.0):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ maxIter=100,
+ regParam=0.0,
+ elasticNetParam=0.0,
+ tol=1e-6,
+ fitIntercept=True,
+ threshold=0.5,
+ thresholds=None,
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ standardization=True,
+ weightCol=None,
+ aggregationDepth=2,
+ family="auto",
+ lowerBoundsOnCoefficients=None,
+ upperBoundsOnCoefficients=None,
+ lowerBoundsOnIntercepts=None,
+ upperBoundsOnIntercepts=None,
+ maxBlockSizeInMB=0.0,
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
@@ -1065,21 +1220,39 @@ class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams
"""
super(LogisticRegression, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.classification.LogisticRegression", self.uid)
+ "org.apache.spark.ml.classification.LogisticRegression", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
self._checkThresholdConsistency()
@keyword_only
@since("1.3.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- threshold=0.5, thresholds=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
- aggregationDepth=2, family="auto",
- lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
- lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None,
- maxBlockSizeInMB=0.0):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ maxIter=100,
+ regParam=0.0,
+ elasticNetParam=0.0,
+ tol=1e-6,
+ fitIntercept=True,
+ threshold=0.5,
+ thresholds=None,
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ standardization=True,
+ weightCol=None,
+ aggregationDepth=2,
+ family="auto",
+ lowerBoundsOnCoefficients=None,
+ upperBoundsOnCoefficients=None,
+ lowerBoundsOnIntercepts=None,
+ upperBoundsOnIntercepts=None,
+ maxBlockSizeInMB=0.0,
+ ):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
@@ -1191,8 +1364,13 @@ class LogisticRegression(_JavaProbabilisticClassifier, _LogisticRegressionParams
return self._set(maxBlockSizeInMB=value)
-class LogisticRegressionModel(_JavaProbabilisticClassificationModel, _LogisticRegressionParams,
- JavaMLWritable, JavaMLReadable, HasTrainingSummary):
+class LogisticRegressionModel(
+ _JavaProbabilisticClassificationModel,
+ _LogisticRegressionParams,
+ JavaMLWritable,
+ JavaMLReadable,
+ HasTrainingSummary,
+):
"""
Model fitted by LogisticRegression.
@@ -1242,14 +1420,17 @@ class LogisticRegressionModel(_JavaProbabilisticClassificationModel, _LogisticRe
"""
if self.hasSummary:
if self.numClasses <= 2:
- return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
- self).summary)
+ return BinaryLogisticRegressionTrainingSummary(
+ super(LogisticRegressionModel, self).summary
+ )
else:
- return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
- self).summary)
+ return LogisticRegressionTrainingSummary(
+ super(LogisticRegressionModel, self).summary
+ )
else:
- raise RuntimeError("No training summary available for this %s" %
- self.__class__.__name__)
+ raise RuntimeError(
+ "No training summary available for this %s" % self.__class__.__name__
+ )
def evaluate(self, dataset):
"""
@@ -1304,28 +1485,31 @@ class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSumm
.. versionadded:: 2.0.0
"""
+
pass
@inherit_doc
-class BinaryLogisticRegressionSummary(_BinaryClassificationSummary,
- LogisticRegressionSummary):
+class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, LogisticRegressionSummary):
"""
Binary Logistic regression results for a given model.
.. versionadded:: 2.0.0
"""
+
pass
@inherit_doc
-class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
- LogisticRegressionTrainingSummary):
+class BinaryLogisticRegressionTrainingSummary(
+ BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary
+):
"""
Binary Logistic regression training results for a given model.
.. versionadded:: 2.0.0
"""
+
pass
@@ -1337,14 +1521,24 @@ class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
def __init__(self, *args):
super(_DecisionTreeClassifierParams, self).__init__(*args)
- self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- impurity="gini", leafCol="", minWeightFractionPerNode=0.0)
+ self._setDefault(
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ impurity="gini",
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ )
@inherit_doc
-class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
- JavaMLWritable, JavaMLReadable):
+class DecisionTreeClassifier(
+ _JavaProbabilisticClassifier, _DecisionTreeClassifierParams, JavaMLWritable, JavaMLReadable
+):
"""
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
learning algorithm for classification.
@@ -1426,11 +1620,27 @@ class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifi
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- probabilityCol="probability", rawPredictionCol="rawPrediction",
- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
- seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ impurity="gini",
+ seed=None,
+ weightCol=None,
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -1440,18 +1650,34 @@ class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifi
"""
super(DecisionTreeClassifier, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
+ "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.4.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- probabilityCol="probability", rawPredictionCol="rawPrediction",
- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- impurity="gini", seed=None, weightCol=None, leafCol="",
- minWeightFractionPerNode=0.0):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ impurity="gini",
+ seed=None,
+ weightCol=None,
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ ):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -1538,9 +1764,13 @@ class DecisionTreeClassifier(_JavaProbabilisticClassifier, _DecisionTreeClassifi
@inherit_doc
-class DecisionTreeClassificationModel(_DecisionTreeModel, _JavaProbabilisticClassificationModel,
- _DecisionTreeClassifierParams, JavaMLWritable,
- JavaMLReadable):
+class DecisionTreeClassificationModel(
+ _DecisionTreeModel,
+ _JavaProbabilisticClassificationModel,
+ _DecisionTreeClassifierParams,
+ JavaMLWritable,
+ JavaMLReadable,
+):
"""
Model fitted by DecisionTreeClassifier.
@@ -1580,16 +1810,28 @@ class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
def __init__(self, *args):
super(_RandomForestClassifierParams, self).__init__(*args)
- self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- impurity="gini", numTrees=20, featureSubsetStrategy="auto",
- subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0,
- bootstrap=True)
+ self._setDefault(
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ impurity="gini",
+ numTrees=20,
+ featureSubsetStrategy="auto",
+ subsamplingRate=1.0,
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ bootstrap=True,
+ )
@inherit_doc
-class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifierParams,
- JavaMLWritable, JavaMLReadable):
+class RandomForestClassifier(
+ _JavaProbabilisticClassifier, _RandomForestClassifierParams, JavaMLWritable, JavaMLReadable
+):
"""
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
learning algorithm for classification.
@@ -1665,12 +1907,31 @@ class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifi
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- probabilityCol="probability", rawPredictionCol="rawPrediction",
- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
- numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
- leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ impurity="gini",
+ numTrees=20,
+ featureSubsetStrategy="auto",
+ seed=None,
+ subsamplingRate=1.0,
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ weightCol=None,
+ bootstrap=True,
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -1681,18 +1942,38 @@ class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifi
"""
super(RandomForestClassifier, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
+ "org.apache.spark.ml.classification.RandomForestClassifier", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.4.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- probabilityCol="probability", rawPredictionCol="rawPrediction",
- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
- impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
- leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ seed=None,
+ impurity="gini",
+ numTrees=20,
+ featureSubsetStrategy="auto",
+ subsamplingRate=1.0,
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ weightCol=None,
+ bootstrap=True,
+ ):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -1806,9 +2087,14 @@ class RandomForestClassifier(_JavaProbabilisticClassifier, _RandomForestClassifi
return self._set(minWeightFractionPerNode=value)
-class RandomForestClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClassificationModel,
- _RandomForestClassifierParams, JavaMLWritable,
- JavaMLReadable, HasTrainingSummary):
+class RandomForestClassificationModel(
+ _TreeEnsembleModel,
+ _JavaProbabilisticClassificationModel,
+ _RandomForestClassifierParams,
+ JavaMLWritable,
+ JavaMLReadable,
+ HasTrainingSummary,
+):
"""
Model fitted by RandomForestClassifier.
@@ -1849,13 +2135,16 @@ class RandomForestClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClas
if self.hasSummary:
if self.numClasses <= 2:
return BinaryRandomForestClassificationTrainingSummary(
- super(RandomForestClassificationModel, self).summary)
+ super(RandomForestClassificationModel, self).summary
+ )
else:
return RandomForestClassificationTrainingSummary(
- super(RandomForestClassificationModel, self).summary)
+ super(RandomForestClassificationModel, self).summary
+ )
else:
- raise RuntimeError("No training summary available for this %s" %
- self.__class__.__name__)
+ raise RuntimeError(
+ "No training summary available for this %s" % self.__class__.__name__
+ )
def evaluate(self, dataset):
"""
@@ -1883,17 +2172,20 @@ class RandomForestClassificationSummary(_ClassificationSummary):
.. versionadded:: 3.1.0
"""
+
pass
@inherit_doc
-class RandomForestClassificationTrainingSummary(RandomForestClassificationSummary,
- _TrainingSummary):
+class RandomForestClassificationTrainingSummary(
+ RandomForestClassificationSummary, _TrainingSummary
+):
"""
Abstraction for RandomForestClassificationTraining Training results.
.. versionadded:: 3.1.0
"""
+
pass
@@ -1904,17 +2196,20 @@ class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary):
.. versionadded:: 3.1.0
"""
+
pass
@inherit_doc
-class BinaryRandomForestClassificationTrainingSummary(BinaryRandomForestClassificationSummary,
- RandomForestClassificationTrainingSummary):
+class BinaryRandomForestClassificationTrainingSummary(
+ BinaryRandomForestClassificationSummary, RandomForestClassificationTrainingSummary
+):
"""
BinaryRandomForestClassification training results for a given model.
.. versionadded:: 3.1.0
"""
+
pass
@@ -1927,18 +2222,35 @@ class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
supportedLossTypes = ["logistic"]
- lossType = Param(Params._dummy(), "lossType",
- "Loss function which GBT tries to minimize (case-insensitive). " +
- "Supported options: " + ", ".join(supportedLossTypes),
- typeConverter=TypeConverters.toString)
+ lossType = Param(
+ Params._dummy(),
+ "lossType",
+ "Loss function which GBT tries to minimize (case-insensitive). "
+ + "Supported options: "
+ + ", ".join(supportedLossTypes),
+ typeConverter=TypeConverters.toString,
+ )
def __init__(self, *args):
super(_GBTClassifierParams, self).__init__(*args)
- self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
- impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
- leafCol="", minWeightFractionPerNode=0.0)
+ self._setDefault(
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ lossType="logistic",
+ maxIter=20,
+ stepSize=0.1,
+ subsamplingRate=1.0,
+ impurity="variance",
+ featureSubsetStrategy="all",
+ validationTol=0.01,
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ )
@since("1.4.0")
def getLossType(self):
@@ -1949,8 +2261,9 @@ class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
@inherit_doc
-class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
- JavaMLWritable, JavaMLReadable):
+class GBTClassifier(
+ _JavaProbabilisticClassifier, _GBTClassifierParams, JavaMLWritable, JavaMLReadable
+):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
learning algorithm for classification.
@@ -2056,12 +2369,32 @@ class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
- maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
- featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
- leafCol="", minWeightFractionPerNode=0.0, weightCol=None):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ lossType="logistic",
+ maxIter=20,
+ stepSize=0.1,
+ seed=None,
+ subsamplingRate=1.0,
+ impurity="variance",
+ featureSubsetStrategy="all",
+ validationTol=0.01,
+ validationIndicatorCol=None,
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ weightCol=None,
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
@@ -2073,19 +2406,39 @@ class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
"""
super(GBTClassifier, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.classification.GBTClassifier", self.uid)
+ "org.apache.spark.ml.classification.GBTClassifier", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.4.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
- impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
- validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
- weightCol=None):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ maxDepth=5,
+ maxBins=32,
+ minInstancesPerNode=1,
+ minInfoGain=0.0,
+ maxMemoryInMB=256,
+ cacheNodeIds=False,
+ checkpointInterval=10,
+ lossType="logistic",
+ maxIter=20,
+ stepSize=0.1,
+ seed=None,
+ subsamplingRate=1.0,
+ impurity="variance",
+ featureSubsetStrategy="all",
+ validationTol=0.01,
+ validationIndicatorCol=None,
+ leafCol="",
+ minWeightFractionPerNode=0.0,
+ weightCol=None,
+ ):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
@@ -2216,8 +2569,13 @@ class GBTClassifier(_JavaProbabilisticClassifier, _GBTClassifierParams,
return self._set(minWeightFractionPerNode=value)
-class GBTClassificationModel(_TreeEnsembleModel, _JavaProbabilisticClassificationModel,
- _GBTClassifierParams, JavaMLWritable, JavaMLReadable):
+class GBTClassificationModel(
+ _TreeEnsembleModel,
+ _JavaProbabilisticClassificationModel,
+ _GBTClassifierParams,
+ JavaMLWritable,
+ JavaMLReadable,
+):
"""
Model fitted by GBTClassifier.
@@ -2269,12 +2627,20 @@ class _NaiveBayesParams(_PredictorParams, HasWeightCol):
.. versionadded:: 3.0.0
"""
- smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
- "default is 1.0", typeConverter=TypeConverters.toFloat)
- modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
- "(case-sensitive). Supported options: multinomial (default), bernoulli " +
- "and gaussian.",
- typeConverter=TypeConverters.toString)
+ smoothing = Param(
+ Params._dummy(),
+ "smoothing",
+ "The smoothing parameter, should be >= 0, " + "default is 1.0",
+ typeConverter=TypeConverters.toFloat,
+ )
+ modelType = Param(
+ Params._dummy(),
+ "modelType",
+ "The model type which is a string "
+ + "(case-sensitive). Supported options: multinomial (default), bernoulli "
+ + "and gaussian.",
+ typeConverter=TypeConverters.toString,
+ )
def __init__(self, *args):
super(_NaiveBayesParams, self).__init__(*args)
@@ -2296,8 +2662,14 @@ class _NaiveBayesParams(_PredictorParams, HasWeightCol):
@inherit_doc
-class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds, HasWeightCol,
- JavaMLWritable, JavaMLReadable):
+class NaiveBayes(
+ _JavaProbabilisticClassifier,
+ _NaiveBayesParams,
+ HasThresholds,
+ HasWeightCol,
+ JavaMLWritable,
+ JavaMLReadable,
+):
"""
Naive Bayes Classifiers.
It supports both Multinomial and Bernoulli NB. `Multinomial NB \
@@ -2392,9 +2764,19 @@ class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
- modelType="multinomial", thresholds=None, weightCol=None):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ smoothing=1.0,
+ modelType="multinomial",
+ thresholds=None,
+ weightCol=None,
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
@@ -2402,15 +2784,26 @@ class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
"""
super(NaiveBayes, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.classification.NaiveBayes", self.uid)
+ "org.apache.spark.ml.classification.NaiveBayes", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.5.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
- modelType="multinomial", thresholds=None, weightCol=None):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ smoothing=1.0,
+ modelType="multinomial",
+ thresholds=None,
+ weightCol=None,
+ ):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
@@ -2444,8 +2837,9 @@ class NaiveBayes(_JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
return self._set(weightCol=value)
-class NaiveBayesModel(_JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable,
- JavaMLReadable):
+class NaiveBayesModel(
+ _JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable, JavaMLReadable
+):
"""
Model fitted by NaiveBayes.
@@ -2477,26 +2871,45 @@ class NaiveBayesModel(_JavaProbabilisticClassificationModel, _NaiveBayesParams,
return self._call_java("sigma")
-class _MultilayerPerceptronParams(_ProbabilisticClassifierParams, HasSeed, HasMaxIter,
- HasTol, HasStepSize, HasSolver, HasBlockSize):
+class _MultilayerPerceptronParams(
+ _ProbabilisticClassifierParams,
+ HasSeed,
+ HasMaxIter,
+ HasTol,
+ HasStepSize,
+ HasSolver,
+ HasBlockSize,
+):
"""
Params for :py:class:`MultilayerPerceptronClassifier`.
.. versionadded:: 3.0.0
"""
- layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
- "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
- "neurons and output layer of 10 neurons.",
- typeConverter=TypeConverters.toListInt)
- solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
- "options: l-bfgs, gd.", typeConverter=TypeConverters.toString)
- initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.",
- typeConverter=TypeConverters.toVector)
+ layers = Param(
+ Params._dummy(),
+ "layers",
+ "Sizes of layers from input layer to output layer "
+ + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 "
+ + "neurons and output layer of 10 neurons.",
+ typeConverter=TypeConverters.toListInt,
+ )
+ solver = Param(
+ Params._dummy(),
+ "solver",
+ "The solver algorithm for optimization. Supported " + "options: l-bfgs, gd.",
+ typeConverter=TypeConverters.toString,
+ )
+ initialWeights = Param(
+ Params._dummy(),
+ "initialWeights",
+ "The initial weights of the model.",
+ typeConverter=TypeConverters.toVector,
+ )
def __init__(self, *args):
super(_MultilayerPerceptronParams, self).__init__(*args)
- self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
+ self._setDefault(maxIter=100, tol=1e-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
@since("1.6.0")
def getLayers(self):
@@ -2514,8 +2927,9 @@ class _MultilayerPerceptronParams(_ProbabilisticClassifierParams, HasSeed, HasMa
@inherit_doc
-class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPerceptronParams,
- JavaMLWritable, JavaMLReadable):
+class MultilayerPerceptronClassifier(
+ _JavaProbabilisticClassifier, _MultilayerPerceptronParams, JavaMLWritable, JavaMLReadable
+):
"""
Classifier trainer based on the Multilayer Perceptron.
Each layer has sigmoid activation function, output layer has softmax.
@@ -2592,10 +3006,23 @@ class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPe
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
- solver="l-bfgs", initialWeights=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction"):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ maxIter=100,
+ tol=1e-6,
+ seed=None,
+ layers=None,
+ blockSize=128,
+ stepSize=0.03,
+ solver="l-bfgs",
+ initialWeights=None,
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
@@ -2604,16 +3031,30 @@ class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPe
"""
super(MultilayerPerceptronClassifier, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
+ "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
- solver="l-bfgs", initialWeights=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction"):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ maxIter=100,
+ tol=1e-6,
+ seed=None,
+ layers=None,
+ blockSize=128,
+ stepSize=0.03,
+ solver="l-bfgs",
+ initialWeights=None,
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ ):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
@@ -2680,9 +3121,13 @@ class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPe
return self._set(solver=value)
-class MultilayerPerceptronClassificationModel(_JavaProbabilisticClassificationModel,
- _MultilayerPerceptronParams, JavaMLWritable,
- JavaMLReadable, HasTrainingSummary):
+class MultilayerPerceptronClassificationModel(
+ _JavaProbabilisticClassificationModel,
+ _MultilayerPerceptronParams,
+ JavaMLWritable,
+ JavaMLReadable,
+ HasTrainingSummary,
+):
"""
Model fitted by MultilayerPerceptronClassifier.
@@ -2705,10 +3150,12 @@ class MultilayerPerceptronClassificationModel(_JavaProbabilisticClassificationMo
"""
if self.hasSummary:
return MultilayerPerceptronClassificationTrainingSummary(
- super(MultilayerPerceptronClassificationModel, self).summary)
+ super(MultilayerPerceptronClassificationModel, self).summary
+ )
else:
- raise RuntimeError("No training summary available for this %s" %
- self.__class__.__name__)
+ raise RuntimeError(
+ "No training summary available for this %s" % self.__class__.__name__
+ )
def evaluate(self, dataset):
"""
@@ -2733,17 +3180,20 @@ class MultilayerPerceptronClassificationSummary(_ClassificationSummary):
.. versionadded:: 3.1.0
"""
+
pass
@inherit_doc
-class MultilayerPerceptronClassificationTrainingSummary(MultilayerPerceptronClassificationSummary,
- _TrainingSummary):
+class MultilayerPerceptronClassificationTrainingSummary(
+ MultilayerPerceptronClassificationSummary, _TrainingSummary
+):
"""
Abstraction for MultilayerPerceptronClassifier Training results.
.. versionadded:: 3.1.0
"""
+
pass
@@ -2815,8 +3265,17 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ rawPredictionCol="rawPrediction",
+ classifier=None,
+ weightCol=None,
+ parallelism=1,
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
@@ -2828,8 +3287,17 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
@keyword_only
@since("2.0.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ rawPredictionCol="rawPrediction",
+ classifier=None,
+ weightCol=None,
+ parallelism=1,
+ ):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
@@ -2887,15 +3355,16 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
predictionCol = self.getPredictionCol()
classifier = self.getClassifier()
- numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
+ numClasses = int(dataset.agg({labelCol: "max"}).head()["max(" + labelCol + ")"]) + 1
weightCol = None
- if (self.isDefined(self.weightCol) and self.getWeightCol()):
+ if self.isDefined(self.weightCol) and self.getWeightCol():
if isinstance(classifier, HasWeightCol):
weightCol = self.getWeightCol()
else:
- warnings.warn("weightCol is ignored, "
- "as it is not supported by {} now.".format(classifier))
+ warnings.warn(
+ "weightCol is ignored, " "as it is not supported by {} now.".format(classifier)
+ )
if weightCol:
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
@@ -2911,10 +3380,15 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
binaryLabelCol = "mc2b$" + str(index)
trainingDataset = multiclassLabeled.withColumn(
binaryLabelCol,
- when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0))
- paramMap = dict([(classifier.labelCol, binaryLabelCol),
- (classifier.featuresCol, featuresCol),
- (classifier.predictionCol, predictionCol)])
+ when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
+ )
+ paramMap = dict(
+ [
+ (classifier.labelCol, binaryLabelCol),
+ (classifier.featuresCol, featuresCol),
+ (classifier.predictionCol, predictionCol),
+ ]
+ )
if weightCol:
paramMap[classifier.weightCol] = weightCol
return classifier.fit(trainingDataset, paramMap)
@@ -2965,9 +3439,14 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
rawPredictionCol = java_stage.getRawPredictionCol()
classifier = JavaParams._from_java(java_stage.getClassifier())
parallelism = java_stage.getParallelism()
- py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
- rawPredictionCol=rawPredictionCol, classifier=classifier,
- parallelism=parallelism)
+ py_stage = cls(
+ featuresCol=featuresCol,
+ labelCol=labelCol,
+ predictionCol=predictionCol,
+ rawPredictionCol=rawPredictionCol,
+ classifier=classifier,
+ parallelism=parallelism,
+ )
if java_stage.isDefined(java_stage.getParam("weightCol")):
py_stage.setWeightCol(java_stage.getWeightCol())
py_stage._resetUid(java_stage.uid())
@@ -2982,14 +3461,15 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
py4j.java_gateway.JavaObject
Java object equivalent to this instance.
"""
- _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
- self.uid)
+ _java_obj = JavaParams._new_java_obj(
+ "org.apache.spark.ml.classification.OneVsRest", self.uid
+ )
_java_obj.setClassifier(self.getClassifier()._to_java())
_java_obj.setParallelism(self.getParallelism())
_java_obj.setFeaturesCol(self.getFeaturesCol())
_java_obj.setLabelCol(self.getLabelCol())
_java_obj.setPredictionCol(self.getPredictionCol())
- if (self.isDefined(self.weightCol) and self.getWeightCol()):
+ if self.isDefined(self.weightCol) and self.getWeightCol():
_java_obj.setWeightCol(self.getWeightCol())
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
return _java_obj
@@ -3008,16 +3488,17 @@ class OneVsRest(Estimator, _OneVsRestParams, HasParallelism, MLReadable, MLWrita
class _OneVsRestSharedReadWrite:
@staticmethod
def saveImpl(instance, sc, path, extraMetadata=None):
- skipParams = ['classifier']
+ skipParams = ["classifier"]
jsonParams = DefaultParamsWriter.extractJsonParams(instance, skipParams)
- DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams,
- extraMetadata=extraMetadata)
- classifierPath = os.path.join(path, 'classifier')
+ DefaultParamsWriter.saveMetadata(
+ instance, path, sc, paramMap=jsonParams, extraMetadata=extraMetadata
+ )
+ classifierPath = os.path.join(path, "classifier")
instance.getClassifier().save(classifierPath)
@staticmethod
def loadClassifier(path, sc):
- classifierPath = os.path.join(path, 'classifier')
+ classifierPath = os.path.join(path, "classifier")
return DefaultParamsReader.loadParamsInstance(classifierPath, sc)
@staticmethod
@@ -3028,8 +3509,10 @@ class _OneVsRestSharedReadWrite:
for elem in elems_to_check:
if not isinstance(elem, MLWritable):
- raise ValueError(f'OneVsRest write will fail because it contains {elem.uid} '
- f'which is not writable.')
+ raise ValueError(
+ f"OneVsRest write will fail because it contains {elem.uid} "
+ f"which is not writable."
+ )
@inherit_doc
@@ -3044,8 +3527,8 @@ class OneVsRestReader(MLReader):
return JavaMLReader(self.cls).load(path)
else:
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
- ova = OneVsRest(classifier=classifier)._resetUid(metadata['uid'])
- DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=['classifier'])
+ ova = OneVsRest(classifier=classifier)._resetUid(metadata["uid"])
+ DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=["classifier"])
return ova
@@ -3096,14 +3579,17 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
# set java instance
java_models = [model._to_java() for model in self.models]
sc = SparkContext._active_spark_context
- java_models_array = JavaWrapper._new_java_array(java_models,
- sc._gateway.jvm.org.apache.spark.ml
- .classification.ClassificationModel)
+ java_models_array = JavaWrapper._new_java_array(
+ java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel
+ )
# TODO: need to set metadata
metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
- self._java_obj = \
- JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
- self.uid, metadata.empty(), java_models_array)
+ self._java_obj = JavaParams._new_java_obj(
+ "org.apache.spark.ml.classification.OneVsRestModel",
+ self.uid,
+ metadata.empty(),
+ java_models_array,
+ )
def _transform(self, dataset):
# determine the input columns: these need to be passed through
@@ -3130,21 +3616,25 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
tmpColName = "mbc$tmp" + str(uuid.uuid4())
updateUDF = udf(
lambda predictions, prediction: predictions + [prediction.tolist()[1]],
- ArrayType(DoubleType()))
+ ArrayType(DoubleType()),
+ )
transformedDataset = model.transform(aggregatedDataset).select(*columns)
updatedDataset = transformedDataset.withColumn(
tmpColName,
- updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]))
+ updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]),
+ )
newColumns = origCols + [tmpColName]
# switch out the intermediate column with the accumulator column
- aggregatedDataset = updatedDataset\
- .select(*newColumns).withColumnRenamed(tmpColName, accColName)
+ aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed(
+ tmpColName, accColName
+ )
if handlePersistence:
newDataset.unpersist()
if self.getRawPredictionCol():
+
def func(predictions):
predArray = []
for x in predictions:
@@ -3153,14 +3643,20 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
rawPredictionUDF = udf(func, VectorUDT())
aggregatedDataset = aggregatedDataset.withColumn(
- self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))
+ self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName])
+ )
if self.getPredictionCol():
# output the index of the classifier with highest confidence as prediction
- labelUDF = udf(lambda predictions: float(max(enumerate(predictions),
- key=operator.itemgetter(1))[0]), DoubleType())
+ labelUDF = udf(
+ lambda predictions: float(
+ max(enumerate(predictions), key=operator.itemgetter(1))[0]
+ ),
+ DoubleType(),
+ )
aggregatedDataset = aggregatedDataset.withColumn(
- self.getPredictionCol(), labelUDF(aggregatedDataset[accColName]))
+ self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])
+ )
return aggregatedDataset.drop(accColName)
def copy(self, extra=None):
@@ -3198,8 +3694,7 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
predictionCol = java_stage.getPredictionCol()
classifier = JavaParams._from_java(java_stage.getClassifier())
models = [JavaParams._from_java(model) for model in java_stage.models()]
- py_stage = cls(models=models).setPredictionCol(predictionCol)\
- .setFeaturesCol(featuresCol)
+ py_stage = cls(models=models).setPredictionCol(predictionCol).setFeaturesCol(featuresCol)
py_stage._set(labelCol=labelCol)
if java_stage.isDefined(java_stage.getParam("weightCol")):
py_stage._set(weightCol=java_stage.getWeightCol())
@@ -3219,15 +3714,20 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
sc = SparkContext._active_spark_context
java_models = [model._to_java() for model in self.models]
java_models_array = JavaWrapper._new_java_array(
- java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel)
+ java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel
+ )
metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
- _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
- self.uid, metadata.empty(), java_models_array)
+ _java_obj = JavaParams._new_java_obj(
+ "org.apache.spark.ml.classification.OneVsRestModel",
+ self.uid,
+ metadata.empty(),
+ java_models_array,
+ )
_java_obj.set("classifier", self.getClassifier()._to_java())
_java_obj.set("featuresCol", self.getFeaturesCol())
_java_obj.set("labelCol", self.getLabelCol())
_java_obj.set("predictionCol", self.getPredictionCol())
- if (self.isDefined(self.weightCol) and self.getWeightCol()):
+ if self.isDefined(self.weightCol) and self.getWeightCol():
_java_obj.set("weightCol", self.getWeightCol())
return _java_obj
@@ -3236,8 +3736,9 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
return OneVsRestModelReader(cls)
def write(self):
- if all(map(lambda elem: isinstance(elem, JavaMLWritable),
- [self.getClassifier()] + self.models)):
+ if all(
+ map(lambda elem: isinstance(elem, JavaMLWritable), [self.getClassifier()] + self.models)
+ ):
return JavaMLWriter(self)
else:
return OneVsRestModelWriter(self)
@@ -3255,14 +3756,14 @@ class OneVsRestModelReader(MLReader):
return JavaMLReader(self.cls).load(path)
else:
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
- numClasses = metadata['numClasses']
+ numClasses = metadata["numClasses"]
subModels = [None] * numClasses
for idx in range(numClasses):
- subModelPath = os.path.join(path, f'model_{idx}')
+ subModelPath = os.path.join(path, f"model_{idx}")
subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc)
- ovaModel = OneVsRestModel(subModels)._resetUid(metadata['uid'])
+ ovaModel = OneVsRestModel(subModels)._resetUid(metadata["uid"])
ovaModel.set(ovaModel.classifier, classifier)
- DefaultParamsReader.getAndSetParams(ovaModel, metadata, skipParams=['classifier'])
+ DefaultParamsReader.getAndSetParams(ovaModel, metadata, skipParams=["classifier"])
return ovaModel
@@ -3276,16 +3777,17 @@ class OneVsRestModelWriter(MLWriter):
_OneVsRestSharedReadWrite.validateParams(self.instance)
instance = self.instance
numClasses = len(instance.models)
- extraMetadata = {'numClasses': numClasses}
+ extraMetadata = {"numClasses": numClasses}
_OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata)
for idx in range(numClasses):
- subModelPath = os.path.join(path, f'model_{idx}')
+ subModelPath = os.path.join(path, f"model_{idx}")
instance.models[idx].save(subModelPath)
@inherit_doc
-class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable,
- JavaMLReadable):
+class FMClassifier(
+ _JavaProbabilisticClassifier, _FactorizationMachinesParams, JavaMLWritable, JavaMLReadable
+):
"""
Factorization Machines learning algorithm for classification.
@@ -3348,11 +3850,27 @@ class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, J
"""
@keyword_only
- def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- probabilityCol="probability", rawPredictionCol="rawPrediction",
- factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
- miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
- tol=1e-6, solver="adamW", thresholds=None, seed=None):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ factorSize=8,
+ fitIntercept=True,
+ fitLinear=True,
+ regParam=0.0,
+ miniBatchFraction=1.0,
+ initStd=0.01,
+ maxIter=100,
+ stepSize=1.0,
+ tol=1e-6,
+ solver="adamW",
+ thresholds=None,
+ seed=None,
+ ):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -3362,17 +3880,34 @@ class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, J
"""
super(FMClassifier, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.classification.FMClassifier", self.uid)
+ "org.apache.spark.ml.classification.FMClassifier", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("3.0.0")
- def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
- probabilityCol="probability", rawPredictionCol="rawPrediction",
- factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
- miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
- tol=1e-6, solver="adamW", thresholds=None, seed=None):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ labelCol="label",
+ predictionCol="prediction",
+ probabilityCol="probability",
+ rawPredictionCol="rawPrediction",
+ factorSize=8,
+ fitIntercept=True,
+ fitLinear=True,
+ regParam=0.0,
+ miniBatchFraction=1.0,
+ initStd=0.01,
+ maxIter=100,
+ stepSize=1.0,
+ tol=1e-6,
+ solver="adamW",
+ thresholds=None,
+ seed=None,
+ ):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
@@ -3465,8 +4000,13 @@ class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, J
return self._set(regParam=value)
-class FMClassificationModel(_JavaProbabilisticClassificationModel, _FactorizationMachinesParams,
- JavaMLWritable, JavaMLReadable, HasTrainingSummary):
+class FMClassificationModel(
+ _JavaProbabilisticClassificationModel,
+ _FactorizationMachinesParams,
+ JavaMLWritable,
+ JavaMLReadable,
+ HasTrainingSummary,
+):
"""
Model fitted by :class:`FMClassifier`.
@@ -3506,8 +4046,9 @@ class FMClassificationModel(_JavaProbabilisticClassificationModel, _Factorizatio
if self.hasSummary:
return FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
else:
- raise RuntimeError("No training summary available for this %s" %
- self.__class__.__name__)
+ raise RuntimeError(
+ "No training summary available for this %s" % self.__class__.__name__
+ )
def evaluate(self, dataset):
"""
@@ -3532,6 +4073,7 @@ class FMClassificationSummary(_BinaryClassificationSummary):
.. versionadded:: 3.1.0
"""
+
pass
@@ -3542,6 +4084,7 @@ class FMClassificationTrainingSummary(FMClassificationSummary, _TrainingSummary)
.. versionadded:: 3.1.0
"""
+
pass
@@ -3549,24 +4092,24 @@ if __name__ == "__main__":
import doctest
import pyspark.ml.classification
from pyspark.sql import SparkSession
+
globs = pyspark.ml.classification.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- spark = SparkSession.builder\
- .master("local[2]")\
- .appName("ml.classification tests")\
- .getOrCreate()
+ spark = SparkSession.builder.master("local[2]").appName("ml.classification tests").getOrCreate()
sc = spark.sparkContext
- globs['sc'] = sc
- globs['spark'] = spark
+ globs["sc"] = sc
+ globs["spark"] = spark
import tempfile
+
temp_path = tempfile.mkdtemp()
- globs['temp_path'] = temp_path
+ globs["temp_path"] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
finally:
from shutil import rmtree
+
try:
rmtree(temp_path)
except OSError:
diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi
index a4a3d21..bb4fb05 100644
--- a/python/pyspark/ml/classification.pyi
+++ b/python/pyspark/ml/classification.pyi
@@ -53,8 +53,15 @@ from pyspark.ml.tree import (
_TreeClassifierParams,
_TreeEnsembleModel,
)
-from pyspark.ml.util import HasTrainingSummary, JavaMLReadable, JavaMLWritable, \
- MLReader, MLReadable, MLWriter, MLWritable
+from pyspark.ml.util import (
+ HasTrainingSummary,
+ JavaMLReadable,
+ JavaMLWritable,
+ MLReader,
+ MLReadable,
+ MLWriter,
+ MLWritable,
+)
from pyspark.ml.wrapper import JavaPredictionModel, JavaPredictor, JavaWrapper
from pyspark.ml.linalg import Matrix, Vector
@@ -75,13 +82,9 @@ class ClassificationModel(PredictionModel, _ClassifierParams, metaclass=abc.ABCM
@abstractmethod
def predictRaw(self, value: Vector) -> Vector: ...
-class _ProbabilisticClassifierParams(
- HasProbabilityCol, HasThresholds, _ClassifierParams
-): ...
+class _ProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _ClassifierParams): ...
-class ProbabilisticClassifier(
- Classifier, _ProbabilisticClassifierParams, metaclass=abc.ABCMeta
-):
+class ProbabilisticClassifier(Classifier, _ProbabilisticClassifierParams, metaclass=abc.ABCMeta):
def setProbabilityCol(self: P, value: str) -> P: ...
def setThresholds(self: P, value: List[float]) -> P: ...
@@ -200,7 +203,7 @@ class LinearSVC(
threshold: float = ...,
weightCol: Optional[str] = ...,
aggregationDepth: int = ...,
- maxBlockSizeInMB: float = ...
+ maxBlockSizeInMB: float = ...,
) -> None: ...
def setParams(
self,
@@ -217,7 +220,7 @@ class LinearSVC(
threshold: float = ...,
weightCol: Optional[str] = ...,
aggregationDepth: int = ...,
- maxBlockSizeInMB: float = ...
+ maxBlockSizeInMB: float = ...,
) -> LinearSVC: ...
def setMaxIter(self, value: int) -> LinearSVC: ...
def setRegParam(self, value: float) -> LinearSVC: ...
@@ -306,7 +309,7 @@ class LogisticRegression(
upperBoundsOnCoefficients: Optional[Matrix] = ...,
lowerBoundsOnIntercepts: Optional[Vector] = ...,
upperBoundsOnIntercepts: Optional[Vector] = ...,
- maxBlockSizeInMB: float = ...
+ maxBlockSizeInMB: float = ...,
) -> None: ...
def setParams(
self,
@@ -331,7 +334,7 @@ class LogisticRegression(
upperBoundsOnCoefficients: Optional[Matrix] = ...,
lowerBoundsOnIntercepts: Optional[Vector] = ...,
upperBoundsOnIntercepts: Optional[Vector] = ...,
- maxBlockSizeInMB: float = ...
+ maxBlockSizeInMB: float = ...,
) -> LogisticRegression: ...
def setFamily(self, value: str) -> LogisticRegression: ...
def setLowerBoundsOnCoefficients(self, value: Matrix) -> LogisticRegression: ...
@@ -373,12 +376,8 @@ class LogisticRegressionSummary(_ClassificationSummary):
@property
def featuresCol(self) -> str: ...
-class LogisticRegressionTrainingSummary(
- LogisticRegressionSummary, _TrainingSummary
-): ...
-class BinaryLogisticRegressionSummary(
- _BinaryClassificationSummary, LogisticRegressionSummary
-): ...
+class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSummary): ...
+class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, LogisticRegressionSummary): ...
class BinaryLogisticRegressionTrainingSummary(
BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary
): ...
@@ -411,7 +410,7 @@ class DecisionTreeClassifier(
seed: Optional[int] = ...,
weightCol: Optional[str] = ...,
leafCol: str = ...,
- minWeightFractionPerNode: float = ...
+ minWeightFractionPerNode: float = ...,
) -> None: ...
def setParams(
self,
@@ -432,7 +431,7 @@ class DecisionTreeClassifier(
seed: Optional[int] = ...,
weightCol: Optional[str] = ...,
leafCol: str = ...,
- minWeightFractionPerNode: float = ...
+ minWeightFractionPerNode: float = ...,
) -> DecisionTreeClassifier: ...
def setMaxDepth(self, value: int) -> DecisionTreeClassifier: ...
def setMaxBins(self, value: int) -> DecisionTreeClassifier: ...
@@ -488,7 +487,7 @@ class RandomForestClassifier(
leafCol: str = ...,
minWeightFractionPerNode: float = ...,
weightCol: Optional[str] = ...,
- bootstrap: Optional[bool] = ...
+ bootstrap: Optional[bool] = ...,
) -> None: ...
def setParams(
self,
@@ -513,7 +512,7 @@ class RandomForestClassifier(
leafCol: str = ...,
minWeightFractionPerNode: float = ...,
weightCol: Optional[str] = ...,
- bootstrap: Optional[bool] = ...
+ bootstrap: Optional[bool] = ...,
) -> RandomForestClassifier: ...
def setMaxDepth(self, value: int) -> RandomForestClassifier: ...
def setMaxBins(self, value: int) -> RandomForestClassifier: ...
@@ -590,7 +589,7 @@ class GBTClassifier(
validationIndicatorCol: Optional[str] = ...,
leafCol: str = ...,
minWeightFractionPerNode: float = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -615,7 +614,7 @@ class GBTClassifier(
validationIndicatorCol: Optional[str] = ...,
leafCol: str = ...,
minWeightFractionPerNode: float = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> GBTClassifier: ...
def setMaxDepth(self, value: int) -> GBTClassifier: ...
def setMaxBins(self, value: int) -> GBTClassifier: ...
@@ -674,7 +673,7 @@ class NaiveBayes(
smoothing: float = ...,
modelType: str = ...,
thresholds: Optional[List[float]] = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -687,7 +686,7 @@ class NaiveBayes(
smoothing: float = ...,
modelType: str = ...,
thresholds: Optional[List[float]] = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> NaiveBayes: ...
def setSmoothing(self, value: float) -> NaiveBayes: ...
def setModelType(self, value: str) -> NaiveBayes: ...
@@ -743,7 +742,7 @@ class MultilayerPerceptronClassifier(
solver: str = ...,
initialWeights: Optional[Vector] = ...,
probabilityCol: str = ...,
- rawPredictionCol: str = ...
+ rawPredictionCol: str = ...,
) -> None: ...
def setParams(
self,
@@ -760,7 +759,7 @@ class MultilayerPerceptronClassifier(
solver: str = ...,
initialWeights: Optional[Vector] = ...,
probabilityCol: str = ...,
- rawPredictionCol: str = ...
+ rawPredictionCol: str = ...,
) -> MultilayerPerceptronClassifier: ...
def setLayers(self, value: List[int]) -> MultilayerPerceptronClassifier: ...
def setBlockSize(self, value: int) -> MultilayerPerceptronClassifier: ...
@@ -781,9 +780,7 @@ class MultilayerPerceptronClassificationModel(
@property
def weights(self) -> Vector: ...
def summary(self) -> MultilayerPerceptronClassificationTrainingSummary: ...
- def evaluate(
- self, dataset: DataFrame
- ) -> MultilayerPerceptronClassificationSummary: ...
+ def evaluate(self, dataset: DataFrame) -> MultilayerPerceptronClassificationSummary: ...
class MultilayerPerceptronClassificationSummary(_ClassificationSummary): ...
class MultilayerPerceptronClassificationTrainingSummary(
@@ -810,7 +807,7 @@ class OneVsRest(
rawPredictionCol: str = ...,
classifier: Optional[Estimator[M]] = ...,
weightCol: Optional[str] = ...,
- parallelism: int = ...
+ parallelism: int = ...,
) -> None: ...
def setParams(
self,
@@ -821,7 +818,7 @@ class OneVsRest(
rawPredictionCol: str = ...,
classifier: Optional[Estimator[M]] = ...,
weightCol: Optional[str] = ...,
- parallelism: int = ...
+ parallelism: int = ...,
) -> OneVsRest: ...
def setClassifier(self, value: Estimator[M]) -> OneVsRest: ...
def setLabelCol(self, value: str) -> OneVsRest: ...
@@ -832,9 +829,7 @@ class OneVsRest(
def setParallelism(self, value: int) -> OneVsRest: ...
def copy(self, extra: Optional[ParamMap] = ...) -> OneVsRest: ...
-class OneVsRestModel(
- Model, _OneVsRestParams, MLReadable[OneVsRestModel], MLWritable
-):
+class OneVsRestModel(Model, _OneVsRestParams, MLReadable[OneVsRestModel], MLWritable):
models: List[Transformer]
def __init__(self, models: List[Transformer]) -> None: ...
def setFeaturesCol(self, value: str) -> OneVsRestModel: ...
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index f2a248c..11fbdf5 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -19,20 +19,49 @@ import sys
import warnings
from pyspark import since, keyword_only
-from pyspark.ml.param.shared import HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, \
- HasAggregationDepth, HasWeightCol, HasTol, HasProbabilityCol, HasDistanceMeasure, \
- HasCheckpointInterval, Param, Params, TypeConverters
-from pyspark.ml.util import JavaMLWritable, JavaMLReadable, GeneralJavaMLWritable, \
- HasTrainingSummary, SparkContext
+from pyspark.ml.param.shared import (
+ HasMaxIter,
+ HasFeaturesCol,
+ HasSeed,
+ HasPredictionCol,
+ HasAggregationDepth,
+ HasWeightCol,
+ HasTol,
+ HasProbabilityCol,
+ HasDistanceMeasure,
+ HasCheckpointInterval,
+ Param,
+ Params,
+ TypeConverters,
+)
+from pyspark.ml.util import (
+ JavaMLWritable,
+ JavaMLReadable,
+ GeneralJavaMLWritable,
+ HasTrainingSummary,
+ SparkContext,
+)
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper
from pyspark.ml.common import inherit_doc, _java2py
from pyspark.ml.stat import MultivariateGaussian
from pyspark.sql import DataFrame
-__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary',
- 'KMeans', 'KMeansModel', 'KMeansSummary',
- 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary',
- 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel', 'PowerIterationClustering']
+__all__ = [
+ "BisectingKMeans",
+ "BisectingKMeansModel",
+ "BisectingKMeansSummary",
+ "KMeans",
+ "KMeansModel",
+ "KMeansSummary",
+ "GaussianMixture",
+ "GaussianMixtureModel",
+ "GaussianMixtureSummary",
+ "LDA",
+ "LDAModel",
+ "LocalLDAModel",
+ "DistributedLDAModel",
+ "PowerIterationClustering",
+]
class ClusteringSummary(JavaWrapper):
@@ -100,16 +129,28 @@ class ClusteringSummary(JavaWrapper):
@inherit_doc
-class _GaussianMixtureParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol,
- HasProbabilityCol, HasTol, HasAggregationDepth, HasWeightCol):
+class _GaussianMixtureParams(
+ HasMaxIter,
+ HasFeaturesCol,
+ HasSeed,
+ HasPredictionCol,
+ HasProbabilityCol,
+ HasTol,
+ HasAggregationDepth,
+ HasWeightCol,
+):
"""
Params for :py:class:`GaussianMixture` and :py:class:`GaussianMixtureModel`.
.. versionadded:: 3.0.0
"""
- k = Param(Params._dummy(), "k", "Number of independent Gaussians in the mixture model. " +
- "Must be > 1.", typeConverter=TypeConverters.toInt)
+ k = Param(
+ Params._dummy(),
+ "k",
+ "Number of independent Gaussians in the mixture model. " + "Must be > 1.",
+ typeConverter=TypeConverters.toInt,
+ )
def __init__(self, *args):
super(_GaussianMixtureParams, self).__init__(*args)
@@ -123,8 +164,9 @@ class _GaussianMixtureParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionC
return self.getOrDefault(self.k)
-class GaussianMixtureModel(JavaModel, _GaussianMixtureParams, JavaMLWritable, JavaMLReadable,
- HasTrainingSummary):
+class GaussianMixtureModel(
+ JavaModel, _GaussianMixtureParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary
+):
"""
Model fitted by GaussianMixture.
@@ -173,7 +215,8 @@ class GaussianMixtureModel(JavaModel, _GaussianMixtureParams, JavaMLWritable, Ja
jgaussians = self._java_obj.gaussians()
return [
MultivariateGaussian(_java2py(sc, jgaussian.mean()), _java2py(sc, jgaussian.cov()))
- for jgaussian in jgaussians]
+ for jgaussian in jgaussians
+ ]
@property
@since("2.0.0")
@@ -195,8 +238,9 @@ class GaussianMixtureModel(JavaModel, _GaussianMixtureParams, JavaMLWritable, Ja
if self.hasSummary:
return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
else:
- raise RuntimeError("No training summary available for this %s" %
- self.__class__.__name__)
+ raise RuntimeError(
+ "No training summary available for this %s" % self.__class__.__name__
+ )
@since("3.0.0")
def predict(self, value):
@@ -336,17 +380,28 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav
"""
@keyword_only
- def __init__(self, *, featuresCol="features", predictionCol="prediction", k=2,
- probabilityCol="probability", tol=0.01, maxIter=100, seed=None,
- aggregationDepth=2, weightCol=None):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ predictionCol="prediction",
+ k=2,
+ probabilityCol="probability",
+ tol=0.01,
+ maxIter=100,
+ seed=None,
+ aggregationDepth=2,
+ weightCol=None,
+ ):
"""
__init__(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \
aggregationDepth=2, weightCol=None)
"""
super(GaussianMixture, self).__init__()
- self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture",
- self.uid)
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.clustering.GaussianMixture", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -355,9 +410,19 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav
@keyword_only
@since("2.0.0")
- def setParams(self, *, featuresCol="features", predictionCol="prediction", k=2,
- probabilityCol="probability", tol=0.01, maxIter=100, seed=None,
- aggregationDepth=2, weightCol=None):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ predictionCol="prediction",
+ k=2,
+ probabilityCol="probability",
+ tol=0.01,
+ maxIter=100,
+ seed=None,
+ aggregationDepth=2,
+ weightCol=None,
+ ):
"""
setParams(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \
@@ -482,28 +547,46 @@ class KMeansSummary(ClusteringSummary):
@inherit_doc
-class _KMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasTol,
- HasDistanceMeasure, HasWeightCol):
+class _KMeansParams(
+ HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasTol, HasDistanceMeasure, HasWeightCol
+):
"""
Params for :py:class:`KMeans` and :py:class:`KMeansModel`.
.. versionadded:: 3.0.0
"""
- k = Param(Params._dummy(), "k", "The number of clusters to create. Must be > 1.",
- typeConverter=TypeConverters.toInt)
- initMode = Param(Params._dummy(), "initMode",
- "The initialization algorithm. This can be either \"random\" to " +
- "choose random points as initial cluster centers, or \"k-means||\" " +
- "to use a parallel variant of k-means++",
- typeConverter=TypeConverters.toString)
- initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " +
- "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt)
+ k = Param(
+ Params._dummy(),
+ "k",
+ "The number of clusters to create. Must be > 1.",
+ typeConverter=TypeConverters.toInt,
+ )
+ initMode = Param(
+ Params._dummy(),
+ "initMode",
+ 'The initialization algorithm. This can be either "random" to '
+ + 'choose random points as initial cluster centers, or "k-means||" '
+ + "to use a parallel variant of k-means++",
+ typeConverter=TypeConverters.toString,
+ )
+ initSteps = Param(
+ Params._dummy(),
+ "initSteps",
+ "The number of steps for k-means|| " + "initialization mode. Must be > 0.",
+ typeConverter=TypeConverters.toInt,
+ )
def __init__(self, *args):
super(_KMeansParams, self).__init__(*args)
- self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20,
- distanceMeasure="euclidean")
+ self._setDefault(
+ k=2,
+ initMode="k-means||",
+ initSteps=2,
+ tol=1e-4,
+ maxIter=20,
+ distanceMeasure="euclidean",
+ )
@since("1.5.0")
def getK(self):
@@ -527,8 +610,9 @@ class _KMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasTo
return self.getOrDefault(self.initSteps)
-class KMeansModel(JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadable,
- HasTrainingSummary):
+class KMeansModel(
+ JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadable, HasTrainingSummary
+):
"""
Model fitted by KMeans.
@@ -564,8 +648,9 @@ class KMeansModel(JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadabl
if self.hasSummary:
return KMeansSummary(super(KMeansModel, self).summary)
else:
- raise RuntimeError("No training summary available for this %s" %
- self.__class__.__name__)
+ raise RuntimeError(
+ "No training summary available for this %s" % self.__class__.__name__
+ )
@since("3.0.0")
def predict(self, value):
@@ -643,9 +728,20 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
"""
@keyword_only
- def __init__(self, *, featuresCol="features", predictionCol="prediction", k=2,
- initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None,
- distanceMeasure="euclidean", weightCol=None):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ predictionCol="prediction",
+ k=2,
+ initMode="k-means||",
+ initSteps=2,
+ tol=1e-4,
+ maxIter=20,
+ seed=None,
+ distanceMeasure="euclidean",
+ weightCol=None,
+ ):
"""
__init__(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
@@ -661,9 +757,20 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
@keyword_only
@since("1.5.0")
- def setParams(self, *, featuresCol="features", predictionCol="prediction", k=2,
- initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None,
- distanceMeasure="euclidean", weightCol=None):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ predictionCol="prediction",
+ k=2,
+ initMode="k-means||",
+ initSteps=2,
+ tol=1e-4,
+ maxIter=20,
+ seed=None,
+ distanceMeasure="euclidean",
+ weightCol=None,
+ ):
"""
setParams(self, \\*, featuresCol="features", predictionCol="prediction", k=2, \
initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
@@ -746,20 +853,28 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
@inherit_doc
-class _BisectingKMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol,
- HasDistanceMeasure, HasWeightCol):
+class _BisectingKMeansParams(
+ HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol, HasDistanceMeasure, HasWeightCol
+):
"""
Params for :py:class:`BisectingKMeans` and :py:class:`BisectingKMeansModel`.
.. versionadded:: 3.0.0
"""
- k = Param(Params._dummy(), "k", "The desired number of leaf clusters. Must be > 1.",
- typeConverter=TypeConverters.toInt)
- minDivisibleClusterSize = Param(Params._dummy(), "minDivisibleClusterSize",
- "The minimum number of points (if >= 1.0) or the minimum " +
- "proportion of points (if < 1.0) of a divisible cluster.",
- typeConverter=TypeConverters.toFloat)
+ k = Param(
+ Params._dummy(),
+ "k",
+ "The desired number of leaf clusters. Must be > 1.",
+ typeConverter=TypeConverters.toInt,
+ )
+ minDivisibleClusterSize = Param(
+ Params._dummy(),
+ "minDivisibleClusterSize",
+ "The minimum number of points (if >= 1.0) or the minimum "
+ + "proportion of points (if < 1.0) of a divisible cluster.",
+ typeConverter=TypeConverters.toFloat,
+ )
def __init__(self, *args):
super(_BisectingKMeansParams, self).__init__(*args)
@@ -780,8 +895,9 @@ class _BisectingKMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionC
return self.getOrDefault(self.minDivisibleClusterSize)
-class BisectingKMeansModel(JavaModel, _BisectingKMeansParams, JavaMLWritable, JavaMLReadable,
- HasTrainingSummary):
+class BisectingKMeansModel(
+ JavaModel, _BisectingKMeansParams, JavaMLWritable, JavaMLReadable, HasTrainingSummary
+):
"""
Model fitted by BisectingKMeans.
@@ -817,9 +933,12 @@ class BisectingKMeansModel(JavaModel, _BisectingKMeansParams, JavaMLWritable, Ja
It will be removed in future versions. Use :py:class:`ClusteringEvaluator` instead.
You can also get the cost on the training dataset in the summary.
"""
- warnings.warn("Deprecated in 3.0.0. It will be removed in future versions. Use "
- "ClusteringEvaluator instead. You can also get the cost on the training "
- "dataset in the summary.", FutureWarning)
+ warnings.warn(
+ "Deprecated in 3.0.0. It will be removed in future versions. Use "
+ "ClusteringEvaluator instead. You can also get the cost on the training "
+ "dataset in the summary.",
+ FutureWarning,
+ )
return self._call_java("computeCost", dataset)
@property
@@ -832,8 +951,9 @@ class BisectingKMeansModel(JavaModel, _BisectingKMeansParams, JavaMLWritable, Ja
if self.hasSummary:
return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
else:
- raise RuntimeError("No training summary available for this %s" %
- self.__class__.__name__)
+ raise RuntimeError(
+ "No training summary available for this %s" % self.__class__.__name__
+ )
@since("3.0.0")
def predict(self, value):
@@ -924,25 +1044,44 @@ class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, Jav
"""
@keyword_only
- def __init__(self, *, featuresCol="features", predictionCol="prediction", maxIter=20,
- seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean",
- weightCol=None):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ predictionCol="prediction",
+ maxIter=20,
+ seed=None,
+ k=4,
+ minDivisibleClusterSize=1.0,
+ distanceMeasure="euclidean",
+ weightCol=None,
+ ):
"""
__init__(self, \\*, featuresCol="features", predictionCol="prediction", maxIter=20, \
seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \
weightCol=None)
"""
super(BisectingKMeans, self).__init__()
- self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans",
- self.uid)
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.clustering.BisectingKMeans", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.0.0")
- def setParams(self, *, featuresCol="features", predictionCol="prediction", maxIter=20,
- seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean",
- weightCol=None):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ predictionCol="prediction",
+ maxIter=20,
+ seed=None,
+ k=4,
+ minDivisibleClusterSize=1.0,
+ distanceMeasure="euclidean",
+ weightCol=None,
+ ):
"""
setParams(self, \\*, featuresCol="features", predictionCol="prediction", maxIter=20, \
seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \
@@ -1037,52 +1176,95 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval):
.. versionadded:: 3.0.0
"""
- k = Param(Params._dummy(), "k", "The number of topics (clusters) to infer. Must be > 1.",
- typeConverter=TypeConverters.toInt)
- optimizer = Param(Params._dummy(), "optimizer",
- "Optimizer or inference algorithm used to estimate the LDA model. "
- "Supported: online, em", typeConverter=TypeConverters.toString)
- learningOffset = Param(Params._dummy(), "learningOffset",
- "A (positive) learning parameter that downweights early iterations."
- " Larger values make early iterations count less",
- typeConverter=TypeConverters.toFloat)
- learningDecay = Param(Params._dummy(), "learningDecay", "Learning rate, set as an"
- "exponential decay rate. This should be between (0.5, 1.0] to "
- "guarantee asymptotic convergence.", typeConverter=TypeConverters.toFloat)
- subsamplingRate = Param(Params._dummy(), "subsamplingRate",
- "Fraction of the corpus to be sampled and used in each iteration "
- "of mini-batch gradient descent, in range (0, 1].",
- typeConverter=TypeConverters.toFloat)
- optimizeDocConcentration = Param(Params._dummy(), "optimizeDocConcentration",
- "Indicates whether the docConcentration (Dirichlet parameter "
- "for document-topic distribution) will be optimized during "
- "training.", typeConverter=TypeConverters.toBoolean)
- docConcentration = Param(Params._dummy(), "docConcentration",
- "Concentration parameter (commonly named \"alpha\") for the "
- "prior placed on documents' distributions over topics (\"theta\").",
- typeConverter=TypeConverters.toListFloat)
- topicConcentration = Param(Params._dummy(), "topicConcentration",
- "Concentration parameter (commonly named \"beta\" or \"eta\") for "
- "the prior placed on topic' distributions over terms.",
- typeConverter=TypeConverters.toFloat)
- topicDistributionCol = Param(Params._dummy(), "topicDistributionCol",
- "Output column with estimates of the topic mixture distribution "
- "for each document (often called \"theta\" in the literature). "
- "Returns a vector of zeros for an empty document.",
- typeConverter=TypeConverters.toString)
- keepLastCheckpoint = Param(Params._dummy(), "keepLastCheckpoint",
- "(For EM optimizer) If using checkpointing, this indicates whether"
- " to keep the last checkpoint. If false, then the checkpoint will be"
- " deleted. Deleting the checkpoint can cause failures if a data"
- " partition is lost, so set this bit with care.",
- TypeConverters.toBoolean)
+ k = Param(
+ Params._dummy(),
+ "k",
+ "The number of topics (clusters) to infer. Must be > 1.",
+ typeConverter=TypeConverters.toInt,
+ )
+ optimizer = Param(
+ Params._dummy(),
+ "optimizer",
+ "Optimizer or inference algorithm used to estimate the LDA model. "
+ "Supported: online, em",
+ typeConverter=TypeConverters.toString,
+ )
+ learningOffset = Param(
+ Params._dummy(),
+ "learningOffset",
+ "A (positive) learning parameter that downweights early iterations."
+ " Larger values make early iterations count less",
+ typeConverter=TypeConverters.toFloat,
+ )
+ learningDecay = Param(
+ Params._dummy(),
+ "learningDecay",
+ "Learning rate, set as an"
+ "exponential decay rate. This should be between (0.5, 1.0] to "
+ "guarantee asymptotic convergence.",
+ typeConverter=TypeConverters.toFloat,
+ )
+ subsamplingRate = Param(
+ Params._dummy(),
+ "subsamplingRate",
+ "Fraction of the corpus to be sampled and used in each iteration "
+ "of mini-batch gradient descent, in range (0, 1].",
+ typeConverter=TypeConverters.toFloat,
+ )
+ optimizeDocConcentration = Param(
+ Params._dummy(),
+ "optimizeDocConcentration",
+ "Indicates whether the docConcentration (Dirichlet parameter "
+ "for document-topic distribution) will be optimized during "
+ "training.",
+ typeConverter=TypeConverters.toBoolean,
+ )
+ docConcentration = Param(
+ Params._dummy(),
+ "docConcentration",
+ 'Concentration parameter (commonly named "alpha") for the '
+ 'prior placed on documents\' distributions over topics ("theta").',
+ typeConverter=TypeConverters.toListFloat,
+ )
+ topicConcentration = Param(
+ Params._dummy(),
+ "topicConcentration",
+ 'Concentration parameter (commonly named "beta" or "eta") for '
+ "the prior placed on topic' distributions over terms.",
+ typeConverter=TypeConverters.toFloat,
+ )
+ topicDistributionCol = Param(
+ Params._dummy(),
+ "topicDistributionCol",
+ "Output column with estimates of the topic mixture distribution "
+ 'for each document (often called "theta" in the literature). '
+ "Returns a vector of zeros for an empty document.",
+ typeConverter=TypeConverters.toString,
+ )
+ keepLastCheckpoint = Param(
+ Params._dummy(),
+ "keepLastCheckpoint",
+ "(For EM optimizer) If using checkpointing, this indicates whether"
+ " to keep the last checkpoint. If false, then the checkpoint will be"
+ " deleted. Deleting the checkpoint can cause failures if a data"
+ " partition is lost, so set this bit with care.",
+ TypeConverters.toBoolean,
+ )
def __init__(self, *args):
super(_LDAParams, self).__init__(*args)
- self._setDefault(maxIter=20, checkpointInterval=10,
- k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
- subsamplingRate=0.05, optimizeDocConcentration=True,
- topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
+ self._setDefault(
+ maxIter=20,
+ checkpointInterval=10,
+ k=10,
+ optimizer="online",
+ learningOffset=1024.0,
+ learningDecay=0.51,
+ subsamplingRate=0.05,
+ optimizeDocConcentration=True,
+ topicDistributionCol="topicDistribution",
+ keepLastCheckpoint=True,
+ )
@since("2.0.0")
def getK(self):
@@ -1336,6 +1518,7 @@ class LocalLDAModel(LDAModel, JavaMLReadable, JavaMLWritable):
.. versionadded:: 2.0.0
"""
+
pass
@@ -1411,11 +1594,24 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable):
"""
@keyword_only
- def __init__(self, *, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,
- k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
- subsamplingRate=0.05, optimizeDocConcentration=True,
- docConcentration=None, topicConcentration=None,
- topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ maxIter=20,
+ seed=None,
+ checkpointInterval=10,
+ k=10,
+ optimizer="online",
+ learningOffset=1024.0,
+ learningDecay=0.51,
+ subsamplingRate=0.05,
+ optimizeDocConcentration=True,
+ docConcentration=None,
+ topicConcentration=None,
+ topicDistributionCol="topicDistribution",
+ keepLastCheckpoint=True,
+ ):
"""
__init__(self, \\*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\
k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
@@ -1436,11 +1632,24 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable):
@keyword_only
@since("2.0.0")
- def setParams(self, *, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,
- k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
- subsamplingRate=0.05, optimizeDocConcentration=True,
- docConcentration=None, topicConcentration=None,
- topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ maxIter=20,
+ seed=None,
+ checkpointInterval=10,
+ k=10,
+ optimizer="online",
+ learningOffset=1024.0,
+ learningDecay=0.51,
+ subsamplingRate=0.05,
+ optimizeDocConcentration=True,
+ docConcentration=None,
+ topicConcentration=None,
+ topicDistributionCol="topicDistribution",
+ keepLastCheckpoint=True,
+ ):
"""
setParams(self, \\*, featuresCol="features", maxIter=20, seed=None, checkpointInterval=10,\
k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
@@ -1619,21 +1828,33 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
.. versionadded:: 3.0.0
"""
- k = Param(Params._dummy(), "k",
- "The number of clusters to create. Must be > 1.",
- typeConverter=TypeConverters.toInt)
- initMode = Param(Params._dummy(), "initMode",
- "The initialization algorithm. This can be either " +
- "'random' to use a random vector as vertex properties, or 'degree' to use " +
- "a normalized sum of similarities with other vertices. Supported options: " +
- "'random' and 'degree'.",
- typeConverter=TypeConverters.toString)
- srcCol = Param(Params._dummy(), "srcCol",
- "Name of the input column for source vertex IDs.",
- typeConverter=TypeConverters.toString)
- dstCol = Param(Params._dummy(), "dstCol",
- "Name of the input column for destination vertex IDs.",
- typeConverter=TypeConverters.toString)
+ k = Param(
+ Params._dummy(),
+ "k",
+ "The number of clusters to create. Must be > 1.",
+ typeConverter=TypeConverters.toInt,
+ )
+ initMode = Param(
+ Params._dummy(),
+ "initMode",
+ "The initialization algorithm. This can be either "
+ + "'random' to use a random vector as vertex properties, or 'degree' to use "
+ + "a normalized sum of similarities with other vertices. Supported options: "
+ + "'random' and 'degree'.",
+ typeConverter=TypeConverters.toString,
+ )
+ srcCol = Param(
+ Params._dummy(),
+ "srcCol",
+ "Name of the input column for source vertex IDs.",
+ typeConverter=TypeConverters.toString,
+ )
+ dstCol = Param(
+ Params._dummy(),
+ "dstCol",
+ "Name of the input column for destination vertex IDs.",
+ typeConverter=TypeConverters.toString,
+ )
def __init__(self, *args):
super(_PowerIterationClusteringParams, self).__init__(*args)
@@ -1669,8 +1890,9 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
@inherit_doc
-class PowerIterationClustering(_PowerIterationClusteringParams, JavaParams, JavaMLReadable,
- JavaMLWritable):
+class PowerIterationClustering(
+ _PowerIterationClusteringParams, JavaParams, JavaMLReadable, JavaMLWritable
+):
"""
Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
`Lin and Cohen <http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf>`_. From the
@@ -1722,22 +1944,25 @@ class PowerIterationClustering(_PowerIterationClusteringParams, JavaParams, Java
"""
@keyword_only
- def __init__(self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",
- weightCol=None):
+ def __init__(
+ self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", weightCol=None
+ ):
"""
__init__(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
weightCol=None)
"""
super(PowerIterationClustering, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid)
+ "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.4.0")
- def setParams(self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",
- weightCol=None):
+ def setParams(
+ self, *, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", weightCol=None
+ ):
"""
setParams(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
weightCol=None)
@@ -1822,29 +2047,29 @@ if __name__ == "__main__":
import numpy
import pyspark.ml.clustering
from pyspark.sql import SparkSession
+
try:
# Numpy 1.14+ changed it's string format.
- numpy.set_printoptions(legacy='1.13')
+ numpy.set_printoptions(legacy="1.13")
except TypeError:
pass
globs = pyspark.ml.clustering.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- spark = SparkSession.builder\
- .master("local[2]")\
- .appName("ml.clustering tests")\
- .getOrCreate()
+ spark = SparkSession.builder.master("local[2]").appName("ml.clustering tests").getOrCreate()
sc = spark.sparkContext
- globs['sc'] = sc
- globs['spark'] = spark
+ globs["sc"] = sc
+ globs["spark"] = spark
import tempfile
+
temp_path = tempfile.mkdtemp()
- globs['temp_path'] = temp_path
+ globs["temp_path"] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
finally:
from shutil import rmtree
+
try:
rmtree(temp_path)
except OSError:
diff --git a/python/pyspark/ml/clustering.pyi b/python/pyspark/ml/clustering.pyi
index e899b60..81074fc 100644
--- a/python/pyspark/ml/clustering.pyi
+++ b/python/pyspark/ml/clustering.pyi
@@ -113,7 +113,7 @@ class GaussianMixture(
maxIter: int = ...,
seed: Optional[int] = ...,
aggregationDepth: int = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -126,7 +126,7 @@ class GaussianMixture(
maxIter: int = ...,
seed: Optional[int] = ...,
aggregationDepth: int = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> GaussianMixture: ...
def setK(self, value: int) -> GaussianMixture: ...
def setMaxIter(self, value: int) -> GaussianMixture: ...
@@ -180,9 +180,7 @@ class KMeansModel(
def summary(self) -> KMeansSummary: ...
def predict(self, value: Vector) -> int: ...
-class KMeans(
- JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable[KMeans]
-):
+class KMeans(JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable[KMeans]):
def __init__(
self,
*,
@@ -195,7 +193,7 @@ class KMeans(
maxIter: int = ...,
seed: Optional[int] = ...,
distanceMeasure: str = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -209,7 +207,7 @@ class KMeans(
maxIter: int = ...,
seed: Optional[int] = ...,
distanceMeasure: str = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> KMeans: ...
def setK(self, value: int) -> KMeans: ...
def setInitMode(self, value: str) -> KMeans: ...
@@ -267,7 +265,7 @@ class BisectingKMeans(
k: int = ...,
minDivisibleClusterSize: float = ...,
distanceMeasure: str = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -279,7 +277,7 @@ class BisectingKMeans(
k: int = ...,
minDivisibleClusterSize: float = ...,
distanceMeasure: str = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> BisectingKMeans: ...
def setK(self, value: int) -> BisectingKMeans: ...
def setMinDivisibleClusterSize(self, value: float) -> BisectingKMeans: ...
@@ -329,9 +327,7 @@ class LDAModel(JavaModel, _LDAParams):
def describeTopics(self, maxTermsPerTopic: int = ...) -> DataFrame: ...
def estimatedDocConcentration(self) -> Vector: ...
-class DistributedLDAModel(
- LDAModel, JavaMLReadable[DistributedLDAModel], JavaMLWritable
-):
+class DistributedLDAModel(LDAModel, JavaMLReadable[DistributedLDAModel], JavaMLWritable):
def toLocal(self) -> LDAModel: ...
def trainingLogLikelihood(self) -> float: ...
def logPrior(self) -> float: ...
@@ -356,7 +352,7 @@ class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable[LDA], JavaMLWritab
docConcentration: Optional[List[float]] = ...,
topicConcentration: Optional[float] = ...,
topicDistributionCol: str = ...,
- keepLastCheckpoint: bool = ...
+ keepLastCheckpoint: bool = ...,
) -> None: ...
def setParams(
self,
@@ -374,7 +370,7 @@ class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable[LDA], JavaMLWritab
docConcentration: Optional[List[float]] = ...,
topicConcentration: Optional[float] = ...,
topicDistributionCol: str = ...,
- keepLastCheckpoint: bool = ...
+ keepLastCheckpoint: bool = ...,
) -> LDA: ...
def setCheckpointInterval(self, value: int) -> LDA: ...
def setSeed(self, value: int) -> LDA: ...
@@ -416,7 +412,7 @@ class PowerIterationClustering(
initMode: str = ...,
srcCol: str = ...,
dstCol: str = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -426,7 +422,7 @@ class PowerIterationClustering(
initMode: str = ...,
srcCol: str = ...,
dstCol: str = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> PowerIterationClustering: ...
def setK(self, value: int) -> PowerIterationClustering: ...
def setInitMode(self, value: str) -> PowerIterationClustering: ...
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
index 165dc29..8107f91 100644
--- a/python/pyspark/ml/common.py
+++ b/python/pyspark/ml/common.py
@@ -28,9 +28,9 @@ from pyspark.sql import DataFrame, SQLContext
_old_smart_decode = py4j.protocol.smart_decode
_float_str_mapping = {
- 'nan': 'NaN',
- 'inf': 'Infinity',
- '-inf': '-Infinity',
+ "nan": "NaN",
+ "inf": "Infinity",
+ "-inf": "-Infinity",
}
@@ -40,20 +40,21 @@ def _new_smart_decode(obj):
return _float_str_mapping.get(s, s)
return _old_smart_decode(obj)
+
py4j.protocol.smart_decode = _new_smart_decode
_picklable_classes = [
- 'SparseVector',
- 'DenseVector',
- 'SparseMatrix',
- 'DenseMatrix',
+ "SparseVector",
+ "DenseVector",
+ "SparseMatrix",
+ "DenseMatrix",
]
# this will call the ML version of pythonToJava()
def _to_java_object_rdd(rdd):
- """ Return an JavaRDD of Object by unpickling
+ """Return an JavaRDD of Object by unpickling
It will convert each Python object into Java object by Pickle, whenever the
RDD is serialized in batch or not.
@@ -63,7 +64,7 @@ def _to_java_object_rdd(rdd):
def _py2java(sc, obj):
- """ Convert Python object into Java """
+ """Convert Python object into Java"""
if isinstance(obj, RDD):
obj = _to_java_object_rdd(obj)
elif isinstance(obj, DataFrame):
@@ -86,15 +87,15 @@ def _java2py(sc, r, encoding="bytes"):
if isinstance(r, JavaObject):
clsName = r.getClass().getSimpleName()
# convert RDD into JavaRDD
- if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+ if clsName != "JavaRDD" and clsName.endswith("RDD"):
r = r.toJavaRDD()
- clsName = 'JavaRDD'
+ clsName = "JavaRDD"
- if clsName == 'JavaRDD':
+ if clsName == "JavaRDD":
jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
return RDD(jrdd, sc)
- if clsName == 'Dataset':
+ if clsName == "Dataset":
return DataFrame(r, SQLContext.getOrCreate(sc))
if clsName in _picklable_classes:
@@ -111,7 +112,7 @@ def _java2py(sc, r, encoding="bytes"):
def callJavaFunc(sc, func, *args):
- """ Call Java Function """
+ """Call Java Function"""
args = [_py2java(sc, a) for a in args]
return _java2py(sc, func(*args))
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index e8cada9..be63a8f 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -21,14 +21,26 @@ from abc import abstractmethod, ABCMeta
from pyspark import since, keyword_only
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params, TypeConverters
-from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasProbabilityCol, \
- HasRawPredictionCol, HasFeaturesCol, HasWeightCol
+from pyspark.ml.param.shared import (
+ HasLabelCol,
+ HasPredictionCol,
+ HasProbabilityCol,
+ HasRawPredictionCol,
+ HasFeaturesCol,
+ HasWeightCol,
+)
from pyspark.ml.common import inherit_doc
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
-__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
- 'MulticlassClassificationEvaluator', 'MultilabelClassificationEvaluator',
- 'ClusteringEvaluator', 'RankingEvaluator']
+__all__ = [
+ "Evaluator",
+ "BinaryClassificationEvaluator",
+ "RegressionEvaluator",
+ "MulticlassClassificationEvaluator",
+ "MultilabelClassificationEvaluator",
+ "ClusteringEvaluator",
+ "RankingEvaluator",
+]
@inherit_doc
@@ -38,6 +50,7 @@ class Evaluator(Params, metaclass=ABCMeta):
.. versionadded:: 1.4.0
"""
+
pass
@abstractmethod
@@ -125,8 +138,9 @@ class JavaEvaluator(JavaParams, Evaluator, metaclass=ABCMeta):
@inherit_doc
-class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol,
- JavaMLReadable, JavaMLWritable):
+class BinaryClassificationEvaluator(
+ JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol, JavaMLReadable, JavaMLWritable
+):
"""
Evaluator for binary classification, which expects input columns rawPrediction, label
and an optional weight column.
@@ -168,25 +182,40 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
1000
"""
- metricName = Param(Params._dummy(), "metricName",
- "metric name in evaluation (areaUnderROC|areaUnderPR)",
- typeConverter=TypeConverters.toString)
-
- numBins = Param(Params._dummy(), "numBins", "Number of bins to down-sample the curves "
- "(ROC curve, PR curve) in area computation. If 0, no down-sampling will "
- "occur. Must be >= 0.",
- typeConverter=TypeConverters.toInt)
+ metricName = Param(
+ Params._dummy(),
+ "metricName",
+ "metric name in evaluation (areaUnderROC|areaUnderPR)",
+ typeConverter=TypeConverters.toString,
+ )
+
+ numBins = Param(
+ Params._dummy(),
+ "numBins",
+ "Number of bins to down-sample the curves "
+ "(ROC curve, PR curve) in area computation. If 0, no down-sampling will "
+ "occur. Must be >= 0.",
+ typeConverter=TypeConverters.toInt,
+ )
@keyword_only
- def __init__(self, *, rawPredictionCol="rawPrediction", labelCol="label",
- metricName="areaUnderROC", weightCol=None, numBins=1000):
+ def __init__(
+ self,
+ *,
+ rawPredictionCol="rawPrediction",
+ labelCol="label",
+ metricName="areaUnderROC",
+ weightCol=None,
+ numBins=1000,
+ ):
"""
__init__(self, \\*, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC", weightCol=None, numBins=1000)
"""
super(BinaryClassificationEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
+ "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid
+ )
self._setDefault(metricName="areaUnderROC", numBins=1000)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -240,8 +269,15 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
@keyword_only
@since("1.4.0")
- def setParams(self, *, rawPredictionCol="rawPrediction", labelCol="label",
- metricName="areaUnderROC", weightCol=None, numBins=1000):
+ def setParams(
+ self,
+ *,
+ rawPredictionCol="rawPrediction",
+ labelCol="label",
+ metricName="areaUnderROC",
+ weightCol=None,
+ numBins=1000,
+ ):
"""
setParams(self, \\*, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC", weightCol=None, numBins=1000)
@@ -252,8 +288,9 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
@inherit_doc
-class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
- JavaMLReadable, JavaMLWritable):
+class RegressionEvaluator(
+ JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol, JavaMLReadable, JavaMLWritable
+):
"""
Evaluator for Regression, which expects input columns prediction, label
and an optional weight column.
@@ -290,29 +327,44 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
>>> evaluator.getThroughOrigin()
False
"""
- metricName = Param(Params._dummy(), "metricName",
- """metric name in evaluation - one of:
+
+ metricName = Param(
+ Params._dummy(),
+ "metricName",
+ """metric name in evaluation - one of:
rmse - root mean squared error (default)
mse - mean squared error
r2 - r^2 metric
mae - mean absolute error
var - explained variance.""",
- typeConverter=TypeConverters.toString)
+ typeConverter=TypeConverters.toString,
+ )
- throughOrigin = Param(Params._dummy(), "throughOrigin",
- "whether the regression is through the origin.",
- typeConverter=TypeConverters.toBoolean)
+ throughOrigin = Param(
+ Params._dummy(),
+ "throughOrigin",
+ "whether the regression is through the origin.",
+ typeConverter=TypeConverters.toBoolean,
+ )
@keyword_only
- def __init__(self, *, predictionCol="prediction", labelCol="label",
- metricName="rmse", weightCol=None, throughOrigin=False):
+ def __init__(
+ self,
+ *,
+ predictionCol="prediction",
+ labelCol="label",
+ metricName="rmse",
+ weightCol=None,
+ throughOrigin=False,
+ ):
"""
__init__(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="rmse", weightCol=None, throughOrigin=False)
"""
super(RegressionEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
+ "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid
+ )
self._setDefault(metricName="rmse", throughOrigin=False)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -366,8 +418,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
@keyword_only
@since("1.4.0")
- def setParams(self, *, predictionCol="prediction", labelCol="label",
- metricName="rmse", weightCol=None, throughOrigin=False):
+ def setParams(
+ self,
+ *,
+ predictionCol="prediction",
+ labelCol="label",
+ metricName="rmse",
+ weightCol=None,
+ throughOrigin=False,
+ ):
"""
setParams(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="rmse", weightCol=None, throughOrigin=False)
@@ -378,8 +437,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
@inherit_doc
-class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
- HasProbabilityCol, JavaMLReadable, JavaMLWritable):
+class MulticlassClassificationEvaluator(
+ JavaEvaluator,
+ HasLabelCol,
+ HasPredictionCol,
+ HasWeightCol,
+ HasProbabilityCol,
+ JavaMLReadable,
+ JavaMLWritable,
+):
"""
Evaluator for Multiclass Classification, which expects input
columns: prediction, label, weight (optional) and probabilityCol (only for logLoss).
@@ -432,32 +498,54 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
>>> evaluator.evaluate(dataset)
0.9682...
"""
- metricName = Param(Params._dummy(), "metricName",
- "metric name in evaluation "
- "(f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate| "
- "weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel| "
- "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel| "
- "logLoss|hammingLoss)",
- typeConverter=TypeConverters.toString)
- metricLabel = Param(Params._dummy(), "metricLabel",
- "The class whose metric will be computed in truePositiveRateByLabel|"
- "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel."
- " Must be >= 0. The default value is 0.",
- typeConverter=TypeConverters.toFloat)
- beta = Param(Params._dummy(), "beta",
- "The beta value used in weightedFMeasure|fMeasureByLabel."
- " Must be > 0. The default value is 1.",
- typeConverter=TypeConverters.toFloat)
- eps = Param(Params._dummy(), "eps",
- "log-loss is undefined for p=0 or p=1, so probabilities are clipped to "
- "max(eps, min(1 - eps, p)). "
- "Must be in range (0, 0.5). The default value is 1e-15.",
- typeConverter=TypeConverters.toFloat)
+
+ metricName = Param(
+ Params._dummy(),
+ "metricName",
+ "metric name in evaluation "
+ "(f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate| "
+ "weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel| "
+ "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel| "
+ "logLoss|hammingLoss)",
+ typeConverter=TypeConverters.toString,
+ )
+ metricLabel = Param(
+ Params._dummy(),
+ "metricLabel",
+ "The class whose metric will be computed in truePositiveRateByLabel|"
+ "falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel."
+ " Must be >= 0. The default value is 0.",
+ typeConverter=TypeConverters.toFloat,
+ )
+ beta = Param(
+ Params._dummy(),
+ "beta",
+ "The beta value used in weightedFMeasure|fMeasureByLabel."
+ " Must be > 0. The default value is 1.",
+ typeConverter=TypeConverters.toFloat,
+ )
+ eps = Param(
+ Params._dummy(),
+ "eps",
+ "log-loss is undefined for p=0 or p=1, so probabilities are clipped to "
+ "max(eps, min(1 - eps, p)). "
+ "Must be in range (0, 0.5). The default value is 1e-15.",
+ typeConverter=TypeConverters.toFloat,
+ )
@keyword_only
- def __init__(self, *, predictionCol="prediction", labelCol="label",
- metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0,
- probabilityCol="probability", eps=1e-15):
+ def __init__(
+ self,
+ *,
+ predictionCol="prediction",
+ labelCol="label",
+ metricName="f1",
+ weightCol=None,
+ metricLabel=0.0,
+ beta=1.0,
+ probabilityCol="probability",
+ eps=1e-15,
+ ):
"""
__init__(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0, \
@@ -465,7 +553,8 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
"""
super(MulticlassClassificationEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
+ "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid
+ )
self._setDefault(metricName="f1", metricLabel=0.0, beta=1.0, eps=1e-15)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -554,9 +643,18 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
@keyword_only
@since("1.5.0")
- def setParams(self, *, predictionCol="prediction", labelCol="label",
- metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0,
- probabilityCol="probability", eps=1e-15):
+ def setParams(
+ self,
+ *,
+ predictionCol="prediction",
+ labelCol="label",
+ metricName="f1",
+ weightCol=None,
+ metricLabel=0.0,
+ beta=1.0,
+ probabilityCol="probability",
+ eps=1e-15,
+ ):
"""
setParams(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0, \
@@ -568,8 +666,9 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
@inherit_doc
-class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
- JavaMLReadable, JavaMLWritable):
+class MultilabelClassificationEvaluator(
+ JavaEvaluator, HasLabelCol, HasPredictionCol, JavaMLReadable, JavaMLWritable
+):
"""
Evaluator for Multilabel Classification, which expects two input
columns: prediction and label.
@@ -600,28 +699,42 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
>>> str(evaluator2.getPredictionCol())
'prediction'
"""
- metricName = Param(Params._dummy(), "metricName",
- "metric name in evaluation "
- "(subsetAccuracy|accuracy|hammingLoss|precision|recall|f1Measure|"
- "precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|"
- "microRecall|microF1Measure)",
- typeConverter=TypeConverters.toString)
- metricLabel = Param(Params._dummy(), "metricLabel",
- "The class whose metric will be computed in precisionByLabel|"
- "recallByLabel|f1MeasureByLabel. "
- "Must be >= 0. The default value is 0.",
- typeConverter=TypeConverters.toFloat)
+
+ metricName = Param(
+ Params._dummy(),
+ "metricName",
+ "metric name in evaluation "
+ "(subsetAccuracy|accuracy|hammingLoss|precision|recall|f1Measure|"
+ "precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|"
+ "microRecall|microF1Measure)",
+ typeConverter=TypeConverters.toString,
+ )
+ metricLabel = Param(
+ Params._dummy(),
+ "metricLabel",
+ "The class whose metric will be computed in precisionByLabel|"
+ "recallByLabel|f1MeasureByLabel. "
+ "Must be >= 0. The default value is 0.",
+ typeConverter=TypeConverters.toFloat,
+ )
@keyword_only
- def __init__(self, *, predictionCol="prediction", labelCol="label",
- metricName="f1Measure", metricLabel=0.0):
+ def __init__(
+ self,
+ *,
+ predictionCol="prediction",
+ labelCol="label",
+ metricName="f1Measure",
+ metricLabel=0.0,
+ ):
"""
__init__(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="f1Measure", metricLabel=0.0)
"""
super(MultilabelClassificationEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator", self.uid)
+ "org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator", self.uid
+ )
self._setDefault(metricName="f1Measure", metricLabel=0.0)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -670,8 +783,14 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
@keyword_only
@since("3.0.0")
- def setParams(self, *, predictionCol="prediction", labelCol="label",
- metricName="f1Measure", metricLabel=0.0):
+ def setParams(
+ self,
+ *,
+ predictionCol="prediction",
+ labelCol="label",
+ metricName="f1Measure",
+ metricLabel=0.0,
+ ):
"""
setParams(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="f1Measure", metricLabel=0.0)
@@ -682,8 +801,9 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
@inherit_doc
-class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWeightCol,
- JavaMLReadable, JavaMLWritable):
+class ClusteringEvaluator(
+ JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWeightCol, JavaMLReadable, JavaMLWritable
+):
"""
Evaluator for Clustering results, which expects two input
columns: prediction and features. The metric computes the Silhouette
@@ -727,31 +847,53 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWe
>>> str(evaluator2.getPredictionCol())
'prediction'
"""
- metricName = Param(Params._dummy(), "metricName",
- "metric name in evaluation (silhouette)",
- typeConverter=TypeConverters.toString)
- distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " +
- "Supported options: 'squaredEuclidean' and 'cosine'.",
- typeConverter=TypeConverters.toString)
+
+ metricName = Param(
+ Params._dummy(),
+ "metricName",
+ "metric name in evaluation (silhouette)",
+ typeConverter=TypeConverters.toString,
+ )
+ distanceMeasure = Param(
+ Params._dummy(),
+ "distanceMeasure",
+ "The distance measure. " + "Supported options: 'squaredEuclidean' and 'cosine'.",
+ typeConverter=TypeConverters.toString,
+ )
@keyword_only
- def __init__(self, *, predictionCol="prediction", featuresCol="features",
- metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None):
+ def __init__(
+ self,
+ *,
+ predictionCol="prediction",
+ featuresCol="features",
+ metricName="silhouette",
+ distanceMeasure="squaredEuclidean",
+ weightCol=None,
+ ):
"""
__init__(self, \\*, predictionCol="prediction", featuresCol="features", \
metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
"""
super(ClusteringEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid)
+ "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid
+ )
self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean")
kwargs = self._input_kwargs
self._set(**kwargs)
@keyword_only
@since("2.3.0")
- def setParams(self, *, predictionCol="prediction", featuresCol="features",
- metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None):
+ def setParams(
+ self,
+ *,
+ predictionCol="prediction",
+ featuresCol="features",
+ metricName="silhouette",
+ distanceMeasure="squaredEuclidean",
+ weightCol=None,
+ ):
"""
setParams(self, \\*, predictionCol="prediction", featuresCol="features", \
metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
@@ -809,8 +951,9 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWe
@inherit_doc
-class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
- JavaMLReadable, JavaMLWritable):
+class RankingEvaluator(
+ JavaEvaluator, HasLabelCol, HasPredictionCol, JavaMLReadable, JavaMLWritable
+):
"""
Evaluator for Ranking, which expects two input
columns: prediction and label.
@@ -842,26 +985,40 @@ class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
>>> str(evaluator2.getPredictionCol())
'prediction'
"""
- metricName = Param(Params._dummy(), "metricName",
- "metric name in evaluation "
- "(meanAveragePrecision|meanAveragePrecisionAtK|"
- "precisionAtK|ndcgAtK|recallAtK)",
- typeConverter=TypeConverters.toString)
- k = Param(Params._dummy(), "k",
- "The ranking position value used in meanAveragePrecisionAtK|precisionAtK|"
- "ndcgAtK|recallAtK. Must be > 0. The default value is 10.",
- typeConverter=TypeConverters.toInt)
+
+ metricName = Param(
+ Params._dummy(),
+ "metricName",
+ "metric name in evaluation "
+ "(meanAveragePrecision|meanAveragePrecisionAtK|"
+ "precisionAtK|ndcgAtK|recallAtK)",
+ typeConverter=TypeConverters.toString,
+ )
+ k = Param(
+ Params._dummy(),
+ "k",
+ "The ranking position value used in meanAveragePrecisionAtK|precisionAtK|"
+ "ndcgAtK|recallAtK. Must be > 0. The default value is 10.",
+ typeConverter=TypeConverters.toInt,
+ )
@keyword_only
- def __init__(self, *, predictionCol="prediction", labelCol="label",
- metricName="meanAveragePrecision", k=10):
+ def __init__(
+ self,
+ *,
+ predictionCol="prediction",
+ labelCol="label",
+ metricName="meanAveragePrecision",
+ k=10,
+ ):
"""
__init__(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="meanAveragePrecision", k=10)
"""
super(RankingEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.evaluation.RankingEvaluator", self.uid)
+ "org.apache.spark.ml.evaluation.RankingEvaluator", self.uid
+ )
self._setDefault(metricName="meanAveragePrecision", k=10)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -910,8 +1067,14 @@ class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
@keyword_only
@since("3.0.0")
- def setParams(self, *, predictionCol="prediction", labelCol="label",
- metricName="meanAveragePrecision", k=10):
+ def setParams(
+ self,
+ *,
+ predictionCol="prediction",
+ labelCol="label",
+ metricName="meanAveragePrecision",
+ k=10,
+ ):
"""
setParams(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="meanAveragePrecision", k=10)
@@ -926,21 +1089,20 @@ if __name__ == "__main__":
import tempfile
import pyspark.ml.evaluation
from pyspark.sql import SparkSession
+
globs = pyspark.ml.evaluation.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- spark = SparkSession.builder\
- .master("local[2]")\
- .appName("ml.evaluation tests")\
- .getOrCreate()
- globs['spark'] = spark
+ spark = SparkSession.builder.master("local[2]").appName("ml.evaluation tests").getOrCreate()
+ globs["spark"] = spark
temp_path = tempfile.mkdtemp()
- globs['temp_path'] = temp_path
+ globs["temp_path"] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
finally:
from shutil import rmtree
+
try:
rmtree(temp_path)
except OSError:
diff --git a/python/pyspark/ml/evaluation.pyi b/python/pyspark/ml/evaluation.pyi
index 55a3ae2..d7883f4 100644
--- a/python/pyspark/ml/evaluation.pyi
+++ b/python/pyspark/ml/evaluation.pyi
@@ -42,9 +42,7 @@ from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.sql.dataframe import DataFrame
class Evaluator(Params, metaclass=abc.ABCMeta):
- def evaluate(
- self, dataset: DataFrame, params: Optional[ParamMap] = ...
- ) -> float: ...
+ def evaluate(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> float: ...
def isLargerBetter(self) -> bool: ...
class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta):
@@ -67,7 +65,7 @@ class BinaryClassificationEvaluator(
labelCol: str = ...,
metricName: BinaryClassificationEvaluatorMetricType = ...,
weightCol: Optional[str] = ...,
- numBins: int = ...
+ numBins: int = ...,
) -> None: ...
def setMetricName(
self, value: BinaryClassificationEvaluatorMetricType
@@ -85,7 +83,7 @@ class BinaryClassificationEvaluator(
labelCol: str = ...,
metricName: BinaryClassificationEvaluatorMetricType = ...,
weightCol: Optional[str] = ...,
- numBins: int = ...
+ numBins: int = ...,
) -> BinaryClassificationEvaluator: ...
class RegressionEvaluator(
@@ -105,11 +103,9 @@ class RegressionEvaluator(
labelCol: str = ...,
metricName: RegressionEvaluatorMetricType = ...,
weightCol: Optional[str] = ...,
- throughOrigin: bool = ...
+ throughOrigin: bool = ...,
) -> None: ...
- def setMetricName(
- self, value: RegressionEvaluatorMetricType
- ) -> RegressionEvaluator: ...
+ def setMetricName(self, value: RegressionEvaluatorMetricType) -> RegressionEvaluator: ...
def getMetricName(self) -> RegressionEvaluatorMetricType: ...
def setThroughOrigin(self, value: bool) -> RegressionEvaluator: ...
def getThroughOrigin(self) -> bool: ...
@@ -123,7 +119,7 @@ class RegressionEvaluator(
labelCol: str = ...,
metricName: RegressionEvaluatorMetricType = ...,
weightCol: Optional[str] = ...,
- throughOrigin: bool = ...
+ throughOrigin: bool = ...,
) -> RegressionEvaluator: ...
class MulticlassClassificationEvaluator(
@@ -149,7 +145,7 @@ class MulticlassClassificationEvaluator(
metricLabel: float = ...,
beta: float = ...,
probabilityCol: str = ...,
- eps: float = ...
+ eps: float = ...,
) -> None: ...
def setMetricName(
self, value: MulticlassClassificationEvaluatorMetricType
@@ -175,7 +171,7 @@ class MulticlassClassificationEvaluator(
metricLabel: float = ...,
beta: float = ...,
probabilityCol: str = ...,
- eps: float = ...
+ eps: float = ...,
) -> MulticlassClassificationEvaluator: ...
class MultilabelClassificationEvaluator(
@@ -193,7 +189,7 @@ class MultilabelClassificationEvaluator(
predictionCol: str = ...,
labelCol: str = ...,
metricName: MultilabelClassificationEvaluatorMetricType = ...,
- metricLabel: float = ...
+ metricLabel: float = ...,
) -> None: ...
def setMetricName(
self, value: MultilabelClassificationEvaluatorMetricType
@@ -209,7 +205,7 @@ class MultilabelClassificationEvaluator(
predictionCol: str = ...,
labelCol: str = ...,
metricName: MultilabelClassificationEvaluatorMetricType = ...,
- metricLabel: float = ...
+ metricLabel: float = ...,
) -> MultilabelClassificationEvaluator: ...
class ClusteringEvaluator(
@@ -229,7 +225,7 @@ class ClusteringEvaluator(
featuresCol: str = ...,
metricName: ClusteringEvaluatorMetricType = ...,
distanceMeasure: str = ...,
- weightCol: Optional[str] = ...
+ weightCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -238,11 +234,9 @@ class ClusteringEvaluator(
featuresCol: str = ...,
metricName: ClusteringEvaluatorMetricType = ...,
distanceMeasure: str = ...,
- weightCol: Optional[str] = ...
- ) -> ClusteringEvaluator: ...
- def setMetricName(
- self, value: ClusteringEvaluatorMetricType
+ weightCol: Optional[str] = ...,
) -> ClusteringEvaluator: ...
+ def setMetricName(self, value: ClusteringEvaluatorMetricType) -> ClusteringEvaluator: ...
def getMetricName(self) -> ClusteringEvaluatorMetricType: ...
def setDistanceMeasure(self, value: str) -> ClusteringEvaluator: ...
def getDistanceMeasure(self) -> str: ...
@@ -265,7 +259,7 @@ class RankingEvaluator(
predictionCol: str = ...,
labelCol: str = ...,
metricName: RankingEvaluatorMetricType = ...,
- k: int = ...
+ k: int = ...,
) -> None: ...
def setMetricName(self, value: RankingEvaluatorMetricType) -> RankingEvaluator: ...
def getMetricName(self) -> RankingEvaluatorMetricType: ...
@@ -279,5 +273,5 @@ class RankingEvaluator(
predictionCol: str = ...,
labelCol: str = ...,
metricName: RankingEvaluatorMetricType = ...,
- k: int = ...
+ k: int = ...,
) -> RankingEvaluator: ...
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index cf6b91c..18731ae 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -17,55 +17,100 @@
from pyspark import since, keyword_only, SparkContext
from pyspark.ml.linalg import _convert_to_vector
-from pyspark.ml.param.shared import HasThreshold, HasThresholds, HasInputCol, HasOutputCol, \
- HasInputCols, HasOutputCols, HasHandleInvalid, HasRelativeError, HasFeaturesCol, HasLabelCol, \
- HasSeed, HasNumFeatures, HasStepSize, HasMaxIter, TypeConverters, Param, Params
+from pyspark.ml.param.shared import (
+ HasThreshold,
+ HasThresholds,
+ HasInputCol,
+ HasOutputCol,
+ HasInputCols,
+ HasOutputCols,
+ HasHandleInvalid,
+ HasRelativeError,
+ HasFeaturesCol,
+ HasLabelCol,
+ HasSeed,
+ HasNumFeatures,
+ HasStepSize,
+ HasMaxIter,
+ TypeConverters,
+ Param,
+ Params,
+)
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
from pyspark.ml.common import inherit_doc
-__all__ = ['Binarizer',
- 'BucketedRandomProjectionLSH', 'BucketedRandomProjectionLSHModel',
- 'Bucketizer',
- 'ChiSqSelector', 'ChiSqSelectorModel',
- 'CountVectorizer', 'CountVectorizerModel',
- 'DCT',
- 'ElementwiseProduct',
- 'FeatureHasher',
- 'HashingTF',
- 'IDF', 'IDFModel',
- 'Imputer', 'ImputerModel',
- 'IndexToString',
- 'Interaction',
- 'MaxAbsScaler', 'MaxAbsScalerModel',
- 'MinHashLSH', 'MinHashLSHModel',
- 'MinMaxScaler', 'MinMaxScalerModel',
- 'NGram',
- 'Normalizer',
- 'OneHotEncoder', 'OneHotEncoderModel',
- 'PCA', 'PCAModel',
- 'PolynomialExpansion',
- 'QuantileDiscretizer',
- 'RobustScaler', 'RobustScalerModel',
- 'RegexTokenizer',
- 'RFormula', 'RFormulaModel',
- 'SQLTransformer',
- 'StandardScaler', 'StandardScalerModel',
- 'StopWordsRemover',
- 'StringIndexer', 'StringIndexerModel',
- 'Tokenizer',
- 'UnivariateFeatureSelector', 'UnivariateFeatureSelectorModel',
- 'VarianceThresholdSelector', 'VarianceThresholdSelectorModel',
- 'VectorAssembler',
- 'VectorIndexer', 'VectorIndexerModel',
- 'VectorSizeHint',
- 'VectorSlicer',
- 'Word2Vec', 'Word2VecModel']
+__all__ = [
+ "Binarizer",
+ "BucketedRandomProjectionLSH",
+ "BucketedRandomProjectionLSHModel",
+ "Bucketizer",
+ "ChiSqSelector",
+ "ChiSqSelectorModel",
+ "CountVectorizer",
+ "CountVectorizerModel",
+ "DCT",
+ "ElementwiseProduct",
+ "FeatureHasher",
+ "HashingTF",
+ "IDF",
+ "IDFModel",
+ "Imputer",
+ "ImputerModel",
+ "IndexToString",
+ "Interaction",
+ "MaxAbsScaler",
+ "MaxAbsScalerModel",
+ "MinHashLSH",
+ "MinHashLSHModel",
+ "MinMaxScaler",
+ "MinMaxScalerModel",
+ "NGram",
+ "Normalizer",
+ "OneHotEncoder",
+ "OneHotEncoderModel",
+ "PCA",
+ "PCAModel",
+ "PolynomialExpansion",
+ "QuantileDiscretizer",
+ "RobustScaler",
+ "RobustScalerModel",
+ "RegexTokenizer",
+ "RFormula",
+ "RFormulaModel",
+ "SQLTransformer",
+ "StandardScaler",
+ "StandardScalerModel",
+ "StopWordsRemover",
+ "StringIndexer",
+ "StringIndexerModel",
+ "Tokenizer",
+ "UnivariateFeatureSelector",
+ "UnivariateFeatureSelectorModel",
+ "VarianceThresholdSelector",
+ "VarianceThresholdSelectorModel",
+ "VectorAssembler",
+ "VectorIndexer",
+ "VectorIndexerModel",
+ "VectorSizeHint",
+ "VectorSlicer",
+ "Word2Vec",
+ "Word2VecModel",
+]
@inherit_doc
-class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOutputCol,
- HasInputCols, HasOutputCols, JavaMLReadable, JavaMLWritable):
+class Binarizer(
+ JavaTransformer,
+ HasThreshold,
+ HasThresholds,
+ HasInputCol,
+ HasOutputCol,
+ HasInputCols,
+ HasOutputCols,
+ JavaMLReadable,
+ JavaMLWritable,
+):
"""
Binarize a column of continuous features given a threshold. Since 3.0.0,
:py:class:`Binarize` can map multiple columns at once by setting the :py:attr:`inputCols`
@@ -112,21 +157,35 @@ class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOu
...
"""
- threshold = Param(Params._dummy(), "threshold",
- "Param for threshold used to binarize continuous features. " +
- "The features greater than the threshold will be binarized to 1.0. " +
- "The features equal to or less than the threshold will be binarized to 0.0",
- typeConverter=TypeConverters.toFloat)
- thresholds = Param(Params._dummy(), "thresholds",
- "Param for array of threshold used to binarize continuous features. " +
- "This is for multiple columns input. If transforming multiple columns " +
- "and thresholds is not set, but threshold is set, then threshold will " +
- "be applied across all columns.",
- typeConverter=TypeConverters.toListFloat)
+ threshold = Param(
+ Params._dummy(),
+ "threshold",
+ "Param for threshold used to binarize continuous features. "
+ + "The features greater than the threshold will be binarized to 1.0. "
+ + "The features equal to or less than the threshold will be binarized to 0.0",
+ typeConverter=TypeConverters.toFloat,
+ )
+ thresholds = Param(
+ Params._dummy(),
+ "thresholds",
+ "Param for array of threshold used to binarize continuous features. "
+ + "This is for multiple columns input. If transforming multiple columns "
+ + "and thresholds is not set, but threshold is set, then threshold will "
+ + "be applied across all columns.",
+ typeConverter=TypeConverters.toListFloat,
+ )
@keyword_only
- def __init__(self, *, threshold=0.0, inputCol=None, outputCol=None, thresholds=None,
- inputCols=None, outputCols=None):
+ def __init__(
+ self,
+ *,
+ threshold=0.0,
+ inputCol=None,
+ outputCol=None,
+ thresholds=None,
+ inputCols=None,
+ outputCols=None,
+ ):
"""
__init__(self, \\*, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, \
inputCols=None, outputCols=None)
@@ -139,8 +198,16 @@ class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOu
@keyword_only
@since("1.4.0")
- def setParams(self, *, threshold=0.0, inputCol=None, outputCol=None, thresholds=None,
- inputCols=None, outputCols=None):
+ def setParams(
+ self,
+ *,
+ threshold=0.0,
+ inputCol=None,
+ outputCol=None,
+ thresholds=None,
+ inputCols=None,
+ outputCols=None,
+ ):
"""
setParams(self, \\*, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, \
inputCols=None, outputCols=None)
@@ -195,10 +262,14 @@ class _LSHParams(HasInputCol, HasOutputCol):
Mixin for Locality Sensitive Hashing (LSH) algorithm parameters.
"""
- numHashTables = Param(Params._dummy(), "numHashTables", "number of hash tables, where " +
- "increasing number of hash tables lowers the false negative rate, " +
- "and decreasing it improves the running performance.",
- typeConverter=TypeConverters.toInt)
+ numHashTables = Param(
+ Params._dummy(),
+ "numHashTables",
+ "number of hash tables, where "
+ + "increasing number of hash tables lowers the false negative rate, "
+ + "and decreasing it improves the running performance.",
+ typeConverter=TypeConverters.toInt,
+ )
def __init__(self, *args):
super(_LSHParams, self).__init__(*args)
@@ -281,8 +352,7 @@ class _LSHModel(JavaModel, _LSHParams):
A dataset containing at most k items closest to the key. A column "distCol" is
added to show the distance between each row and the key.
"""
- return self._call_java("approxNearestNeighbors", dataset, key, numNearestNeighbors,
- distCol)
+ return self._call_java("approxNearestNeighbors", dataset, key, numNearestNeighbors, distCol)
def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol"):
"""
@@ -314,7 +384,7 @@ class _LSHModel(JavaModel, _LSHParams):
return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol)
-class _BucketedRandomProjectionLSHParams():
+class _BucketedRandomProjectionLSHParams:
"""
Params for :py:class:`BucketedRandomProjectionLSH` and
:py:class:`BucketedRandomProjectionLSHModel`.
@@ -322,9 +392,12 @@ class _BucketedRandomProjectionLSHParams():
.. versionadded:: 3.0.0
"""
- bucketLength = Param(Params._dummy(), "bucketLength", "the length of each hash bucket, " +
- "a larger bucket lowers the false negative rate.",
- typeConverter=TypeConverters.toFloat)
+ bucketLength = Param(
+ Params._dummy(),
+ "bucketLength",
+ "the length of each hash bucket, " + "a larger bucket lowers the false negative rate.",
+ typeConverter=TypeConverters.toFloat,
+ )
@since("2.2.0")
def getBucketLength(self):
@@ -335,8 +408,9 @@ class _BucketedRandomProjectionLSHParams():
@inherit_doc
-class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
- HasSeed, JavaMLReadable, JavaMLWritable):
+class BucketedRandomProjectionLSH(
+ _LSH, _BucketedRandomProjectionLSHParams, HasSeed, JavaMLReadable, JavaMLWritable
+):
"""
LSH class for Euclidean distance metrics.
The input is dense or sparse vectors, each of which represents a point in the Euclidean
@@ -417,22 +491,25 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
"""
@keyword_only
- def __init__(self, *, inputCol=None, outputCol=None, seed=None, numHashTables=1,
- bucketLength=None):
+ def __init__(
+ self, *, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None
+ ):
"""
__init__(self, \\*, inputCol=None, outputCol=None, seed=None, numHashTables=1, \
bucketLength=None)
"""
super(BucketedRandomProjectionLSH, self).__init__()
- self._java_obj = \
- self._new_java_obj("org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid)
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.2.0")
- def setParams(self, *, inputCol=None, outputCol=None, seed=None, numHashTables=1,
- bucketLength=None):
+ def setParams(
+ self, *, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None
+ ):
"""
setParams(self, \\*, inputCol=None, outputCol=None, seed=None, numHashTables=1, \
bucketLength=None)
@@ -458,8 +535,9 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
return BucketedRandomProjectionLSHModel(java_model)
-class BucketedRandomProjectionLSHModel(_LSHModel, _BucketedRandomProjectionLSHParams,
- JavaMLReadable, JavaMLWritable):
+class BucketedRandomProjectionLSHModel(
+ _LSHModel, _BucketedRandomProjectionLSHParams, JavaMLReadable, JavaMLWritable
+):
r"""
Model fitted by :py:class:`BucketedRandomProjectionLSH`, where multiple random vectors are
stored. The vectors are normalized to be unit vectors and each vector is used in a hash
@@ -472,8 +550,16 @@ class BucketedRandomProjectionLSHModel(_LSHModel, _BucketedRandomProjectionLSHPa
@inherit_doc
-class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
- HasHandleInvalid, JavaMLReadable, JavaMLWritable):
+class Bucketizer(
+ JavaTransformer,
+ HasInputCol,
+ HasOutputCol,
+ HasInputCols,
+ HasOutputCols,
+ HasHandleInvalid,
+ JavaMLReadable,
+ JavaMLWritable,
+):
"""
Maps a column of continuous features to a column of feature buckets. Since 3.0.0,
:py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols`
@@ -539,40 +625,59 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu
...
"""
- splits = \
- Param(Params._dummy(), "splits",
- "Split points for mapping continuous features into buckets. With n+1 splits, " +
- "there are n buckets. A bucket defined by splits x,y holds values in the " +
- "range [x,y) except the last bucket, which also includes y. The splits " +
- "should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
- "explicitly provided to cover all Double values; otherwise, values outside the " +
- "splits specified will be treated as errors.",
- typeConverter=TypeConverters.toListFloat)
-
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries "
- "containing NaN values. Values outside the splits will always be treated "
- "as errors. Options are 'skip' (filter out rows with invalid values), " +
- "'error' (throw an error), or 'keep' (keep invalid values in a " +
- "special additional bucket). Note that in the multiple column " +
- "case, the invalid handling is applied to all columns. That said " +
- "for 'error' it will throw an error if any invalids are found in " +
- "any column, for 'skip' it will skip rows with any invalids in " +
- "any columns, etc.",
- typeConverter=TypeConverters.toString)
-
- splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " +
- "continuous features into buckets for multiple columns. For each input " +
- "column, with n+1 splits, there are n buckets. A bucket defined by " +
- "splits x,y holds values in the range [x,y) except the last bucket, " +
- "which also includes y. The splits should be of length >= 3 and " +
- "strictly increasing. Values at -inf, inf must be explicitly provided " +
- "to cover all Double values; otherwise, values outside the splits " +
- "specified will be treated as errors.",
- typeConverter=TypeConverters.toListListFloat)
+ splits = Param(
+ Params._dummy(),
+ "splits",
+ "Split points for mapping continuous features into buckets. With n+1 splits, "
+ + "there are n buckets. A bucket defined by splits x,y holds values in the "
+ + "range [x,y) except the last bucket, which also includes y. The splits "
+ + "should be of length >= 3 and strictly increasing. Values at -inf, inf must be "
+ + "explicitly provided to cover all Double values; otherwise, values outside the "
+ + "splits specified will be treated as errors.",
+ typeConverter=TypeConverters.toListFloat,
+ )
+
+ handleInvalid = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "how to handle invalid entries "
+ "containing NaN values. Values outside the splits will always be treated "
+ "as errors. Options are 'skip' (filter out rows with invalid values), "
+ + "'error' (throw an error), or 'keep' (keep invalid values in a "
+ + "special additional bucket). Note that in the multiple column "
+ + "case, the invalid handling is applied to all columns. That said "
+ + "for 'error' it will throw an error if any invalids are found in "
+ + "any column, for 'skip' it will skip rows with any invalids in "
+ + "any columns, etc.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ splitsArray = Param(
+ Params._dummy(),
+ "splitsArray",
+ "The array of split points for mapping "
+ + "continuous features into buckets for multiple columns. For each input "
+ + "column, with n+1 splits, there are n buckets. A bucket defined by "
+ + "splits x,y holds values in the range [x,y) except the last bucket, "
+ + "which also includes y. The splits should be of length >= 3 and "
+ + "strictly increasing. Values at -inf, inf must be explicitly provided "
+ + "to cover all Double values; otherwise, values outside the splits "
+ + "specified will be treated as errors.",
+ typeConverter=TypeConverters.toListListFloat,
+ )
@keyword_only
- def __init__(self, *, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
- splitsArray=None, inputCols=None, outputCols=None):
+ def __init__(
+ self,
+ *,
+ splits=None,
+ inputCol=None,
+ outputCol=None,
+ handleInvalid="error",
+ splitsArray=None,
+ inputCols=None,
+ outputCols=None,
+ ):
"""
__init__(self, \\*, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
splitsArray=None, inputCols=None, outputCols=None)
@@ -585,8 +690,17 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu
@keyword_only
@since("1.4.0")
- def setParams(self, *, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
- splitsArray=None, inputCols=None, outputCols=None):
+ def setParams(
+ self,
+ *,
+ splits=None,
+ inputCol=None,
+ outputCol=None,
+ handleInvalid="error",
+ splitsArray=None,
+ inputCols=None,
+ outputCols=None,
+ ):
"""
setParams(self, \\*, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
splitsArray=None, inputCols=None, outputCols=None)
@@ -662,35 +776,53 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
"""
minTF = Param(
- Params._dummy(), "minTF", "Filter to ignore rare words in" +
- " a document. For each document, terms with frequency/count less than the given" +
- " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
- " times the term must appear in the document); if this is a double in [0,1), then this " +
- "specifies a fraction (out of the document's token count). Note that the parameter is " +
- "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
- typeConverter=TypeConverters.toFloat)
+ Params._dummy(),
+ "minTF",
+ "Filter to ignore rare words in"
+ + " a document. For each document, terms with frequency/count less than the given"
+ + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of"
+ + " times the term must appear in the document); if this is a double in [0,1), then this "
+ + "specifies a fraction (out of the document's token count). Note that the parameter is "
+ + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
+ typeConverter=TypeConverters.toFloat,
+ )
minDF = Param(
- Params._dummy(), "minDF", "Specifies the minimum number of" +
- " different documents a term must appear in to be included in the vocabulary." +
- " If this is an integer >= 1, this specifies the number of documents the term must" +
- " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
- " Default 1.0", typeConverter=TypeConverters.toFloat)
+ Params._dummy(),
+ "minDF",
+ "Specifies the minimum number of"
+ + " different documents a term must appear in to be included in the vocabulary."
+ + " If this is an integer >= 1, this specifies the number of documents the term must"
+ + " appear in; if this is a double in [0,1), then this specifies the fraction of documents."
+ + " Default 1.0",
+ typeConverter=TypeConverters.toFloat,
+ )
maxDF = Param(
- Params._dummy(), "maxDF", "Specifies the maximum number of" +
- " different documents a term could appear in to be included in the vocabulary." +
- " A term that appears more than the threshold will be ignored. If this is an" +
- " integer >= 1, this specifies the maximum number of documents the term could appear in;" +
- " if this is a double in [0,1), then this specifies the maximum" +
- " fraction of documents the term could appear in." +
- " Default (2^63) - 1", typeConverter=TypeConverters.toFloat)
+ Params._dummy(),
+ "maxDF",
+ "Specifies the maximum number of"
+ + " different documents a term could appear in to be included in the vocabulary."
+ + " A term that appears more than the threshold will be ignored. If this is an"
+ + " integer >= 1, this specifies the maximum number of documents the term could appear in;"
+ + " if this is a double in [0,1), then this specifies the maximum"
+ + " fraction of documents the term could appear in."
+ + " Default (2^63) - 1",
+ typeConverter=TypeConverters.toFloat,
+ )
vocabSize = Param(
- Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
- typeConverter=TypeConverters.toInt)
+ Params._dummy(),
+ "vocabSize",
+ "max size of the vocabulary. Default 1 << 18.",
+ typeConverter=TypeConverters.toInt,
+ )
binary = Param(
- Params._dummy(), "binary", "Binary toggle to control the output vector values." +
- " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
- " for discrete probabilistic models that model binary events rather than integer counts." +
- " Default False", typeConverter=TypeConverters.toBoolean)
+ Params._dummy(),
+ "binary",
+ "Binary toggle to control the output vector values."
+ + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful"
+ + " for discrete probabilistic models that model binary events rather than integer counts."
+ + " Default False",
+ typeConverter=TypeConverters.toBoolean,
+ )
def __init__(self, *args):
super(_CountVectorizerParams, self).__init__(*args)
@@ -791,22 +923,39 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav
"""
@keyword_only
- def __init__(self, *, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,
- binary=False, inputCol=None, outputCol=None):
+ def __init__(
+ self,
+ *,
+ minTF=1.0,
+ minDF=1.0,
+ maxDF=2 ** 63 - 1,
+ vocabSize=1 << 18,
+ binary=False,
+ inputCol=None,
+ outputCol=None,
+ ):
"""
__init__(self, \\*, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,\
binary=False, inputCol=None,outputCol=None)
"""
super(CountVectorizer, self).__init__()
- self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
- self.uid)
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
- def setParams(self, *, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,
- binary=False, inputCol=None, outputCol=None):
+ def setParams(
+ self,
+ *,
+ minTF=1.0,
+ minDF=1.0,
+ maxDF=2 ** 63 - 1,
+ vocabSize=1 << 18,
+ binary=False,
+ inputCol=None,
+ outputCol=None,
+ ):
"""
setParams(self, \\*, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,\
binary=False, inputCol=None, outputCol=None)
@@ -899,7 +1048,8 @@ class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, Ja
java_class = sc._gateway.jvm.java.lang.String
jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
model = CountVectorizerModel._create_from_java_class(
- "org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
+ "org.apache.spark.ml.feature.CountVectorizerModel", jvocab
+ )
model.setInputCol(inputCol)
if outputCol is not None:
model.setOutputCol(outputCol)
@@ -975,8 +1125,12 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit
False
"""
- inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " +
- "default False.", typeConverter=TypeConverters.toBoolean)
+ inverse = Param(
+ Params._dummy(),
+ "inverse",
+ "Set transformer to perform inverse DCT, " + "default False.",
+ typeConverter=TypeConverters.toBoolean,
+ )
@keyword_only
def __init__(self, *, inverse=False, inputCol=None, outputCol=None):
@@ -1027,8 +1181,9 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit
@inherit_doc
-class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
- JavaMLWritable):
+class ElementwiseProduct(
+ JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable
+):
"""
Outputs the Hadamard product (i.e., the element-wise product) of each input vector
with a provided "weight" vector. In other words, it scales each column of the dataset
@@ -1060,8 +1215,12 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
True
"""
- scalingVec = Param(Params._dummy(), "scalingVec", "Vector for hadamard product.",
- typeConverter=TypeConverters.toVector)
+ scalingVec = Param(
+ Params._dummy(),
+ "scalingVec",
+ "Vector for hadamard product.",
+ typeConverter=TypeConverters.toVector,
+ )
@keyword_only
def __init__(self, *, scalingVec=None, inputCol=None, outputCol=None):
@@ -1069,8 +1228,9 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
__init__(self, \\*, scalingVec=None, inputCol=None, outputCol=None)
"""
super(ElementwiseProduct, self).__init__()
- self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct",
- self.uid)
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.feature.ElementwiseProduct", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -1112,8 +1272,9 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
@inherit_doc
-class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, JavaMLReadable,
- JavaMLWritable):
+class FeatureHasher(
+ JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, JavaMLReadable, JavaMLWritable
+):
"""
Feature hashing projects a set of categorical or numerical features into a feature vector of
specified dimension (typically substantially smaller than that of the original feature
@@ -1171,13 +1332,17 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
True
"""
- categoricalCols = Param(Params._dummy(), "categoricalCols",
- "numeric columns to treat as categorical",
- typeConverter=TypeConverters.toListString)
+ categoricalCols = Param(
+ Params._dummy(),
+ "categoricalCols",
+ "numeric columns to treat as categorical",
+ typeConverter=TypeConverters.toListString,
+ )
@keyword_only
- def __init__(self, *, numFeatures=1 << 18, inputCols=None, outputCol=None,
- categoricalCols=None):
+ def __init__(
+ self, *, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None
+ ):
"""
__init__(self, \\*, numFeatures=1 << 18, inputCols=None, outputCol=None, \
categoricalCols=None)
@@ -1190,8 +1355,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
@keyword_only
@since("2.3.0")
- def setParams(self, *, numFeatures=1 << 18, inputCols=None, outputCol=None,
- categoricalCols=None):
+ def setParams(
+ self, *, numFeatures=1 << 18, inputCols=None, outputCol=None, categoricalCols=None
+ ):
"""
setParams(self, \\*, numFeatures=1 << 18, inputCols=None, outputCol=None, \
categoricalCols=None)
@@ -1234,8 +1400,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
@inherit_doc
-class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable,
- JavaMLWritable):
+class HashingTF(
+ JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, JavaMLWritable
+):
"""
Maps a sequence of terms to their term frequencies using the hashing trick.
Currently we use Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32)
@@ -1270,10 +1437,14 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
5
"""
- binary = Param(Params._dummy(), "binary", "If True, all non zero counts are set to 1. " +
- "This is useful for discrete probabilistic models that model binary events " +
- "rather than integer counts. Default False.",
- typeConverter=TypeConverters.toBoolean)
+ binary = Param(
+ Params._dummy(),
+ "binary",
+ "If True, all non zero counts are set to 1. "
+ + "This is useful for discrete probabilistic models that model binary events "
+ + "rather than integer counts. Default False.",
+ typeConverter=TypeConverters.toBoolean,
+ )
@keyword_only
def __init__(self, *, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None):
@@ -1344,9 +1515,12 @@ class _IDFParams(HasInputCol, HasOutputCol):
.. versionadded:: 3.0.0
"""
- minDocFreq = Param(Params._dummy(), "minDocFreq",
- "minimum number of documents in which a term should appear for filtering",
- typeConverter=TypeConverters.toInt)
+ minDocFreq = Param(
+ Params._dummy(),
+ "minDocFreq",
+ "minimum number of documents in which a term should appear for filtering",
+ typeConverter=TypeConverters.toInt,
+ )
@since("1.4.0")
def getMinDocFreq(self):
@@ -1503,16 +1677,23 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has
.. versionadded:: 3.0.0
"""
- strategy = Param(Params._dummy(), "strategy",
- "strategy for imputation. If mean, then replace missing values using the mean "
- "value of the feature. If median, then replace missing values using the "
- "median value of the feature. If mode, then replace missing using the most "
- "frequent value of the feature.",
- typeConverter=TypeConverters.toString)
-
- missingValue = Param(Params._dummy(), "missingValue",
- "The placeholder for the missing values. All occurrences of missingValue "
- "will be imputed.", typeConverter=TypeConverters.toFloat)
+ strategy = Param(
+ Params._dummy(),
+ "strategy",
+ "strategy for imputation. If mean, then replace missing values using the mean "
+ "value of the feature. If median, then replace missing values using the "
+ "median value of the feature. If mode, then replace missing using the most "
+ "frequent value of the feature.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ missingValue = Param(
+ Params._dummy(),
+ "missingValue",
+ "The placeholder for the missing values. All occurrences of missingValue "
+ "will be imputed.",
+ typeConverter=TypeConverters.toFloat,
+ )
def __init__(self, *args):
super(_ImputerParams, self).__init__(*args)
@@ -1649,8 +1830,17 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
"""
@keyword_only
- def __init__(self, *, strategy="mean", missingValue=float("nan"), inputCols=None,
- outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
+ def __init__(
+ self,
+ *,
+ strategy="mean",
+ missingValue=float("nan"),
+ inputCols=None,
+ outputCols=None,
+ inputCol=None,
+ outputCol=None,
+ relativeError=0.001,
+ ):
"""
__init__(self, \\*, strategy="mean", missingValue=float("nan"), inputCols=None, \
outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
@@ -1662,8 +1852,17 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
@keyword_only
@since("2.2.0")
- def setParams(self, *, strategy="mean", missingValue=float("nan"), inputCols=None,
- outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
+ def setParams(
+ self,
+ *,
+ strategy="mean",
+ missingValue=float("nan"),
+ inputCols=None,
+ outputCols=None,
+ inputCol=None,
+ outputCol=None,
+ relativeError=0.001,
+ ):
"""
setParams(self, \\*, strategy="mean", missingValue=float("nan"), inputCols=None, \
outputCols=None, inputCol=None, outputCol=None, relativeError=0.001)
@@ -1849,6 +2048,7 @@ class _MaxAbsScalerParams(HasInputCol, HasOutputCol):
.. versionadded:: 3.0.0
"""
+
pass
@@ -2082,10 +2282,18 @@ class _MinMaxScalerParams(HasInputCol, HasOutputCol):
.. versionadded:: 3.0.0
"""
- min = Param(Params._dummy(), "min", "Lower bound of the output feature range",
- typeConverter=TypeConverters.toFloat)
- max = Param(Params._dummy(), "max", "Upper bound of the output feature range",
- typeConverter=TypeConverters.toFloat)
+ min = Param(
+ Params._dummy(),
+ "min",
+ "Lower bound of the output feature range",
+ typeConverter=TypeConverters.toFloat,
+ )
+ max = Param(
+ Params._dummy(),
+ "max",
+ "Upper bound of the output feature range",
+ typeConverter=TypeConverters.toFloat,
+ )
def __init__(self, *args):
super(_MinMaxScalerParams, self).__init__(*args)
@@ -2311,8 +2519,12 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr
True
"""
- n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)",
- typeConverter=TypeConverters.toInt)
+ n = Param(
+ Params._dummy(),
+ "n",
+ "number of elements per n-gram (>=1)",
+ typeConverter=TypeConverters.toInt,
+ )
@keyword_only
def __init__(self, *, n=2, inputCol=None, outputCol=None):
@@ -2395,8 +2607,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
True
"""
- p = Param(Params._dummy(), "p", "the p norm value.",
- typeConverter=TypeConverters.toFloat)
+ p = Param(Params._dummy(), "p", "the p norm value.", typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, *, p=2.0, inputCol=None, outputCol=None):
@@ -2446,23 +2657,32 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
return self._set(outputCol=value)
-class _OneHotEncoderParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols,
- HasHandleInvalid):
+class _OneHotEncoderParams(
+ HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasHandleInvalid
+):
"""
Params for :py:class:`OneHotEncoder` and :py:class:`OneHotEncoderModel`.
.. versionadded:: 3.0.0
"""
- handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " +
- "transform(). Options are 'keep' (invalid data presented as an extra " +
- "categorical feature) or error (throw an error). Note that this Param " +
- "is only used during transform; during fitting, invalid data will " +
- "result in an error.",
- typeConverter=TypeConverters.toString)
-
- dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category",
- typeConverter=TypeConverters.toBoolean)
+ handleInvalid = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "How to handle invalid data during "
+ + "transform(). Options are 'keep' (invalid data presented as an extra "
+ + "categorical feature) or error (throw an error). Note that this Param "
+ + "is only used during transform; during fitting, invalid data will "
+ + "result in an error.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ dropLast = Param(
+ Params._dummy(),
+ "dropLast",
+ "whether to drop the last category",
+ typeConverter=TypeConverters.toBoolean,
+ )
def __init__(self, *args):
super(_OneHotEncoderParams, self).__init__(*args)
@@ -2541,22 +2761,37 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW
"""
@keyword_only
- def __init__(self, *, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True,
- inputCol=None, outputCol=None):
+ def __init__(
+ self,
+ *,
+ inputCols=None,
+ outputCols=None,
+ handleInvalid="error",
+ dropLast=True,
+ inputCol=None,
+ outputCol=None,
+ ):
"""
__init__(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \
inputCol=None, outputCol=None)
"""
super(OneHotEncoder, self).__init__()
- self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.feature.OneHotEncoder", self.uid)
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.3.0")
- def setParams(self, *, inputCols=None, outputCols=None, handleInvalid="error",
- dropLast=True, inputCol=None, outputCol=None):
+ def setParams(
+ self,
+ *,
+ inputCols=None,
+ outputCols=None,
+ handleInvalid="error",
+ dropLast=True,
+ inputCol=None,
+ outputCol=None,
+ ):
"""
setParams(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", \
dropLast=True, inputCol=None, outputCol=None)
@@ -2671,8 +2906,9 @@ class OneHotEncoderModel(JavaModel, _OneHotEncoderParams, JavaMLReadable, JavaML
@inherit_doc
-class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
- JavaMLWritable):
+class PolynomialExpansion(
+ JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable
+):
"""
Perform feature expansion in a polynomial space. As said in `wikipedia of Polynomial Expansion
<http://en.wikipedia.org/wiki/Polynomial_expansion>`_, "In mathematics, an
@@ -2704,8 +2940,12 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
True
"""
- degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)",
- typeConverter=TypeConverters.toInt)
+ degree = Param(
+ Params._dummy(),
+ "degree",
+ "the polynomial degree to expand (>= 1)",
+ typeConverter=TypeConverters.toInt,
+ )
@keyword_only
def __init__(self, *, degree=2, inputCol=None, outputCol=None):
@@ -2714,7 +2954,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
"""
super(PolynomialExpansion, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.feature.PolynomialExpansion", self.uid)
+ "org.apache.spark.ml.feature.PolynomialExpansion", self.uid
+ )
self._setDefault(degree=2)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -2757,8 +2998,17 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
@inherit_doc
-class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
- HasHandleInvalid, HasRelativeError, JavaMLReadable, JavaMLWritable):
+class QuantileDiscretizer(
+ JavaEstimator,
+ HasInputCol,
+ HasOutputCol,
+ HasInputCols,
+ HasOutputCols,
+ HasHandleInvalid,
+ HasRelativeError,
+ JavaMLReadable,
+ JavaMLWritable,
+):
"""
:py:class:`QuantileDiscretizer` takes a column with continuous features and outputs a column
with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets`
@@ -2852,46 +3102,78 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols
...
"""
- numBuckets = Param(Params._dummy(), "numBuckets",
- "Maximum number of buckets (quantiles, or " +
- "categories) into which data points are grouped. Must be >= 2.",
- typeConverter=TypeConverters.toInt)
-
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
- "Options are skip (filter out rows with invalid values), " +
- "error (throw an error), or keep (keep invalid values in a special " +
- "additional bucket). Note that in the multiple columns " +
- "case, the invalid handling is applied to all columns. That said " +
- "for 'error' it will throw an error if any invalids are found in " +
- "any columns, for 'skip' it will skip rows with any invalids in " +
- "any columns, etc.",
- typeConverter=TypeConverters.toString)
-
- numBucketsArray = Param(Params._dummy(), "numBucketsArray", "Array of number of buckets " +
- "(quantiles, or categories) into which data points are grouped. " +
- "This is for multiple columns input. If transforming multiple " +
- "columns and numBucketsArray is not set, but numBuckets is set, " +
- "then numBuckets will be applied across all columns.",
- typeConverter=TypeConverters.toListInt)
+ numBuckets = Param(
+ Params._dummy(),
+ "numBuckets",
+ "Maximum number of buckets (quantiles, or "
+ + "categories) into which data points are grouped. Must be >= 2.",
+ typeConverter=TypeConverters.toInt,
+ )
+
+ handleInvalid = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "how to handle invalid entries. "
+ + "Options are skip (filter out rows with invalid values), "
+ + "error (throw an error), or keep (keep invalid values in a special "
+ + "additional bucket). Note that in the multiple columns "
+ + "case, the invalid handling is applied to all columns. That said "
+ + "for 'error' it will throw an error if any invalids are found in "
+ + "any columns, for 'skip' it will skip rows with any invalids in "
+ + "any columns, etc.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ numBucketsArray = Param(
+ Params._dummy(),
+ "numBucketsArray",
+ "Array of number of buckets "
+ + "(quantiles, or categories) into which data points are grouped. "
+ + "This is for multiple columns input. If transforming multiple "
+ + "columns and numBucketsArray is not set, but numBuckets is set, "
+ + "then numBuckets will be applied across all columns.",
+ typeConverter=TypeConverters.toListInt,
+ )
@keyword_only
- def __init__(self, *, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
- handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None):
+ def __init__(
+ self,
+ *,
+ numBuckets=2,
+ inputCol=None,
+ outputCol=None,
+ relativeError=0.001,
+ handleInvalid="error",
+ numBucketsArray=None,
+ inputCols=None,
+ outputCols=None,
+ ):
"""
__init__(self, \\*, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
"""
super(QuantileDiscretizer, self).__init__()
- self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
- self.uid)
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.feature.QuantileDiscretizer", self.uid
+ )
self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.0.0")
- def setParams(self, *, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
- handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None):
+ def setParams(
+ self,
+ *,
+ numBuckets=2,
+ inputCol=None,
+ outputCol=None,
+ relativeError=0.001,
+ handleInvalid="error",
+ numBucketsArray=None,
+ inputCols=None,
+ outputCols=None,
+ ):
"""
setParams(self, \\*, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
@@ -2971,17 +3253,21 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols
"""
Private method to convert the java_model to a Python model.
"""
- if (self.isSet(self.inputCol)):
- return Bucketizer(splits=list(java_model.getSplits()),
- inputCol=self.getInputCol(),
- outputCol=self.getOutputCol(),
- handleInvalid=self.getHandleInvalid())
+ if self.isSet(self.inputCol):
+ return Bucketizer(
+ splits=list(java_model.getSplits()),
+ inputCol=self.getInputCol(),
+ outputCol=self.getOutputCol(),
+ handleInvalid=self.getHandleInvalid(),
+ )
else:
splitsArrayList = [list(x) for x in list(java_model.getSplitsArray())]
- return Bucketizer(splitsArray=splitsArrayList,
- inputCols=self.getInputCols(),
- outputCols=self.getOutputCols(),
- handleInvalid=self.getHandleInvalid())
+ return Bucketizer(
+ splitsArray=splitsArrayList,
+ inputCols=self.getInputCols(),
+ outputCols=self.getOutputCols(),
+ handleInvalid=self.getHandleInvalid(),
+ )
class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
@@ -2991,19 +3277,36 @@ class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
.. versionadded:: 3.0.0
"""
- lower = Param(Params._dummy(), "lower", "Lower quantile to calculate quantile range",
- typeConverter=TypeConverters.toFloat)
- upper = Param(Params._dummy(), "upper", "Upper quantile to calculate quantile range",
- typeConverter=TypeConverters.toFloat)
- withCentering = Param(Params._dummy(), "withCentering", "Whether to center data with median",
- typeConverter=TypeConverters.toBoolean)
- withScaling = Param(Params._dummy(), "withScaling", "Whether to scale the data to "
- "quantile range", typeConverter=TypeConverters.toBoolean)
+ lower = Param(
+ Params._dummy(),
+ "lower",
+ "Lower quantile to calculate quantile range",
+ typeConverter=TypeConverters.toFloat,
+ )
+ upper = Param(
+ Params._dummy(),
+ "upper",
+ "Upper quantile to calculate quantile range",
+ typeConverter=TypeConverters.toFloat,
+ )
+ withCentering = Param(
+ Params._dummy(),
+ "withCentering",
+ "Whether to center data with median",
+ typeConverter=TypeConverters.toBoolean,
+ )
+ withScaling = Param(
+ Params._dummy(),
+ "withScaling",
+ "Whether to scale the data to " "quantile range",
+ typeConverter=TypeConverters.toBoolean,
+ )
def __init__(self, *args):
super(_RobustScalerParams, self).__init__(*args)
- self._setDefault(lower=0.25, upper=0.75, withCentering=False, withScaling=True,
- relativeError=0.001)
+ self._setDefault(
+ lower=0.25, upper=0.75, withCentering=False, withScaling=True, relativeError=0.001
+ )
@since("3.0.0")
def getLower(self):
@@ -3089,8 +3392,17 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri
"""
@keyword_only
- def __init__(self, *, lower=0.25, upper=0.75, withCentering=False, withScaling=True,
- inputCol=None, outputCol=None, relativeError=0.001):
+ def __init__(
+ self,
+ *,
+ lower=0.25,
+ upper=0.75,
+ withCentering=False,
+ withScaling=True,
+ inputCol=None,
+ outputCol=None,
+ relativeError=0.001,
+ ):
"""
__init__(self, \\*, lower=0.25, upper=0.75, withCentering=False, withScaling=True, \
inputCol=None, outputCol=None, relativeError=0.001)
@@ -3102,8 +3414,17 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri
@keyword_only
@since("3.0.0")
- def setParams(self, *, lower=0.25, upper=0.75, withCentering=False, withScaling=True,
- inputCol=None, outputCol=None, relativeError=0.001):
+ def setParams(
+ self,
+ *,
+ lower=0.25,
+ upper=0.75,
+ withCentering=False,
+ withScaling=True,
+ inputCol=None,
+ outputCol=None,
+ relativeError=0.001,
+ ):
"""
setParams(self, \\*, lower=0.25, upper=0.75, withCentering=False, withScaling=True, \
inputCol=None, outputCol=None, relativeError=0.001)
@@ -3249,18 +3570,41 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
True
"""
- minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)",
- typeConverter=TypeConverters.toInt)
- gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens " +
- "(False)")
- pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing",
- typeConverter=TypeConverters.toString)
- toLowercase = Param(Params._dummy(), "toLowercase", "whether to convert all characters to " +
- "lowercase before tokenizing", typeConverter=TypeConverters.toBoolean)
+ minTokenLength = Param(
+ Params._dummy(),
+ "minTokenLength",
+ "minimum token length (>= 0)",
+ typeConverter=TypeConverters.toInt,
+ )
+ gaps = Param(
+ Params._dummy(),
+ "gaps",
+ "whether regex splits on gaps (True) or matches tokens " + "(False)",
+ )
+ pattern = Param(
+ Params._dummy(),
+ "pattern",
+ "regex pattern (Java dialect) used for tokenizing",
+ typeConverter=TypeConverters.toString,
+ )
+ toLowercase = Param(
+ Params._dummy(),
+ "toLowercase",
+ "whether to convert all characters to " + "lowercase before tokenizing",
+ typeConverter=TypeConverters.toBoolean,
+ )
@keyword_only
- def __init__(self, *, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None,
- outputCol=None, toLowercase=True):
+ def __init__(
+ self,
+ *,
+ minTokenLength=1,
+ gaps=True,
+ pattern="\\s+",
+ inputCol=None,
+ outputCol=None,
+ toLowercase=True,
+ ):
"""
__init__(self, \\*, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \
outputCol=None, toLowercase=True)
@@ -3273,8 +3617,16 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
@keyword_only
@since("1.4.0")
- def setParams(self, *, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None,
- outputCol=None, toLowercase=True):
+ def setParams(
+ self,
+ *,
+ minTokenLength=1,
+ gaps=True,
+ pattern="\\s+",
+ inputCol=None,
+ outputCol=None,
+ toLowercase=True,
+ ):
"""
setParams(self, \\*, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \
outputCol=None, toLowercase=True)
@@ -3377,8 +3729,9 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
True
"""
- statement = Param(Params._dummy(), "statement", "SQL statement",
- typeConverter=TypeConverters.toString)
+ statement = Param(
+ Params._dummy(), "statement", "SQL statement", typeConverter=TypeConverters.toString
+ )
@keyword_only
def __init__(self, *, statement=None):
@@ -3422,10 +3775,15 @@ class _StandardScalerParams(HasInputCol, HasOutputCol):
.. versionadded:: 3.0.0
"""
- withMean = Param(Params._dummy(), "withMean", "Center data with mean",
- typeConverter=TypeConverters.toBoolean)
- withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation",
- typeConverter=TypeConverters.toBoolean)
+ withMean = Param(
+ Params._dummy(), "withMean", "Center data with mean", typeConverter=TypeConverters.toBoolean
+ )
+ withStd = Param(
+ Params._dummy(),
+ "withStd",
+ "Scale to unit standard deviation",
+ typeConverter=TypeConverters.toBoolean,
+ )
def __init__(self, *args):
super(_StandardScalerParams, self).__init__(*args)
@@ -3582,27 +3940,35 @@ class StandardScalerModel(JavaModel, _StandardScalerParams, JavaMLReadable, Java
return self._call_java("mean")
-class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol,
- HasInputCols, HasOutputCols):
+class _StringIndexerParams(
+ JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols
+):
"""
Params for :py:class:`StringIndexer` and :py:class:`StringIndexerModel`.
"""
- stringOrderType = Param(Params._dummy(), "stringOrderType",
- "How to order labels of string column. The first label after " +
- "ordering is assigned an index of 0. Supported options: " +
- "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. " +
- "Default is frequencyDesc. In case of equal frequency when " +
- "under frequencyDesc/Asc, the strings are further sorted " +
- "alphabetically",
- typeConverter=TypeConverters.toString)
-
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
- "or NULL values) in features and label column of string type. " +
- "Options are 'skip' (filter out rows with invalid data), " +
- "error (throw an error), or 'keep' (put invalid data " +
- "in a special additional bucket, at index numLabels).",
- typeConverter=TypeConverters.toString)
+ stringOrderType = Param(
+ Params._dummy(),
+ "stringOrderType",
+ "How to order labels of string column. The first label after "
+ + "ordering is assigned an index of 0. Supported options: "
+ + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc. "
+ + "Default is frequencyDesc. In case of equal frequency when "
+ + "under frequencyDesc/Asc, the strings are further sorted "
+ + "alphabetically",
+ typeConverter=TypeConverters.toString,
+ )
+
+ handleInvalid = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "how to handle invalid data (unseen "
+ + "or NULL values) in features and label column of string type. "
+ + "Options are 'skip' (filter out rows with invalid data), "
+ + "error (throw an error), or 'keep' (put invalid data "
+ + "in a special additional bucket, at index numLabels).",
+ typeConverter=TypeConverters.toString,
+ )
def __init__(self, *args):
super(_StringIndexerParams, self).__init__(*args)
@@ -3701,8 +4067,16 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
"""
@keyword_only
- def __init__(self, *, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
- handleInvalid="error", stringOrderType="frequencyDesc"):
+ def __init__(
+ self,
+ *,
+ inputCol=None,
+ outputCol=None,
+ inputCols=None,
+ outputCols=None,
+ handleInvalid="error",
+ stringOrderType="frequencyDesc",
+ ):
"""
__init__(self, \\*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
handleInvalid="error", stringOrderType="frequencyDesc")
@@ -3714,8 +4088,16 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
@keyword_only
@since("1.4.0")
- def setParams(self, *, inputCol=None, outputCol=None, inputCols=None, outputCols=None,
- handleInvalid="error", stringOrderType="frequencyDesc"):
+ def setParams(
+ self,
+ *,
+ inputCol=None,
+ outputCol=None,
+ inputCols=None,
+ outputCols=None,
+ handleInvalid="error",
+ stringOrderType="frequencyDesc",
+ ):
"""
setParams(self, \\*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
handleInvalid="error", stringOrderType="frequencyDesc")
@@ -3818,7 +4200,8 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
java_class = sc._gateway.jvm.java.lang.String
jlabels = StringIndexerModel._new_java_array(labels, java_class)
model = StringIndexerModel._create_from_java_class(
- "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+ "org.apache.spark.ml.feature.StringIndexerModel", jlabels
+ )
model.setInputCol(inputCol)
if outputCol is not None:
model.setOutputCol(outputCol)
@@ -3828,8 +4211,7 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
@classmethod
@since("3.0.0")
- def from_arrays_of_labels(cls, arrayOfLabels, inputCols, outputCols=None,
- handleInvalid=None):
+ def from_arrays_of_labels(cls, arrayOfLabels, inputCols, outputCols=None, handleInvalid=None):
"""
Construct the model directly from an array of array of label strings,
requires an active SparkContext.
@@ -3838,7 +4220,8 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
java_class = sc._gateway.jvm.java.lang.String
jlabels = StringIndexerModel._new_java_array(arrayOfLabels, java_class)
model = StringIndexerModel._create_from_java_class(
- "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+ "org.apache.spark.ml.feature.StringIndexerModel", jlabels
+ )
model.setInputCols(inputCols)
if outputCols is not None:
model.setOutputCols(outputCols)
@@ -3882,10 +4265,13 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
StringIndexer : for converting categorical values into category indices
"""
- labels = Param(Params._dummy(), "labels",
- "Optional array of labels specifying index-string mapping." +
- " If not provided or if empty, then metadata from inputCol is used instead.",
- typeConverter=TypeConverters.toListString)
+ labels = Param(
+ Params._dummy(),
+ "labels",
+ "Optional array of labels specifying index-string mapping."
+ + " If not provided or if empty, then metadata from inputCol is used instead.",
+ typeConverter=TypeConverters.toListString,
+ )
@keyword_only
def __init__(self, *, inputCol=None, outputCol=None, labels=None):
@@ -3893,8 +4279,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
__init__(self, \\*, inputCol=None, outputCol=None, labels=None)
"""
super(IndexToString, self).__init__()
- self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString",
- self.uid)
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -3935,8 +4320,15 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
return self._set(outputCol=value)
-class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
- JavaMLReadable, JavaMLWritable):
+class StopWordsRemover(
+ JavaTransformer,
+ HasInputCol,
+ HasOutputCol,
+ HasInputCols,
+ HasOutputCols,
+ JavaMLReadable,
+ JavaMLWritable,
+):
"""
A feature transformer that filters out stop words from input.
Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting
@@ -3981,32 +4373,66 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols,
...
"""
- stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out",
- typeConverter=TypeConverters.toListString)
- caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
- "comparison over the stop words", typeConverter=TypeConverters.toBoolean)
- locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " +
- "is true", typeConverter=TypeConverters.toString)
+ stopWords = Param(
+ Params._dummy(),
+ "stopWords",
+ "The words to be filtered out",
+ typeConverter=TypeConverters.toListString,
+ )
+ caseSensitive = Param(
+ Params._dummy(),
+ "caseSensitive",
+ "whether to do a case sensitive " + "comparison over the stop words",
+ typeConverter=TypeConverters.toBoolean,
+ )
+ locale = Param(
+ Params._dummy(),
+ "locale",
+ "locale of the input. ignored when case sensitive " + "is true",
+ typeConverter=TypeConverters.toString,
+ )
@keyword_only
- def __init__(self, *, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
- locale=None, inputCols=None, outputCols=None):
+ def __init__(
+ self,
+ *,
+ inputCol=None,
+ outputCol=None,
+ stopWords=None,
+ caseSensitive=False,
+ locale=None,
+ inputCols=None,
+ outputCols=None,
+ ):
"""
__init__(self, \\*, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None, inputCols=None, outputCols=None)
"""
super(StopWordsRemover, self).__init__()
- self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
- self.uid)
- self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
- caseSensitive=False, locale=self._java_obj.getLocale())
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.feature.StopWordsRemover", self.uid
+ )
+ self._setDefault(
+ stopWords=StopWordsRemover.loadDefaultStopWords("english"),
+ caseSensitive=False,
+ locale=self._java_obj.getLocale(),
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
- def setParams(self, *, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
- locale=None, inputCols=None, outputCols=None):
+ def setParams(
+ self,
+ *,
+ inputCol=None,
+ outputCol=None,
+ stopWords=None,
+ caseSensitive=False,
+ locale=None,
+ inputCols=None,
+ outputCols=None,
+ ):
"""
setParams(self, \\*, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None, inputCols=None, outputCols=None)
@@ -4165,8 +4591,9 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
@inherit_doc
-class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable,
- JavaMLWritable):
+class VectorAssembler(
+ JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable, JavaMLWritable
+):
"""
A feature transformer that merges multiple columns into a vector column.
@@ -4212,15 +4639,19 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInva
...
"""
- handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " +
- "and NaN values). Options are 'skip' (filter out rows with invalid " +
- "data), 'error' (throw an error), or 'keep' (return relevant number " +
- "of NaN in the output). Column lengths are taken from the size of ML " +
- "Attribute Group, which can be set using `VectorSizeHint` in a " +
- "pipeline before `VectorAssembler`. Column lengths can also be " +
- "inferred from first rows of the data since it is safe to do so but " +
- "only in case of 'error' or 'skip').",
- typeConverter=TypeConverters.toString)
+ handleInvalid = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "How to handle invalid data (NULL "
+ + "and NaN values). Options are 'skip' (filter out rows with invalid "
+ + "data), 'error' (throw an error), or 'keep' (return relevant number "
+ + "of NaN in the output). Column lengths are taken from the size of ML "
+ + "Attribute Group, which can be set using `VectorSizeHint` in a "
+ + "pipeline before `VectorAssembler`. Column lengths can also be "
+ + "inferred from first rows of the data since it is safe to do so but "
+ + "only in case of 'error' or 'skip').",
+ typeConverter=TypeConverters.toString,
+ )
@keyword_only
def __init__(self, *, inputCols=None, outputCol=None, handleInvalid="error"):
@@ -4269,17 +4700,25 @@ class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid):
.. versionadded:: 3.0.0
"""
- maxCategories = Param(Params._dummy(), "maxCategories",
- "Threshold for the number of values a categorical feature can take " +
- "(>= 2). If a feature is found to have > maxCategories values, then " +
- "it is declared continuous.", typeConverter=TypeConverters.toInt)
-
- handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " +
- "(unseen labels or NULL values). Options are 'skip' (filter out " +
- "rows with invalid data), 'error' (throw an error), or 'keep' (put " +
- "invalid data in a special additional bucket, at index of the number " +
- "of categories of the feature).",
- typeConverter=TypeConverters.toString)
+ maxCategories = Param(
+ Params._dummy(),
+ "maxCategories",
+ "Threshold for the number of values a categorical feature can take "
+ + "(>= 2). If a feature is found to have > maxCategories values, then "
+ + "it is declared continuous.",
+ typeConverter=TypeConverters.toInt,
+ )
+
+ handleInvalid = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "How to handle invalid data "
+ + "(unseen labels or NULL values). Options are 'skip' (filter out "
+ + "rows with invalid data), 'error' (throw an error), or 'keep' (put "
+ + "invalid data in a special additional bucket, at index of the number "
+ + "of categories of the feature).",
+ typeConverter=TypeConverters.toString,
+ )
def __init__(self, *args):
super(_VectorIndexerParams, self).__init__(*args)
@@ -4519,13 +4958,22 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J
True
"""
- indices = Param(Params._dummy(), "indices", "An array of indices to select features from " +
- "a vector column. There can be no overlap with names.",
- typeConverter=TypeConverters.toListInt)
- names = Param(Params._dummy(), "names", "An array of feature names to select features from " +
- "a vector column. These names must be specified by ML " +
- "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " +
- "indices.", typeConverter=TypeConverters.toListString)
+ indices = Param(
+ Params._dummy(),
+ "indices",
+ "An array of indices to select features from "
+ + "a vector column. There can be no overlap with names.",
+ typeConverter=TypeConverters.toListInt,
+ )
+ names = Param(
+ Params._dummy(),
+ "names",
+ "An array of feature names to select features from "
+ + "a vector column. These names must be specified by ML "
+ + "org.apache.spark.ml.attribute.Attribute. There can be no overlap with "
+ + "indices.",
+ typeConverter=TypeConverters.toListString,
+ )
@keyword_only
def __init__(self, *, inputCol=None, outputCol=None, indices=None, names=None):
@@ -4596,28 +5044,51 @@ class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCo
.. versionadded:: 3.0.0
"""
- vectorSize = Param(Params._dummy(), "vectorSize",
- "the dimension of codes after transforming from words",
- typeConverter=TypeConverters.toInt)
- numPartitions = Param(Params._dummy(), "numPartitions",
- "number of partitions for sentences of words",
- typeConverter=TypeConverters.toInt)
- minCount = Param(Params._dummy(), "minCount",
- "the minimum number of times a token must appear to be included in the " +
- "word2vec model's vocabulary", typeConverter=TypeConverters.toInt)
- windowSize = Param(Params._dummy(), "windowSize",
- "the window size (context words from [-window, window]). Default value is 5",
- typeConverter=TypeConverters.toInt)
- maxSentenceLength = Param(Params._dummy(), "maxSentenceLength",
- "Maximum length (in words) of each sentence in the input data. " +
- "Any sentence longer than this threshold will " +
- "be divided into chunks up to the size.",
- typeConverter=TypeConverters.toInt)
+ vectorSize = Param(
+ Params._dummy(),
+ "vectorSize",
+ "the dimension of codes after transforming from words",
+ typeConverter=TypeConverters.toInt,
+ )
+ numPartitions = Param(
+ Params._dummy(),
+ "numPartitions",
+ "number of partitions for sentences of words",
+ typeConverter=TypeConverters.toInt,
+ )
+ minCount = Param(
+ Params._dummy(),
+ "minCount",
+ "the minimum number of times a token must appear to be included in the "
+ + "word2vec model's vocabulary",
+ typeConverter=TypeConverters.toInt,
+ )
+ windowSize = Param(
+ Params._dummy(),
+ "windowSize",
+ "the window size (context words from [-window, window]). Default value is 5",
+ typeConverter=TypeConverters.toInt,
+ )
+ maxSentenceLength = Param(
+ Params._dummy(),
+ "maxSentenceLength",
+ "Maximum length (in words) of each sentence in the input data. "
+ + "Any sentence longer than this threshold will "
+ + "be divided into chunks up to the size.",
+ typeConverter=TypeConverters.toInt,
+ )
def __init__(self, *args):
super(_Word2VecParams, self).__init__(*args)
- self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
- windowSize=5, maxSentenceLength=1000)
+ self._setDefault(
+ vectorSize=100,
+ minCount=5,
+ numPartitions=1,
+ stepSize=0.025,
+ maxIter=1,
+ windowSize=5,
+ maxSentenceLength=1000,
+ )
@since("1.4.0")
def getVectorSize(self):
@@ -4721,9 +5192,20 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
"""
@keyword_only
- def __init__(self, *, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025,
- maxIter=1, seed=None, inputCol=None, outputCol=None, windowSize=5,
- maxSentenceLength=1000):
+ def __init__(
+ self,
+ *,
+ vectorSize=100,
+ minCount=5,
+ numPartitions=1,
+ stepSize=0.025,
+ maxIter=1,
+ seed=None,
+ inputCol=None,
+ outputCol=None,
+ windowSize=5,
+ maxSentenceLength=1000,
+ ):
"""
__init__(self, \\*, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, \
maxIter=1, seed=None, inputCol=None, outputCol=None, windowSize=5, \
@@ -4736,9 +5218,20 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
@keyword_only
@since("1.4.0")
- def setParams(self, *, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025,
- maxIter=1, seed=None, inputCol=None, outputCol=None, windowSize=5,
- maxSentenceLength=1000):
+ def setParams(
+ self,
+ *,
+ vectorSize=100,
+ minCount=5,
+ numPartitions=1,
+ stepSize=0.025,
+ maxIter=1,
+ seed=None,
+ inputCol=None,
+ outputCol=None,
+ windowSize=5,
+ maxSentenceLength=1000,
+ ):
"""
setParams(self, \\*, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \
seed=None, inputCol=None, outputCol=None, windowSize=5, \
@@ -4878,8 +5371,12 @@ class _PCAParams(HasInputCol, HasOutputCol):
.. versionadded:: 3.0.0
"""
- k = Param(Params._dummy(), "k", "the number of principal components",
- typeConverter=TypeConverters.toInt)
+ k = Param(
+ Params._dummy(),
+ "k",
+ "the number of principal components",
+ typeConverter=TypeConverters.toInt,
+ )
@since("1.5.0")
def getK(self):
@@ -5022,32 +5519,44 @@ class _RFormulaParams(HasFeaturesCol, HasLabelCol, HasHandleInvalid):
.. versionadded:: 3.0.0
"""
- formula = Param(Params._dummy(), "formula", "R model formula",
- typeConverter=TypeConverters.toString)
-
- forceIndexLabel = Param(Params._dummy(), "forceIndexLabel",
- "Force to index label whether it is numeric or string",
- typeConverter=TypeConverters.toBoolean)
-
- stringIndexerOrderType = Param(Params._dummy(), "stringIndexerOrderType",
- "How to order categories of a string feature column used by " +
- "StringIndexer. The last category after ordering is dropped " +
- "when encoding strings. Supported options: frequencyDesc, " +
- "frequencyAsc, alphabetDesc, alphabetAsc. The default value " +
- "is frequencyDesc. When the ordering is set to alphabetDesc, " +
- "RFormula drops the same category as R when encoding strings.",
- typeConverter=TypeConverters.toString)
-
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
- "Options are 'skip' (filter out rows with invalid values), " +
- "'error' (throw an error), or 'keep' (put invalid data in a special " +
- "additional bucket, at index numLabels).",
- typeConverter=TypeConverters.toString)
+ formula = Param(
+ Params._dummy(), "formula", "R model formula", typeConverter=TypeConverters.toString
+ )
+
+ forceIndexLabel = Param(
+ Params._dummy(),
+ "forceIndexLabel",
+ "Force to index label whether it is numeric or string",
+ typeConverter=TypeConverters.toBoolean,
+ )
+
+ stringIndexerOrderType = Param(
+ Params._dummy(),
+ "stringIndexerOrderType",
+ "How to order categories of a string feature column used by "
+ + "StringIndexer. The last category after ordering is dropped "
+ + "when encoding strings. Supported options: frequencyDesc, "
+ + "frequencyAsc, alphabetDesc, alphabetAsc. The default value "
+ + "is frequencyDesc. When the ordering is set to alphabetDesc, "
+ + "RFormula drops the same category as R when encoding strings.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ handleInvalid = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "how to handle invalid entries. "
+ + "Options are 'skip' (filter out rows with invalid values), "
+ + "'error' (throw an error), or 'keep' (put invalid data in a special "
+ + "additional bucket, at index numLabels).",
+ typeConverter=TypeConverters.toString,
+ )
def __init__(self, *args):
super(_RFormulaParams, self).__init__(*args)
- self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
- handleInvalid="error")
+ self._setDefault(
+ forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", handleInvalid="error"
+ )
@since("1.5.0")
def getFormula(self):
@@ -5146,9 +5655,16 @@ class RFormula(JavaEstimator, _RFormulaParams, JavaMLReadable, JavaMLWritable):
"""
@keyword_only
- def __init__(self, *, formula=None, featuresCol="features", labelCol="label",
- forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
- handleInvalid="error"):
+ def __init__(
+ self,
+ *,
+ formula=None,
+ featuresCol="features",
+ labelCol="label",
+ forceIndexLabel=False,
+ stringIndexerOrderType="frequencyDesc",
+ handleInvalid="error",
+ ):
"""
__init__(self, \\*, formula=None, featuresCol="features", labelCol="label", \
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
@@ -5161,9 +5677,16 @@ class RFormula(JavaEstimator, _RFormulaParams, JavaMLReadable, JavaMLWritable):
@keyword_only
@since("1.5.0")
- def setParams(self, *, formula=None, featuresCol="features", labelCol="label",
- forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
- handleInvalid="error"):
+ def setParams(
+ self,
+ *,
+ formula=None,
+ featuresCol="features",
+ labelCol="label",
+ forceIndexLabel=False,
+ stringIndexerOrderType="frequencyDesc",
+ handleInvalid="error",
+ ):
"""
setParams(self, \\*, formula=None, featuresCol="features", labelCol="label", \
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
@@ -5240,34 +5763,61 @@ class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol):
.. versionadded:: 3.1.0
"""
- selectorType = Param(Params._dummy(), "selectorType",
- "The selector type. " +
- "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.",
- typeConverter=TypeConverters.toString)
-
- numTopFeatures = \
- Param(Params._dummy(), "numTopFeatures",
- "Number of features that selector will select, ordered by ascending p-value. " +
- "If the number of features is < numTopFeatures, then this will select " +
- "all features.", typeConverter=TypeConverters.toInt)
-
- percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " +
- "will select, ordered by ascending p-value.",
- typeConverter=TypeConverters.toFloat)
-
- fpr = Param(Params._dummy(), "fpr", "The highest p-value for features to be kept.",
- typeConverter=TypeConverters.toFloat)
-
- fdr = Param(Params._dummy(), "fdr", "The upper bound of the expected false discovery rate.",
- typeConverter=TypeConverters.toFloat)
-
- fwe = Param(Params._dummy(), "fwe", "The upper bound of the expected family-wise error rate.",
- typeConverter=TypeConverters.toFloat)
+ selectorType = Param(
+ Params._dummy(),
+ "selectorType",
+ "The selector type. "
+ + "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ numTopFeatures = Param(
+ Params._dummy(),
+ "numTopFeatures",
+ "Number of features that selector will select, ordered by ascending p-value. "
+ + "If the number of features is < numTopFeatures, then this will select "
+ + "all features.",
+ typeConverter=TypeConverters.toInt,
+ )
+
+ percentile = Param(
+ Params._dummy(),
+ "percentile",
+ "Percentile of features that selector " + "will select, ordered by ascending p-value.",
+ typeConverter=TypeConverters.toFloat,
+ )
+
+ fpr = Param(
+ Params._dummy(),
+ "fpr",
+ "The highest p-value for features to be kept.",
+ typeConverter=TypeConverters.toFloat,
+ )
+
+ fdr = Param(
+ Params._dummy(),
+ "fdr",
+ "The upper bound of the expected false discovery rate.",
+ typeConverter=TypeConverters.toFloat,
+ )
+
+ fwe = Param(
+ Params._dummy(),
+ "fwe",
+ "The upper bound of the expected family-wise error rate.",
+ typeConverter=TypeConverters.toFloat,
+ )
def __init__(self, *args):
super(_SelectorParams, self).__init__(*args)
- self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1,
- fpr=0.05, fdr=0.05, fwe=0.05)
+ self._setDefault(
+ numTopFeatures=50,
+ selectorType="numTopFeatures",
+ percentile=0.1,
+ fpr=0.05,
+ fdr=0.05,
+ fwe=0.05,
+ )
@since("2.1.0")
def getSelectorType(self):
@@ -5475,9 +6025,19 @@ class ChiSqSelector(_Selector, JavaMLReadable, JavaMLWritable):
"""
@keyword_only
- def __init__(self, *, numTopFeatures=50, featuresCol="features", outputCol=None,
- labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05,
- fdr=0.05, fwe=0.05):
+ def __init__(
+ self,
+ *,
+ numTopFeatures=50,
+ featuresCol="features",
+ outputCol=None,
+ labelCol="label",
+ selectorType="numTopFeatures",
+ percentile=0.1,
+ fpr=0.05,
+ fdr=0.05,
+ fwe=0.05,
+ ):
"""
__init__(self, \\*, numTopFeatures=50, featuresCol="features", outputCol=None, \
labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \
@@ -5490,9 +6050,19 @@ class ChiSqSelector(_Selector, JavaMLReadable, JavaMLWritable):
@keyword_only
@since("2.0.0")
- def setParams(self, *, numTopFeatures=50, featuresCol="features", outputCol=None,
- labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05,
- fdr=0.05, fwe=0.05):
+ def setParams(
+ self,
+ *,
+ numTopFeatures=50,
+ featuresCol="features",
+ outputCol=None,
+ labelCol="labels",
+ selectorType="numTopFeatures",
+ percentile=0.1,
+ fpr=0.05,
+ fdr=0.05,
+ fwe=0.05,
+ ):
"""
setParams(self, \\*, numTopFeatures=50, featuresCol="features", outputCol=None, \
labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \
@@ -5515,8 +6085,9 @@ class ChiSqSelectorModel(_SelectorModel, JavaMLReadable, JavaMLWritable):
@inherit_doc
-class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable,
- JavaMLWritable):
+class VectorSizeHint(
+ JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable, JavaMLWritable
+):
"""
A feature transformer that adds size information to the metadata of a vector column.
VectorAssembler needs size information for its input columns and cannot be used on streaming
@@ -5551,16 +6122,20 @@ class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReada
True
"""
- size = Param(Params._dummy(), "size", "Size of vectors in column.",
- typeConverter=TypeConverters.toInt)
+ size = Param(
+ Params._dummy(), "size", "Size of vectors in column.", typeConverter=TypeConverters.toInt
+ )
- handleInvalid = Param(Params._dummy(), "handleInvalid",
- "How to handle invalid vectors in inputCol. Invalid vectors include "
- "nulls and vectors with the wrong size. The options are `skip` (filter "
- "out rows with invalid vectors), `error` (throw an error) and "
- "`optimistic` (do not check the vector size, and keep all rows). "
- "`error` by default.",
- TypeConverters.toString)
+ handleInvalid = Param(
+ Params._dummy(),
+ "handleInvalid",
+ "How to handle invalid vectors in inputCol. Invalid vectors include "
+ "nulls and vectors with the wrong size. The options are `skip` (filter "
+ "out rows with invalid vectors), `error` (throw an error) and "
+ "`optimistic` (do not check the vector size, and keep all rows). "
+ "`error` by default.",
+ TypeConverters.toString,
+ )
@keyword_only
def __init__(self, *, inputCol=None, size=None, handleInvalid="error"):
@@ -5584,12 +6159,12 @@ class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReada
@since("2.3.0")
def getSize(self):
- """ Gets size param, the size of vectors in `inputCol`."""
+ """Gets size param, the size of vectors in `inputCol`."""
return self.getOrDefault(self.size)
@since("2.3.0")
def setSize(self, value):
- """ Sets size param, the size of vectors in `inputCol`."""
+ """Sets size param, the size of vectors in `inputCol`."""
return self._set(size=value)
def setInputCol(self, value):
@@ -5613,10 +6188,14 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
.. versionadded:: 3.1.0
"""
- varianceThreshold = Param(Params._dummy(), "varianceThreshold",
- "Param for variance threshold. Features with a variance not " +
- "greater than this threshold will be removed. The default value " +
- "is 0.0.", typeConverter=TypeConverters.toFloat)
+ varianceThreshold = Param(
+ Params._dummy(),
+ "varianceThreshold",
+ "Param for variance threshold. Features with a variance not "
+ + "greater than this threshold will be removed. The default value "
+ + "is 0.0.",
+ typeConverter=TypeConverters.toFloat,
+ )
@since("3.1.0")
def getVarianceThreshold(self):
@@ -5627,8 +6206,9 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol):
@inherit_doc
-class VarianceThresholdSelector(JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable,
- JavaMLWritable):
+class VarianceThresholdSelector(
+ JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
+):
"""
Feature selector that removes all low-variance features. Features with a
variance not greater than the threshold will be removed. The default is to keep
@@ -5679,7 +6259,8 @@ class VarianceThresholdSelector(JavaEstimator, _VarianceThresholdSelectorParams,
"""
super(VarianceThresholdSelector, self).__init__()
self._java_obj = self._new_java_obj(
- "org.apache.spark.ml.feature.VarianceThresholdSelector", self.uid)
+ "org.apache.spark.ml.feature.VarianceThresholdSelector", self.uid
+ )
self._setDefault(varianceThreshold=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -5719,8 +6300,9 @@ class VarianceThresholdSelector(JavaEstimator, _VarianceThresholdSelectorParams,
return VarianceThresholdSelectorModel(java_model)
-class VarianceThresholdSelectorModel(JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable,
- JavaMLWritable):
+class VarianceThresholdSelectorModel(
+ JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable
+):
"""
Model fitted by :py:class:`VarianceThresholdSelector`.
@@ -5758,25 +6340,35 @@ class _UnivariateFeatureSelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol
.. versionadded:: 3.1.0
"""
- featureType = Param(Params._dummy(), "featureType",
- "The feature type. " +
- "Supported options: categorical, continuous.",
- typeConverter=TypeConverters.toString)
-
- labelType = Param(Params._dummy(), "labelType",
- "The label type. " +
- "Supported options: categorical, continuous.",
- typeConverter=TypeConverters.toString)
-
- selectionMode = Param(Params._dummy(), "selectionMode",
- "The selection mode. " +
- "Supported options: numTopFeatures (default), percentile, fpr, " +
- "fdr, fwe.",
- typeConverter=TypeConverters.toString)
-
- selectionThreshold = Param(Params._dummy(), "selectionThreshold", "The upper bound of the " +
- "features that selector will select.",
- typeConverter=TypeConverters.toFloat)
+ featureType = Param(
+ Params._dummy(),
+ "featureType",
+ "The feature type. " + "Supported options: categorical, continuous.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ labelType = Param(
+ Params._dummy(),
+ "labelType",
+ "The label type. " + "Supported options: categorical, continuous.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ selectionMode = Param(
+ Params._dummy(),
+ "selectionMode",
+ "The selection mode. "
+ + "Supported options: numTopFeatures (default), percentile, fpr, "
+ + "fdr, fwe.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ selectionThreshold = Param(
+ Params._dummy(),
+ "selectionThreshold",
+ "The upper bound of the " + "features that selector will select.",
+ typeConverter=TypeConverters.toFloat,
+ )
def __init__(self, *args):
super(_UnivariateFeatureSelectorParams, self).__init__(*args)
@@ -5812,8 +6404,9 @@ class _UnivariateFeatureSelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol
@inherit_doc
-class UnivariateFeatureSelector(JavaEstimator, _UnivariateFeatureSelectorParams, JavaMLReadable,
- JavaMLWritable):
+class UnivariateFeatureSelector(
+ JavaEstimator, _UnivariateFeatureSelectorParams, JavaMLReadable, JavaMLWritable
+):
"""
UnivariateFeatureSelector
Feature selector based on univariate statistical tests against labels. Currently, Spark
@@ -5887,22 +6480,35 @@ class UnivariateFeatureSelector(JavaEstimator, _UnivariateFeatureSelectorParams,
"""
@keyword_only
- def __init__(self, *, featuresCol="features", outputCol=None,
- labelCol="label", selectionMode="numTopFeatures"):
+ def __init__(
+ self,
+ *,
+ featuresCol="features",
+ outputCol=None,
+ labelCol="label",
+ selectionMode="numTopFeatures",
+ ):
"""
__init__(self, \\*, featuresCol="features", outputCol=None, \
labelCol="label", selectionMode="numTopFeatures")
"""
super(UnivariateFeatureSelector, self).__init__()
- self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.UnivariateFeatureSelector",
- self.uid)
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.feature.UnivariateFeatureSelector", self.uid
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("3.1.1")
- def setParams(self, *, featuresCol="features", outputCol=None,
- labelCol="labels", selectionMode="numTopFeatures"):
+ def setParams(
+ self,
+ *,
+ featuresCol="features",
+ outputCol=None,
+ labelCol="labels",
+ selectionMode="numTopFeatures",
+ ):
"""
setParams(self, \\*, featuresCol="features", outputCol=None, \
labelCol="labels", selectionMode="numTopFeatures")
@@ -5961,8 +6567,9 @@ class UnivariateFeatureSelector(JavaEstimator, _UnivariateFeatureSelectorParams,
return UnivariateFeatureSelectorModel(java_model)
-class UnivariateFeatureSelectorModel(JavaModel, _UnivariateFeatureSelectorParams, JavaMLReadable,
- JavaMLWritable):
+class UnivariateFeatureSelectorModel(
+ JavaModel, _UnivariateFeatureSelectorParams, JavaMLReadable, JavaMLWritable
+):
"""
Model fitted by :py:class:`UnivariateFeatureSelector`.
@@ -6006,24 +6613,30 @@ if __name__ == "__main__":
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- spark = SparkSession.builder\
- .master("local[2]")\
- .appName("ml.feature tests")\
- .getOrCreate()
+ spark = SparkSession.builder.master("local[2]").appName("ml.feature tests").getOrCreate()
sc = spark.sparkContext
- globs['sc'] = sc
- globs['spark'] = spark
- testData = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="b"),
- Row(id=2, label="c"), Row(id=3, label="a"),
- Row(id=4, label="a"), Row(id=5, label="c")], 2)
- globs['stringIndDf'] = spark.createDataFrame(testData)
+ globs["sc"] = sc
+ globs["spark"] = spark
+ testData = sc.parallelize(
+ [
+ Row(id=0, label="a"),
+ Row(id=1, label="b"),
+ Row(id=2, label="c"),
+ Row(id=3, label="a"),
+ Row(id=4, label="a"),
+ Row(id=5, label="c"),
+ ],
+ 2,
+ )
+ globs["stringIndDf"] = spark.createDataFrame(testData)
temp_path = tempfile.mkdtemp()
- globs['temp_path'] = temp_path
+ globs["temp_path"] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
finally:
from shutil import rmtree
+
try:
rmtree(temp_path)
except OSError:
diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi
index 33e4691..6efc304 100644
--- a/python/pyspark/ml/feature.pyi
+++ b/python/pyspark/ml/feature.pyi
@@ -61,7 +61,7 @@ class Binarizer(
*,
threshold: float = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
@overload
def __init__(
@@ -69,7 +69,7 @@ class Binarizer(
*,
thresholds: Optional[List[float]] = ...,
inputCols: Optional[List[str]] = ...,
- outputCols: Optional[List[str]] = ...
+ outputCols: Optional[List[str]] = ...,
) -> None: ...
@overload
def setParams(
@@ -77,7 +77,7 @@ class Binarizer(
*,
threshold: float = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> Binarizer: ...
@overload
def setParams(
@@ -85,7 +85,7 @@ class Binarizer(
*,
thresholds: Optional[List[float]] = ...,
inputCols: Optional[List[str]] = ...,
- outputCols: Optional[List[str]] = ...
+ outputCols: Optional[List[str]] = ...,
) -> Binarizer: ...
def setThreshold(self, value: float) -> Binarizer: ...
def setThresholds(self, value: List[float]) -> Binarizer: ...
@@ -140,7 +140,7 @@ class BucketedRandomProjectionLSH(
outputCol: Optional[str] = ...,
seed: Optional[int] = ...,
numHashTables: int = ...,
- bucketLength: Optional[float] = ...
+ bucketLength: Optional[float] = ...,
) -> None: ...
def setParams(
self,
@@ -149,7 +149,7 @@ class BucketedRandomProjectionLSH(
outputCol: Optional[str] = ...,
seed: Optional[int] = ...,
numHashTables: int = ...,
- bucketLength: Optional[float] = ...
+ bucketLength: Optional[float] = ...,
) -> BucketedRandomProjectionLSH: ...
def setBucketLength(self, value: float) -> BucketedRandomProjectionLSH: ...
def setSeed(self, value: int) -> BucketedRandomProjectionLSH: ...
@@ -181,7 +181,7 @@ class Bucketizer(
splits: Optional[List[float]] = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> None: ...
@overload
def __init__(
@@ -190,7 +190,7 @@ class Bucketizer(
handleInvalid: str = ...,
splitsArray: Optional[List[List[float]]] = ...,
inputCols: Optional[List[str]] = ...,
- outputCols: Optional[List[str]] = ...
+ outputCols: Optional[List[str]] = ...,
) -> None: ...
@overload
def setParams(
@@ -199,7 +199,7 @@ class Bucketizer(
splits: Optional[List[float]] = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> Bucketizer: ...
@overload
def setParams(
@@ -208,7 +208,7 @@ class Bucketizer(
handleInvalid: str = ...,
splitsArray: Optional[List[List[float]]] = ...,
inputCols: Optional[List[str]] = ...,
- outputCols: Optional[List[str]] = ...
+ outputCols: Optional[List[str]] = ...,
) -> Bucketizer: ...
def setSplits(self, value: List[float]) -> Bucketizer: ...
def getSplits(self) -> List[float]: ...
@@ -248,7 +248,7 @@ class CountVectorizer(
vocabSize: int = ...,
binary: bool = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -259,7 +259,7 @@ class CountVectorizer(
vocabSize: int = ...,
binary: bool = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> CountVectorizer: ...
def setMinTF(self, value: float) -> CountVectorizer: ...
def setMinDF(self, value: float) -> CountVectorizer: ...
@@ -269,9 +269,7 @@ class CountVectorizer(
def setInputCol(self, value: str) -> CountVectorizer: ...
def setOutputCol(self, value: str) -> CountVectorizer: ...
-class CountVectorizerModel(
- JavaModel, JavaMLReadable[CountVectorizerModel], JavaMLWritable
-):
+class CountVectorizerModel(JavaModel, JavaMLReadable[CountVectorizerModel], JavaMLWritable):
def setInputCol(self, value: str) -> CountVectorizerModel: ...
def setOutputCol(self, value: str) -> CountVectorizerModel: ...
def setMinTF(self, value: float) -> CountVectorizerModel: ...
@@ -288,23 +286,13 @@ class CountVectorizerModel(
@property
def vocabulary(self) -> List[str]: ...
-class DCT(
- JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[DCT], JavaMLWritable
-):
+class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[DCT], JavaMLWritable):
inverse: Param[bool]
def __init__(
- self,
- *,
- inverse: bool = ...,
- inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ self, *, inverse: bool = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
) -> None: ...
def setParams(
- self,
- *,
- inverse: bool = ...,
- inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ self, *, inverse: bool = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
) -> DCT: ...
def setInverse(self, value: bool) -> DCT: ...
def getInverse(self) -> bool: ...
@@ -324,14 +312,14 @@ class ElementwiseProduct(
*,
scalingVec: Optional[Vector] = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
*,
scalingVec: Optional[Vector] = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> ElementwiseProduct: ...
def setScalingVec(self, value: Vector) -> ElementwiseProduct: ...
def getScalingVec(self) -> Vector: ...
@@ -353,7 +341,7 @@ class FeatureHasher(
numFeatures: int = ...,
inputCols: Optional[List[str]] = ...,
outputCol: Optional[str] = ...,
- categoricalCols: Optional[List[str]] = ...
+ categoricalCols: Optional[List[str]] = ...,
) -> None: ...
def setParams(
self,
@@ -361,7 +349,7 @@ class FeatureHasher(
numFeatures: int = ...,
inputCols: Optional[List[str]] = ...,
outputCol: Optional[str] = ...,
- categoricalCols: Optional[List[str]] = ...
+ categoricalCols: Optional[List[str]] = ...,
) -> FeatureHasher: ...
def setCategoricalCols(self, value: List[str]) -> FeatureHasher: ...
def getCategoricalCols(self) -> List[str]: ...
@@ -384,7 +372,7 @@ class HashingTF(
numFeatures: int = ...,
binary: bool = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -392,7 +380,7 @@ class HashingTF(
numFeatures: int = ...,
binary: bool = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> HashingTF: ...
def setBinary(self, value: bool) -> HashingTF: ...
def getBinary(self) -> bool: ...
@@ -412,14 +400,14 @@ class IDF(JavaEstimator[IDFModel], _IDFParams, JavaMLReadable[IDF], JavaMLWritab
*,
minDocFreq: int = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
*,
minDocFreq: int = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> IDF: ...
def setMinDocFreq(self, value: int) -> IDF: ...
def setInputCol(self, value: str) -> IDF: ...
@@ -435,17 +423,13 @@ class IDFModel(JavaModel, _IDFParams, JavaMLReadable[IDFModel], JavaMLWritable):
@property
def numDocs(self) -> int: ...
-class _ImputerParams(
- HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasRelativeError
-):
+class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasRelativeError):
strategy: Param[str]
missingValue: Param[float]
def getStrategy(self) -> str: ...
def getMissingValue(self) -> float: ...
-class Imputer(
- JavaEstimator[ImputerModel], _ImputerParams, JavaMLReadable[Imputer], JavaMLWritable
-):
+class Imputer(JavaEstimator[ImputerModel], _ImputerParams, JavaMLReadable[Imputer], JavaMLWritable):
@overload
def __init__(
self,
@@ -454,7 +438,7 @@ class Imputer(
missingValue: float = ...,
inputCols: Optional[List[str]] = ...,
outputCols: Optional[List[str]] = ...,
- relativeError: float = ...
+ relativeError: float = ...,
) -> None: ...
@overload
def __init__(
@@ -464,7 +448,7 @@ class Imputer(
missingValue: float = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- relativeError: float = ...
+ relativeError: float = ...,
) -> None: ...
@overload
def setParams(
@@ -474,7 +458,7 @@ class Imputer(
missingValue: float = ...,
inputCols: Optional[List[str]] = ...,
outputCols: Optional[List[str]] = ...,
- relativeError: float = ...
+ relativeError: float = ...,
) -> Imputer: ...
@overload
def setParams(
@@ -484,7 +468,7 @@ class Imputer(
missingValue: float = ...,
inputCol: Optional[str] = ...,
outputCols: Optional[str] = ...,
- relativeError: float = ...
+ relativeError: float = ...,
) -> Imputer: ...
def setStrategy(self, value: str) -> Imputer: ...
def setMissingValue(self, value: float) -> Imputer: ...
@@ -494,9 +478,7 @@ class Imputer(
def setOutputCol(self, value: str) -> Imputer: ...
def setRelativeError(self, value: float) -> Imputer: ...
-class ImputerModel(
- JavaModel, _ImputerParams, JavaMLReadable[ImputerModel], JavaMLWritable
-):
+class ImputerModel(JavaModel, _ImputerParams, JavaMLReadable[ImputerModel], JavaMLWritable):
def setInputCols(self, value: List[str]) -> ImputerModel: ...
def setOutputCols(self, value: List[str]) -> ImputerModel: ...
def setInputCol(self, value: str) -> ImputerModel: ...
@@ -559,7 +541,7 @@ class MinHashLSH(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
seed: Optional[int] = ...,
- numHashTables: int = ...
+ numHashTables: int = ...,
) -> None: ...
def setParams(
self,
@@ -567,7 +549,7 @@ class MinHashLSH(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
seed: Optional[int] = ...,
- numHashTables: int = ...
+ numHashTables: int = ...,
) -> MinHashLSH: ...
def setSeed(self, value: int) -> MinHashLSH: ...
@@ -592,7 +574,7 @@ class MinMaxScaler(
min: float = ...,
max: float = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -600,7 +582,7 @@ class MinMaxScaler(
min: float = ...,
max: float = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> MinMaxScaler: ...
def setMin(self, value: float) -> MinMaxScaler: ...
def setMax(self, value: float) -> MinMaxScaler: ...
@@ -619,23 +601,13 @@ class MinMaxScalerModel(
@property
def originalMax(self) -> Vector: ...
-class NGram(
- JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[NGram], JavaMLWritable
-):
+class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable[NGram], JavaMLWritable):
n: Param[int]
def __init__(
- self,
- *,
- n: int = ...,
- inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ self, *, n: int = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
) -> None: ...
def setParams(
- self,
- *,
- n: int = ...,
- inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ self, *, n: int = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
) -> NGram: ...
def setN(self, value: int) -> NGram: ...
def getN(self) -> int: ...
@@ -651,18 +623,10 @@ class Normalizer(
):
p: Param[float]
def __init__(
- self,
- *,
- p: float = ...,
- inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ self, *, p: float = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
) -> None: ...
def setParams(
- self,
- *,
- p: float = ...,
- inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ self, *, p: float = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
) -> Normalizer: ...
def setP(self, value: float) -> Normalizer: ...
def getP(self) -> float: ...
@@ -688,7 +652,7 @@ class OneHotEncoder(
inputCols: Optional[List[str]] = ...,
outputCols: Optional[List[str]] = ...,
handleInvalid: str = ...,
- dropLast: bool = ...
+ dropLast: bool = ...,
) -> None: ...
@overload
def __init__(
@@ -697,7 +661,7 @@ class OneHotEncoder(
handleInvalid: str = ...,
dropLast: bool = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
@overload
def setParams(
@@ -706,7 +670,7 @@ class OneHotEncoder(
inputCols: Optional[List[str]] = ...,
outputCols: Optional[List[str]] = ...,
handleInvalid: str = ...,
- dropLast: bool = ...
+ dropLast: bool = ...,
) -> OneHotEncoder: ...
@overload
def setParams(
@@ -715,7 +679,7 @@ class OneHotEncoder(
handleInvalid: str = ...,
dropLast: bool = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> OneHotEncoder: ...
def setDropLast(self, value: bool) -> OneHotEncoder: ...
def setInputCols(self, value: List[str]) -> OneHotEncoder: ...
@@ -745,18 +709,10 @@ class PolynomialExpansion(
):
degree: Param[int]
def __init__(
- self,
- *,
- degree: int = ...,
- inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ self, *, degree: int = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
) -> None: ...
def setParams(
- self,
- *,
- degree: int = ...,
- inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ self, *, degree: int = ..., inputCol: Optional[str] = ..., outputCol: Optional[str] = ...
) -> PolynomialExpansion: ...
def setDegree(self, value: int) -> PolynomialExpansion: ...
def getDegree(self) -> int: ...
@@ -785,7 +741,7 @@ class QuantileDiscretizer(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
relativeError: float = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> None: ...
@overload
def __init__(
@@ -795,7 +751,7 @@ class QuantileDiscretizer(
handleInvalid: str = ...,
numBucketsArray: Optional[List[int]] = ...,
inputCols: Optional[List[str]] = ...,
- outputCols: Optional[List[str]] = ...
+ outputCols: Optional[List[str]] = ...,
) -> None: ...
@overload
def setParams(
@@ -805,7 +761,7 @@ class QuantileDiscretizer(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
relativeError: float = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> QuantileDiscretizer: ...
@overload
def setParams(
@@ -815,7 +771,7 @@ class QuantileDiscretizer(
handleInvalid: str = ...,
numBucketsArray: Optional[List[int]] = ...,
inputCols: Optional[List[str]] = ...,
- outputCols: Optional[List[str]] = ...
+ outputCols: Optional[List[str]] = ...,
) -> QuantileDiscretizer: ...
def setNumBuckets(self, value: int) -> QuantileDiscretizer: ...
def getNumBuckets(self) -> int: ...
@@ -851,7 +807,7 @@ class RobustScaler(
withScaling: bool = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- relativeError: float = ...
+ relativeError: float = ...,
) -> None: ...
def setParams(
self,
@@ -862,7 +818,7 @@ class RobustScaler(
withScaling: bool = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- relativeError: float = ...
+ relativeError: float = ...,
) -> RobustScaler: ...
def setLower(self, value: float) -> RobustScaler: ...
def setUpper(self, value: float) -> RobustScaler: ...
@@ -901,7 +857,7 @@ class RegexTokenizer(
pattern: str = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- toLowercase: bool = ...
+ toLowercase: bool = ...,
) -> None: ...
def setParams(
self,
@@ -911,7 +867,7 @@ class RegexTokenizer(
pattern: str = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- toLowercase: bool = ...
+ toLowercase: bool = ...,
) -> RegexTokenizer: ...
def setMinTokenLength(self, value: int) -> RegexTokenizer: ...
def getMinTokenLength(self) -> int: ...
@@ -950,7 +906,7 @@ class StandardScaler(
withMean: bool = ...,
withStd: bool = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
@@ -958,7 +914,7 @@ class StandardScaler(
withMean: bool = ...,
withStd: bool = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> StandardScaler: ...
def setWithMean(self, value: bool) -> StandardScaler: ...
def setWithStd(self, value: bool) -> StandardScaler: ...
@@ -999,7 +955,7 @@ class StringIndexer(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
handleInvalid: str = ...,
- stringOrderType: str = ...
+ stringOrderType: str = ...,
) -> None: ...
@overload
def __init__(
@@ -1008,7 +964,7 @@ class StringIndexer(
inputCols: Optional[List[str]] = ...,
outputCols: Optional[List[str]] = ...,
handleInvalid: str = ...,
- stringOrderType: str = ...
+ stringOrderType: str = ...,
) -> None: ...
@overload
def setParams(
@@ -1017,7 +973,7 @@ class StringIndexer(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
handleInvalid: str = ...,
- stringOrderType: str = ...
+ stringOrderType: str = ...,
) -> StringIndexer: ...
@overload
def setParams(
@@ -1026,7 +982,7 @@ class StringIndexer(
inputCols: Optional[List[str]] = ...,
outputCols: Optional[List[str]] = ...,
handleInvalid: str = ...,
- stringOrderType: str = ...
+ stringOrderType: str = ...,
) -> StringIndexer: ...
def setStringOrderType(self, value: str) -> StringIndexer: ...
def setInputCol(self, value: str) -> StringIndexer: ...
@@ -1075,14 +1031,14 @@ class IndexToString(
*,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- labels: Optional[List[str]] = ...
+ labels: Optional[List[str]] = ...,
) -> None: ...
def setParams(
self,
*,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- labels: Optional[List[str]] = ...
+ labels: Optional[List[str]] = ...,
) -> IndexToString: ...
def setLabels(self, value: List[str]) -> IndexToString: ...
def getLabels(self) -> List[str]: ...
@@ -1109,7 +1065,7 @@ class StopWordsRemover(
outputCol: Optional[str] = ...,
stopWords: Optional[List[str]] = ...,
caseSensitive: bool = ...,
- locale: Optional[str] = ...
+ locale: Optional[str] = ...,
) -> None: ...
@overload
def __init__(
@@ -1119,7 +1075,7 @@ class StopWordsRemover(
caseSensitive: bool = ...,
locale: Optional[str] = ...,
inputCols: Optional[List[str]] = ...,
- outputCols: Optional[List[str]] = ...
+ outputCols: Optional[List[str]] = ...,
) -> None: ...
@overload
def setParams(
@@ -1129,7 +1085,7 @@ class StopWordsRemover(
outputCol: Optional[str] = ...,
stopWords: Optional[List[str]] = ...,
caseSensitive: bool = ...,
- locale: Optional[str] = ...
+ locale: Optional[str] = ...,
) -> StopWordsRemover: ...
@overload
def setParams(
@@ -1139,7 +1095,7 @@ class StopWordsRemover(
caseSensitive: bool = ...,
locale: Optional[str] = ...,
inputCols: Optional[List[str]] = ...,
- outputCols: Optional[List[str]] = ...
+ outputCols: Optional[List[str]] = ...,
) -> StopWordsRemover: ...
def setStopWords(self, value: List[str]) -> StopWordsRemover: ...
def getStopWords(self) -> List[str]: ...
@@ -1184,14 +1140,14 @@ class VectorAssembler(
*,
inputCols: Optional[List[str]] = ...,
outputCol: Optional[str] = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> None: ...
def setParams(
self,
*,
inputCols: Optional[List[str]] = ...,
outputCol: Optional[str] = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> VectorAssembler: ...
def setInputCols(self, value: List[str]) -> VectorAssembler: ...
def setOutputCol(self, value: str) -> VectorAssembler: ...
@@ -1216,7 +1172,7 @@ class VectorIndexer(
maxCategories: int = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> None: ...
def setParams(
self,
@@ -1224,7 +1180,7 @@ class VectorIndexer(
maxCategories: int = ...,
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> VectorIndexer: ...
def setMaxCategories(self, value: int) -> VectorIndexer: ...
def setInputCol(self, value: str) -> VectorIndexer: ...
@@ -1256,7 +1212,7 @@ class VectorSlicer(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
indices: Optional[List[int]] = ...,
- names: Optional[List[str]] = ...
+ names: Optional[List[str]] = ...,
) -> None: ...
def setParams(
self,
@@ -1264,7 +1220,7 @@ class VectorSlicer(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
indices: Optional[List[int]] = ...,
- names: Optional[List[str]] = ...
+ names: Optional[List[str]] = ...,
) -> VectorSlicer: ...
def setIndices(self, value: List[int]) -> VectorSlicer: ...
def getIndices(self) -> List[int]: ...
@@ -1304,7 +1260,7 @@ class Word2Vec(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
windowSize: int = ...,
- maxSentenceLength: int = ...
+ maxSentenceLength: int = ...,
) -> None: ...
def setParams(
self,
@@ -1318,7 +1274,7 @@ class Word2Vec(
inputCol: Optional[str] = ...,
outputCol: Optional[str] = ...,
windowSize: int = ...,
- maxSentenceLength: int = ...
+ maxSentenceLength: int = ...,
) -> Word2Vec: ...
def setVectorSize(self, value: int) -> Word2Vec: ...
def setNumPartitions(self, value: int) -> Word2Vec: ...
@@ -1331,9 +1287,7 @@ class Word2Vec(
def setSeed(self, value: int) -> Word2Vec: ...
def setStepSize(self, value: float) -> Word2Vec: ...
-class Word2VecModel(
- JavaModel, _Word2VecParams, JavaMLReadable[Word2VecModel], JavaMLWritable
-):
+class Word2VecModel(JavaModel, _Word2VecParams, JavaMLReadable[Word2VecModel], JavaMLWritable):
def getVectors(self) -> DataFrame: ...
def setInputCol(self, value: str) -> Word2VecModel: ...
def setOutputCol(self, value: str) -> Word2VecModel: ...
@@ -1356,14 +1310,14 @@ class PCA(JavaEstimator[PCAModel], _PCAParams, JavaMLReadable[PCA], JavaMLWritab
*,
k: Optional[int] = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> None: ...
def setParams(
self,
*,
k: Optional[int] = ...,
inputCol: Optional[str] = ...,
- outputCol: Optional[str] = ...
+ outputCol: Optional[str] = ...,
) -> PCA: ...
def setK(self, value: int) -> PCA: ...
def setInputCol(self, value: str) -> PCA: ...
@@ -1401,7 +1355,7 @@ class RFormula(
labelCol: str = ...,
forceIndexLabel: bool = ...,
stringIndexerOrderType: str = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> None: ...
def setParams(
self,
@@ -1411,7 +1365,7 @@ class RFormula(
labelCol: str = ...,
forceIndexLabel: bool = ...,
stringIndexerOrderType: str = ...,
- handleInvalid: str = ...
+ handleInvalid: str = ...,
) -> RFormula: ...
def setFormula(self, value: str) -> RFormula: ...
def setForceIndexLabel(self, value: bool) -> RFormula: ...
@@ -1420,9 +1374,7 @@ class RFormula(
def setLabelCol(self, value: str) -> RFormula: ...
def setHandleInvalid(self, value: str) -> RFormula: ...
-class RFormulaModel(
- JavaModel, _RFormulaParams, JavaMLReadable[RFormulaModel], JavaMLWritable
-): ...
+class RFormulaModel(JavaModel, _RFormulaParams, JavaMLReadable[RFormulaModel], JavaMLWritable): ...
class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol):
selectorType: Param[str]
@@ -1472,7 +1424,7 @@ class ChiSqSelector(
percentile: float = ...,
fpr: float = ...,
fdr: float = ...,
- fwe: float = ...
+ fwe: float = ...,
) -> None: ...
def setParams(
self,
@@ -1485,7 +1437,7 @@ class ChiSqSelector(
percentile: float = ...,
fpr: float = ...,
fdr: float = ...,
- fwe: float = ...
+ fwe: float = ...,
) -> ChiSqSelector: ...
def setSelectorType(self, value: str) -> ChiSqSelector: ...
def setNumTopFeatures(self, value: int) -> ChiSqSelector: ...
@@ -1497,9 +1449,7 @@ class ChiSqSelector(
def setOutputCol(self, value: str) -> ChiSqSelector: ...
def setLabelCol(self, value: str) -> ChiSqSelector: ...
-class ChiSqSelectorModel(
- _SelectorModel, JavaMLReadable[ChiSqSelectorModel], JavaMLWritable
-):
+class ChiSqSelectorModel(_SelectorModel, JavaMLReadable[ChiSqSelectorModel], JavaMLWritable):
def setFeaturesCol(self, value: str) -> ChiSqSelectorModel: ...
def setOutputCol(self, value: str) -> ChiSqSelectorModel: ...
@property
@@ -1515,18 +1465,10 @@ class VectorSizeHint(
size: Param[int]
handleInvalid: Param[str]
def __init__(
- self,
- *,
- inputCol: Optional[str] = ...,
- size: Optional[int] = ...,
- handleInvalid: str = ...
+ self, *, inputCol: Optional[str] = ..., size: Optional[int] = ..., handleInvalid: str = ...
) -> None: ...
def setParams(
- self,
- *,
- inputCol: Optional[str] = ...,
- size: Optional[int] = ...,
- handleInvalid: str = ...
+ self, *, inputCol: Optional[str] = ..., size: Optional[int] = ..., handleInvalid: str = ...
) -> VectorSizeHint: ...
def setSize(self, value: int) -> VectorSizeHint: ...
def getSize(self) -> int: ...
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 37dac48..9cfd3af 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -33,34 +33,39 @@ class _FPGrowthParams(HasPredictionCol):
.. versionadded:: 3.0.0
"""
- itemsCol = Param(Params._dummy(), "itemsCol",
- "items column name", typeConverter=TypeConverters.toString)
+ itemsCol = Param(
+ Params._dummy(), "itemsCol", "items column name", typeConverter=TypeConverters.toString
+ )
minSupport = Param(
Params._dummy(),
"minSupport",
- "Minimal support level of the frequent pattern. [0.0, 1.0]. " +
- "Any pattern that appears more than (minSupport * size-of-the-dataset) " +
- "times will be output in the frequent itemsets.",
- typeConverter=TypeConverters.toFloat)
+ "Minimal support level of the frequent pattern. [0.0, 1.0]. "
+ + "Any pattern that appears more than (minSupport * size-of-the-dataset) "
+ + "times will be output in the frequent itemsets.",
+ typeConverter=TypeConverters.toFloat,
+ )
numPartitions = Param(
Params._dummy(),
"numPartitions",
- "Number of partitions (at least 1) used by parallel FP-growth. " +
- "By default the param is not set, " +
- "and partition number of the input dataset is used.",
- typeConverter=TypeConverters.toInt)
+ "Number of partitions (at least 1) used by parallel FP-growth. "
+ + "By default the param is not set, "
+ + "and partition number of the input dataset is used.",
+ typeConverter=TypeConverters.toInt,
+ )
minConfidence = Param(
Params._dummy(),
"minConfidence",
- "Minimal confidence for generating Association Rule. [0.0, 1.0]. " +
- "minConfidence will not affect the mining for frequent itemsets, " +
- "but will affect the association rules generation.",
- typeConverter=TypeConverters.toFloat)
+ "Minimal confidence for generating Association Rule. [0.0, 1.0]. "
+ + "minConfidence will not affect the mining for frequent itemsets, "
+ + "but will affect the association rules generation.",
+ typeConverter=TypeConverters.toFloat,
+ )
def __init__(self, *args):
super(_FPGrowthParams, self).__init__(*args)
- self._setDefault(minSupport=0.3, minConfidence=0.8,
- itemsCol="items", predictionCol="prediction")
+ self._setDefault(
+ minSupport=0.3, minConfidence=0.8, itemsCol="items", predictionCol="prediction"
+ )
def getItemsCol(self):
"""
@@ -224,9 +229,17 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
>>> fpm.transform(data).take(1) == model2.transform(data).take(1)
True
"""
+
@keyword_only
- def __init__(self, *, minSupport=0.3, minConfidence=0.8, itemsCol="items",
- predictionCol="prediction", numPartitions=None):
+ def __init__(
+ self,
+ *,
+ minSupport=0.3,
+ minConfidence=0.8,
+ itemsCol="items",
+ predictionCol="prediction",
+ numPartitions=None,
+ ):
"""
__init__(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
predictionCol="prediction", numPartitions=None)
@@ -238,8 +251,15 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
@keyword_only
@since("2.2.0")
- def setParams(self, *, minSupport=0.3, minConfidence=0.8, itemsCol="items",
- predictionCol="prediction", numPartitions=None):
+ def setParams(
+ self,
+ *,
+ minSupport=0.3,
+ minConfidence=0.8,
+ itemsCol="items",
+ predictionCol="prediction",
+ numPartitions=None,
+ ):
"""
setParams(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
predictionCol="prediction", numPartitions=None)
@@ -327,45 +347,72 @@ class PrefixSpan(JavaParams):
...
"""
- minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " +
- "sequential pattern. Sequential pattern that appears more than " +
- "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.",
- typeConverter=TypeConverters.toFloat)
-
- maxPatternLength = Param(Params._dummy(), "maxPatternLength",
- "The maximal length of the sequential pattern. Must be > 0.",
- typeConverter=TypeConverters.toInt)
+ minSupport = Param(
+ Params._dummy(),
+ "minSupport",
+ "The minimal support level of the "
+ + "sequential pattern. Sequential pattern that appears more than "
+ + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.",
+ typeConverter=TypeConverters.toFloat,
+ )
- maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize",
- "The maximum number of items (including delimiters used in the " +
- "internal storage format) allowed in a projected database before " +
- "local processing. If a projected database exceeds this size, " +
- "another iteration of distributed prefix growth is run. " +
- "Must be > 0.",
- typeConverter=TypeConverters.toInt)
+ maxPatternLength = Param(
+ Params._dummy(),
+ "maxPatternLength",
+ "The maximal length of the sequential pattern. Must be > 0.",
+ typeConverter=TypeConverters.toInt,
+ )
- sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " +
- "dataset, rows with nulls in this column are ignored.",
- typeConverter=TypeConverters.toString)
+ maxLocalProjDBSize = Param(
+ Params._dummy(),
+ "maxLocalProjDBSize",
+ "The maximum number of items (including delimiters used in the "
+ + "internal storage format) allowed in a projected database before "
+ + "local processing. If a projected database exceeds this size, "
+ + "another iteration of distributed prefix growth is run. "
+ + "Must be > 0.",
+ typeConverter=TypeConverters.toInt,
+ )
+
+ sequenceCol = Param(
+ Params._dummy(),
+ "sequenceCol",
+ "The name of the sequence column in "
+ + "dataset, rows with nulls in this column are ignored.",
+ typeConverter=TypeConverters.toString,
+ )
@keyword_only
- def __init__(self, *, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
- sequenceCol="sequence"):
+ def __init__(
+ self,
+ *,
+ minSupport=0.1,
+ maxPatternLength=10,
+ maxLocalProjDBSize=32000000,
+ sequenceCol="sequence",
+ ):
"""
__init__(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
sequenceCol="sequence")
"""
super(PrefixSpan, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid)
- self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
- sequenceCol="sequence")
+ self._setDefault(
+ minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, sequenceCol="sequence"
+ )
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.4.0")
- def setParams(self, *, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
- sequenceCol="sequence"):
+ def setParams(
+ self,
+ *,
+ minSupport=0.1,
+ maxPatternLength=10,
+ maxLocalProjDBSize=32000000,
+ sequenceCol="sequence",
+ ):
"""
setParams(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
sequenceCol="sequence")
@@ -460,24 +507,24 @@ if __name__ == "__main__":
import doctest
import pyspark.ml.fpm
from pyspark.sql import SparkSession
+
globs = pyspark.ml.fpm.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- spark = SparkSession.builder\
- .master("local[2]")\
- .appName("ml.fpm tests")\
- .getOrCreate()
+ spark = SparkSession.builder.master("local[2]").appName("ml.fpm tests").getOrCreate()
sc = spark.sparkContext
- globs['sc'] = sc
- globs['spark'] = spark
+ globs["sc"] = sc
+ globs["spark"] = spark
import tempfile
+
temp_path = tempfile.mkdtemp()
- globs['temp_path'] = temp_path
+ globs["temp_path"] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
finally:
from shutil import rmtree
+
try:
rmtree(temp_path)
except OSError:
diff --git a/python/pyspark/ml/fpm.pyi b/python/pyspark/ml/fpm.pyi
index 7cc304a..609bc44 100644
--- a/python/pyspark/ml/fpm.pyi
+++ b/python/pyspark/ml/fpm.pyi
@@ -36,9 +36,7 @@ class _FPGrowthParams(HasPredictionCol):
def getNumPartitions(self) -> int: ...
def getMinConfidence(self) -> float: ...
-class FPGrowthModel(
- JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable[FPGrowthModel]
-):
+class FPGrowthModel(JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable[FPGrowthModel]):
def setItemsCol(self, value: str) -> FPGrowthModel: ...
def setMinConfidence(self, value: float) -> FPGrowthModel: ...
def setPredictionCol(self, value: str) -> FPGrowthModel: ...
@@ -60,7 +58,7 @@ class FPGrowth(
minConfidence: float = ...,
itemsCol: str = ...,
predictionCol: str = ...,
- numPartitions: Optional[int] = ...
+ numPartitions: Optional[int] = ...,
) -> None: ...
def setParams(
self,
@@ -69,7 +67,7 @@ class FPGrowth(
minConfidence: float = ...,
itemsCol: str = ...,
predictionCol: str = ...,
- numPartitions: Optional[int] = ...
+ numPartitions: Optional[int] = ...,
) -> FPGrowth: ...
def setItemsCol(self, value: str) -> FPGrowth: ...
def setMinSupport(self, value: float) -> FPGrowth: ...
@@ -88,7 +86,7 @@ class PrefixSpan(JavaParams):
minSupport: float = ...,
maxPatternLength: int = ...,
maxLocalProjDBSize: int = ...,
- sequenceCol: str = ...
+ sequenceCol: str = ...,
) -> None: ...
def setParams(
self,
@@ -96,7 +94,7 @@ class PrefixSpan(JavaParams):
minSupport: float = ...,
maxPatternLength: int = ...,
maxLocalProjDBSize: int = ...,
- sequenceCol: str = ...
+ sequenceCol: str = ...,
) -> PrefixSpan: ...
def setMinSupport(self, value: float) -> PrefixSpan: ...
def getMinSupport(self) -> float: ...
diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py
index 1eadbd6..64b5948 100644
--- a/python/pyspark/ml/functions.py
+++ b/python/pyspark/ml/functions.py
@@ -66,7 +66,8 @@ def vector_to_array(col, dtype="float64"):
"""
sc = SparkContext._active_spark_context
return Column(
- sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col), dtype))
+ sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col), dtype)
+ )
def array_to_vector(col):
@@ -100,8 +101,7 @@ def array_to_vector(col):
[Row(vec1=DenseVector([1.0, 3.0]))]
"""
sc = SparkContext._active_spark_context
- return Column(
- sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
+ return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
def _test():
@@ -109,18 +109,18 @@ def _test():
from pyspark.sql import SparkSession
import pyspark.ml.functions
import sys
+
globs = pyspark.ml.functions.__dict__.copy()
- spark = SparkSession.builder \
- .master("local[2]") \
- .appName("ml.functions tests") \
- .getOrCreate()
+ spark = SparkSession.builder.master("local[2]").appName("ml.functions tests").getOrCreate()
sc = spark.sparkContext
- globs['sc'] = sc
- globs['spark'] = spark
+ globs["sc"] = sc
+ globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
- pyspark.ml.functions, globs=globs,
- optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+ pyspark.ml.functions,
+ globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
+ )
spark.stop()
if failure_count:
sys.exit(-1)
diff --git a/python/pyspark/ml/functions.pyi b/python/pyspark/ml/functions.pyi
index 12b44fc..cb08398 100644
--- a/python/pyspark/ml/functions.pyi
+++ b/python/pyspark/ml/functions.pyi
@@ -20,5 +20,4 @@ from pyspark import SparkContext as SparkContext, since as since # noqa: F401
from pyspark.sql.column import Column as Column
def vector_to_array(col: Column) -> Column: ...
-
def array_to_vector(col: Column) -> Column: ...
diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py
index 728e9a3..bdd1c9d 100644
--- a/python/pyspark/ml/image.py
+++ b/python/pyspark/ml/image.py
@@ -136,8 +136,9 @@ class _ImageSchema(object):
if self._undefinedImageType is None:
ctx = SparkContext._active_spark_context
- self._undefinedImageType = \
+ self._undefinedImageType = (
ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType()
+ )
return self._undefinedImageType
def toNDArray(self, image):
@@ -161,12 +162,14 @@ class _ImageSchema(object):
if not isinstance(image, Row):
raise TypeError(
"image argument should be pyspark.sql.types.Row; however, "
- "it got [%s]." % type(image))
+ "it got [%s]." % type(image)
+ )
if any(not hasattr(image, f) for f in self.imageFields):
raise ValueError(
"image argument should have attributes specified in "
- "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))
+ "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields)
+ )
height = image.height
width = image.width
@@ -175,7 +178,8 @@ class _ImageSchema(object):
shape=(height, width, nChannels),
dtype=np.uint8,
buffer=image.data,
- strides=(width * nChannels, nChannels, 1))
+ strides=(width * nChannels, nChannels, 1),
+ )
def toImage(self, array, origin=""):
"""
@@ -198,7 +202,8 @@ class _ImageSchema(object):
if not isinstance(array, np.ndarray):
raise TypeError(
- "array argument should be numpy.ndarray; however, it got [%s]." % type(array))
+ "array argument should be numpy.ndarray; however, it got [%s]." % type(array)
+ )
if array.ndim != 3:
raise ValueError("Invalid array shape")
@@ -217,7 +222,7 @@ class _ImageSchema(object):
# Running `bytearray(numpy.array([1]))` fails in specific Python versions
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
# Here, it avoids it by converting it to bytes.
- if LooseVersion(np.__version__) >= LooseVersion('1.9'):
+ if LooseVersion(np.__version__) >= LooseVersion("1.9"):
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
else:
# Numpy prior to 1.9 don't have `tobytes` method.
@@ -226,8 +231,7 @@ class _ImageSchema(object):
# Creating new Row with _create_row(), because Row(name = value, ... )
# orders fields by name, which conflicts with expected schema order
# when the new DataFrame is created by UDF
- return _create_row(self.imageFields,
- [origin, height, width, nChannels, mode, data])
+ return _create_row(self.imageFields, [origin, height, width, nChannels, mode, data])
ImageSchema = _ImageSchema()
@@ -236,22 +240,22 @@ ImageSchema = _ImageSchema()
# Monkey patch to disallow instantiation of this class.
def _disallow_instance(_):
raise RuntimeError("Creating instance of _ImageSchema class is disallowed.")
+
+
_ImageSchema.__init__ = _disallow_instance
def _test():
import doctest
import pyspark.ml.image
+
globs = pyspark.ml.image.__dict__.copy()
- spark = SparkSession.builder\
- .master("local[2]")\
- .appName("ml.image tests")\
- .getOrCreate()
- globs['spark'] = spark
+ spark = SparkSession.builder.master("local[2]").appName("ml.image tests").getOrCreate()
+ globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
- pyspark.ml.image, globs=globs,
- optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+ pyspark.ml.image, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
+ )
spark.stop()
if failure_count:
sys.exit(-1)
diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py
index 2f50bea..46a8b97 100644
--- a/python/pyspark/ml/linalg/__init__.py
+++ b/python/pyspark/ml/linalg/__init__.py
@@ -29,12 +29,28 @@ import struct
import numpy as np
-from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
- IntegerType, ByteType, BooleanType
-
-
-__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
- 'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices']
+from pyspark.sql.types import (
+ UserDefinedType,
+ StructField,
+ StructType,
+ ArrayType,
+ DoubleType,
+ IntegerType,
+ ByteType,
+ BooleanType,
+)
+
+
+__all__ = [
+ "Vector",
+ "DenseVector",
+ "SparseVector",
+ "Vectors",
+ "Matrix",
+ "DenseMatrix",
+ "SparseMatrix",
+ "Matrices",
+]
# Check whether we have SciPy. MLlib works without it too, but if we have it, some methods,
@@ -42,6 +58,7 @@ __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
try:
import scipy.sparse
+
_have_scipy = True
except:
# No SciPy in environment, but that's okay
@@ -103,8 +120,8 @@ def _vector_size(v):
def _format_float(f, digits=4):
s = str(round(f, digits))
- if '.' in s:
- s = s[:s.index('.') + 1 + digits]
+ if "." in s:
+ s = s[: s.index(".") + 1 + digits]
return s
@@ -114,9 +131,9 @@ def _format_float_list(l):
def _double_to_long_bits(value):
if np.isnan(value):
- value = float('nan')
+ value = float("nan")
# pack double into 64 bits, then unpack as long int
- return struct.unpack('Q', struct.pack('d', value))[0]
+ return struct.unpack("Q", struct.pack("d", value))[0]
class VectorUDT(UserDefinedType):
@@ -126,11 +143,14 @@ class VectorUDT(UserDefinedType):
@classmethod
def sqlType(cls):
- return StructType([
- StructField("type", ByteType(), False),
- StructField("size", IntegerType(), True),
- StructField("indices", ArrayType(IntegerType(), False), True),
- StructField("values", ArrayType(DoubleType(), False), True)])
+ return StructType(
+ [
+ StructField("type", ByteType(), False),
+ StructField("size", IntegerType(), True),
+ StructField("indices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True),
+ ]
+ )
@classmethod
def module(cls):
@@ -152,8 +172,9 @@ class VectorUDT(UserDefinedType):
raise TypeError("cannot serialize %r of type %r" % (obj, type(obj)))
def deserialize(self, datum):
- assert len(datum) == 4, \
- "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
+ assert (
+ len(datum) == 4
+ ), "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
tpe = datum[0]
if tpe == 0:
return SparseVector(datum[1], datum[2], datum[3])
@@ -173,14 +194,17 @@ class MatrixUDT(UserDefinedType):
@classmethod
def sqlType(cls):
- return StructType([
- StructField("type", ByteType(), False),
- StructField("numRows", IntegerType(), False),
- StructField("numCols", IntegerType(), False),
- StructField("colPtrs", ArrayType(IntegerType(), False), True),
- StructField("rowIndices", ArrayType(IntegerType(), False), True),
- StructField("values", ArrayType(DoubleType(), False), True),
- StructField("isTransposed", BooleanType(), False)])
+ return StructType(
+ [
+ StructField("type", ByteType(), False),
+ StructField("numRows", IntegerType(), False),
+ StructField("numCols", IntegerType(), False),
+ StructField("colPtrs", ArrayType(IntegerType(), False), True),
+ StructField("rowIndices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True),
+ StructField("isTransposed", BooleanType(), False),
+ ]
+ )
@classmethod
def module(cls):
@@ -195,18 +219,25 @@ class MatrixUDT(UserDefinedType):
colPtrs = [int(i) for i in obj.colPtrs]
rowIndices = [int(i) for i in obj.rowIndices]
values = [float(v) for v in obj.values]
- return (0, obj.numRows, obj.numCols, colPtrs,
- rowIndices, values, bool(obj.isTransposed))
+ return (
+ 0,
+ obj.numRows,
+ obj.numCols,
+ colPtrs,
+ rowIndices,
+ values,
+ bool(obj.isTransposed),
+ )
elif isinstance(obj, DenseMatrix):
values = [float(v) for v in obj.values]
- return (1, obj.numRows, obj.numCols, None, None, values,
- bool(obj.isTransposed))
+ return (1, obj.numRows, obj.numCols, None, None, values, bool(obj.isTransposed))
else:
raise TypeError("cannot serialize type %r" % (type(obj)))
def deserialize(self, datum):
- assert len(datum) == 7, \
- "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
+ assert (
+ len(datum) == 7
+ ), "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
tpe = datum[0]
if tpe == 0:
return SparseMatrix(*datum[1:])
@@ -226,6 +257,7 @@ class Vector(object):
"""
Abstract class for DenseVector and SparseVector
"""
+
def toArray(self):
"""
Convert the vector into an numpy.ndarray
@@ -260,6 +292,7 @@ class DenseVector(Vector):
>>> -v
DenseVector([-1.0, -2.0])
"""
+
def __init__(self, ar):
if isinstance(ar, bytes):
ar = np.frombuffer(ar, dtype=np.float64)
@@ -400,7 +433,7 @@ class DenseVector(Vector):
return "[" + ",".join([str(v) for v in self.array]) + "]"
def __repr__(self):
- return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array))
+ return "DenseVector([%s])" % (", ".join(_format_float(i) for i in self.array))
def __eq__(self, other):
if isinstance(other, DenseVector):
@@ -439,6 +472,7 @@ class DenseVector(Vector):
if isinstance(other, DenseVector):
other = other.array
return DenseVector(getattr(self.array, op)(other))
+
return func
__add__ = _delegate("__add__")
@@ -460,6 +494,7 @@ class SparseVector(Vector):
A simple sparse vector class for passing data to MLlib. Users may
alternatively pass SciPy's {scipy.sparse} data types.
"""
+
def __init__(self, size, *args):
"""
Create a sparse vector, using either a dictionary, a list of
@@ -523,14 +558,17 @@ class SparseVector(Vector):
if self.indices[i] >= self.indices[i + 1]:
raise TypeError(
"Indices %s and %s are not strictly increasing"
- % (self.indices[i], self.indices[i + 1]))
+ % (self.indices[i], self.indices[i + 1])
+ )
if self.indices.size > 0:
- assert np.max(self.indices) < self.size, \
- "Index %d is out of the size of vector with size=%d" \
- % (np.max(self.indices), self.size)
- assert np.min(self.indices) >= 0, \
- "Contains negative index %d" % (np.min(self.indices))
+ assert (
+ np.max(self.indices) < self.size
+ ), "Index %d is out of the size of vector with size=%d" % (
+ np.max(self.indices),
+ self.size,
+ )
+ assert np.min(self.indices) >= 0, "Contains negative index %d" % (np.min(self.indices))
def numNonzeros(self):
"""
@@ -553,9 +591,7 @@ class SparseVector(Vector):
return np.linalg.norm(self.values, p)
def __reduce__(self):
- return (
- SparseVector,
- (self.size, self.indices.tostring(), self.values.tostring()))
+ return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring()))
def dot(self, other):
"""
@@ -646,8 +682,9 @@ class SparseVector(Vector):
if isinstance(other, np.ndarray) or isinstance(other, DenseVector):
if isinstance(other, np.ndarray) and other.ndim != 1:
- raise ValueError("Cannot call squared_distance with %d-dimensional array" %
- other.ndim)
+ raise ValueError(
+ "Cannot call squared_distance with %d-dimensional array" % other.ndim
+ )
if isinstance(other, DenseVector):
other = other.array
sparse_ind = np.zeros(other.size, dtype=bool)
@@ -703,14 +740,18 @@ class SparseVector(Vector):
def __repr__(self):
inds = self.indices
vals = self.values
- entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i]))
- for i in range(len(inds))])
+ entries = ", ".join(
+ ["{0}: {1}".format(inds[i], _format_float(vals[i])) for i in range(len(inds))]
+ )
return "SparseVector({0}, {{{1}}})".format(self.size, entries)
def __eq__(self, other):
if isinstance(other, SparseVector):
- return other.size == self.size and np.array_equal(other.indices, self.indices) \
+ return (
+ other.size == self.size
+ and np.array_equal(other.indices, self.indices)
and np.array_equal(other.values, self.values)
+ )
elif isinstance(other, DenseVector):
if self.size != len(other):
return False
@@ -721,8 +762,7 @@ class SparseVector(Vector):
inds = self.indices
vals = self.values
if not isinstance(index, int):
- raise TypeError(
- "Indices must be of type integer, got type %s" % type(index))
+ raise TypeError("Indices must be of type integer, got type %s" % type(index))
if index >= self.size or index < -self.size:
raise IndexError("Index %d out of bounds." % index)
@@ -730,13 +770,13 @@ class SparseVector(Vector):
index += self.size
if (inds.size == 0) or (index > inds.item(-1)):
- return 0.
+ return 0.0
insert_index = np.searchsorted(inds, index)
row_ind = inds[insert_index]
if row_ind == index:
return vals[insert_index]
- return 0.
+ return 0.0
def __ne__(self, other):
return not self.__eq__(other)
@@ -872,6 +912,7 @@ class Matrix(object):
"""
Represents a local matrix.
"""
+
def __init__(self, numRows, numCols, isTransposed=False):
self.numRows = numRows
self.numCols = numCols
@@ -897,6 +938,7 @@ class DenseMatrix(Matrix):
"""
Column-major dense matrix.
"""
+
def __init__(self, numRows, numCols, values, isTransposed=False):
Matrix.__init__(self, numRows, numCols, isTransposed)
values = self._convert_to_array(values, np.float64)
@@ -905,8 +947,11 @@ class DenseMatrix(Matrix):
def __reduce__(self):
return DenseMatrix, (
- self.numRows, self.numCols, self.values.tostring(),
- int(self.isTransposed))
+ self.numRows,
+ self.numCols,
+ self.values.tostring(),
+ int(self.isTransposed),
+ )
def __str__(self):
"""
@@ -928,7 +973,7 @@ class DenseMatrix(Matrix):
# We need to adjust six spaces which is the difference in number
# of letters between "DenseMatrix" and "array"
- x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]])
+ x = "\n".join([(" " * 6 + line) for line in array_lines[1:]])
return array_lines[0].replace("array", "DenseMatrix") + "\n" + x
def __repr__(self):
@@ -947,14 +992,13 @@ class DenseMatrix(Matrix):
entries = _format_float_list(self.values)
else:
entries = (
- _format_float_list(self.values[:8]) +
- ["..."] +
- _format_float_list(self.values[-8:])
+ _format_float_list(self.values[:8]) + ["..."] + _format_float_list(self.values[-8:])
)
entries = ", ".join(entries)
return "DenseMatrix({0}, {1}, [{2}], {3})".format(
- self.numRows, self.numCols, entries, self.isTransposed)
+ self.numRows, self.numCols, entries, self.isTransposed
+ )
def toArray(self):
"""
@@ -968,21 +1012,19 @@ class DenseMatrix(Matrix):
[ 1., 3.]])
"""
if self.isTransposed:
- return np.asfortranarray(
- self.values.reshape((self.numRows, self.numCols)))
+ return np.asfortranarray(self.values.reshape((self.numRows, self.numCols)))
else:
- return self.values.reshape((self.numRows, self.numCols), order='F')
+ return self.values.reshape((self.numRows, self.numCols), order="F")
def toSparse(self):
"""Convert to SparseMatrix"""
if self.isTransposed:
- values = np.ravel(self.toArray(), order='F')
+ values = np.ravel(self.toArray(), order="F")
else:
values = self.values
indices = np.nonzero(values)[0]
colCounts = np.bincount(indices // self.numRows)
- colPtrs = np.cumsum(np.hstack(
- (0, colCounts, np.zeros(self.numCols - colCounts.size))))
+ colPtrs = np.cumsum(np.hstack((0, colCounts, np.zeros(self.numCols - colCounts.size))))
values = values[indices]
rowIndices = indices % self.numRows
@@ -991,11 +1033,9 @@ class DenseMatrix(Matrix):
def __getitem__(self, indices):
i, j = indices
if i < 0 or i >= self.numRows:
- raise IndexError("Row index %d is out of range [0, %d)"
- % (i, self.numRows))
+ raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows))
if j >= self.numCols or j < 0:
- raise IndexError("Column index %d is out of range [0, %d)"
- % (j, self.numCols))
+ raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols))
if self.isTransposed:
return self.values[i * self.numCols + j]
@@ -1003,20 +1043,20 @@ class DenseMatrix(Matrix):
return self.values[i + j * self.numRows]
def __eq__(self, other):
- if (self.numRows != other.numRows or self.numCols != other.numCols):
+ if self.numRows != other.numRows or self.numCols != other.numCols:
return False
if isinstance(other, SparseMatrix):
return np.all(self.toArray() == other.toArray())
- self_values = np.ravel(self.toArray(), order='F')
- other_values = np.ravel(other.toArray(), order='F')
+ self_values = np.ravel(self.toArray(), order="F")
+ other_values = np.ravel(other.toArray(), order="F")
return np.all(self_values == other_values)
class SparseMatrix(Matrix):
"""Sparse Matrix stored in CSC format."""
- def __init__(self, numRows, numCols, colPtrs, rowIndices, values,
- isTransposed=False):
+
+ def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=False):
Matrix.__init__(self, numRows, numCols, isTransposed)
self.colPtrs = self._convert_to_array(colPtrs, np.int32)
self.rowIndices = self._convert_to_array(rowIndices, np.int32)
@@ -1024,15 +1064,19 @@ class SparseMatrix(Matrix):
if self.isTransposed:
if self.colPtrs.size != numRows + 1:
- raise ValueError("Expected colPtrs of size %d, got %d."
- % (numRows + 1, self.colPtrs.size))
+ raise ValueError(
+ "Expected colPtrs of size %d, got %d." % (numRows + 1, self.colPtrs.size)
+ )
else:
if self.colPtrs.size != numCols + 1:
- raise ValueError("Expected colPtrs of size %d, got %d."
- % (numCols + 1, self.colPtrs.size))
+ raise ValueError(
+ "Expected colPtrs of size %d, got %d." % (numCols + 1, self.colPtrs.size)
+ )
if self.rowIndices.size != self.values.size:
- raise ValueError("Expected rowIndices of length %d, got %d."
- % (self.rowIndices.size, self.values.size))
+ raise ValueError(
+ "Expected rowIndices of length %d, got %d."
+ % (self.rowIndices.size, self.values.size)
+ )
def __str__(self):
"""
@@ -1071,11 +1115,9 @@ class SparseMatrix(Matrix):
if self.colPtrs[cur_col + 1] <= i:
cur_col += 1
if self.isTransposed:
- smlist.append('({0},{1}) {2}'.format(
- cur_col, rowInd, _format_float(value)))
+ smlist.append("({0},{1}) {2}".format(cur_col, rowInd, _format_float(value)))
else:
- smlist.append('({0},{1}) {2}'.format(
- rowInd, cur_col, _format_float(value)))
+ smlist.append("({0},{1}) {2}".format(rowInd, cur_col, _format_float(value)))
spstr += "\n".join(smlist)
if len(self.values) > 16:
@@ -1100,9 +1142,7 @@ class SparseMatrix(Matrix):
else:
values = (
- _format_float_list(self.values[:8]) +
- ["..."] +
- _format_float_list(self.values[-8:])
+ _format_float_list(self.values[:8]) + ["..."] + _format_float_list(self.values[-8:])
)
rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:]
@@ -1113,23 +1153,25 @@ class SparseMatrix(Matrix):
rowIndices = ", ".join([str(ind) for ind in rowIndices])
colPtrs = ", ".join([str(ptr) for ptr in colPtrs])
return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format(
- self.numRows, self.numCols, colPtrs, rowIndices,
- values, self.isTransposed)
+ self.numRows, self.numCols, colPtrs, rowIndices, values, self.isTransposed
+ )
def __reduce__(self):
return SparseMatrix, (
- self.numRows, self.numCols, self.colPtrs.tostring(),
- self.rowIndices.tostring(), self.values.tostring(),
- int(self.isTransposed))
+ self.numRows,
+ self.numCols,
+ self.colPtrs.tostring(),
+ self.rowIndices.tostring(),
+ self.values.tostring(),
+ int(self.isTransposed),
+ )
def __getitem__(self, indices):
i, j = indices
if i < 0 or i >= self.numRows:
- raise IndexError("Row index %d is out of range [0, %d)"
- % (i, self.numRows))
+ raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows))
if j < 0 or j >= self.numCols:
- raise IndexError("Column index %d is out of range [0, %d)"
- % (j, self.numCols))
+ raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols))
# If a CSR matrix is given, then the row index should be searched
# for in ColPtrs, and the column index should be searched for in the
@@ -1139,7 +1181,7 @@ class SparseMatrix(Matrix):
colStart = self.colPtrs[j]
colEnd = self.colPtrs[j + 1]
- nz = self.rowIndices[colStart: colEnd]
+ nz = self.rowIndices[colStart:colEnd]
ind = np.searchsorted(nz, i) + colStart
if ind < colEnd and self.rowIndices[ind] == i:
return self.values[ind]
@@ -1150,7 +1192,7 @@ class SparseMatrix(Matrix):
"""
Return a numpy.ndarray
"""
- A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F')
+ A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order="F")
for k in range(self.colPtrs.size - 1):
startptr = self.colPtrs[k]
endptr = self.colPtrs[k + 1]
@@ -1161,7 +1203,7 @@ class SparseMatrix(Matrix):
return A
def toDense(self):
- densevals = np.ravel(self.toArray(), order='F')
+ densevals = np.ravel(self.toArray(), order="F")
return DenseMatrix(self.numRows, self.numCols, densevals)
# TODO: More efficient implementation:
@@ -1187,14 +1229,16 @@ class Matrices(object):
def _test():
import doctest
+
try:
# Numpy 1.14+ changed it's string format.
- np.set_printoptions(legacy='1.13')
+ np.set_printoptions(legacy="1.13")
except TypeError:
pass
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
sys.exit(-1)
+
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/ml/linalg/__init__.pyi b/python/pyspark/ml/linalg/__init__.pyi
index 46bd812..bb09397 100644
--- a/python/pyspark/ml/linalg/__init__.pyi
+++ b/python/pyspark/ml/linalg/__init__.pyi
@@ -46,9 +46,7 @@ class MatrixUDT(UserDefinedType):
def scalaUDT(cls) -> str: ...
def serialize(
self, obj: Matrix
- ) -> Tuple[
- int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool
- ]: ...
+ ) -> Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool]: ...
def deserialize(self, datum: Any) -> Matrix: ...
def simpleString(self) -> str: ...
@@ -101,9 +99,7 @@ class SparseVector(Vector):
@overload
def __init__(self, size: int, __indices: bytes, __values: bytes) -> None: ...
@overload
- def __init__(
- self, size: int, __indices: Iterable[int], __values: Iterable[float]
- ) -> None: ...
+ def __init__(self, size: int, __indices: Iterable[int], __values: Iterable[float]) -> None: ...
@overload
def __init__(self, size: int, __pairs: Iterable[Tuple[int, float]]) -> None: ...
@overload
@@ -129,9 +125,7 @@ class Vectors:
def sparse(size: int, __indices: bytes, __values: bytes) -> SparseVector: ...
@overload
@staticmethod
- def sparse(
- size: int, __indices: Iterable[int], __values: Iterable[float]
- ) -> SparseVector: ...
+ def sparse(size: int, __indices: Iterable[int], __values: Iterable[float]) -> SparseVector: ...
@overload
@staticmethod
def sparse(size: int, __pairs: Iterable[Tuple[int, float]]) -> SparseVector: ...
@@ -161,9 +155,7 @@ class Matrix:
numRows: int
numCols: int
isTransposed: bool
- def __init__(
- self, numRows: int, numCols: int, isTransposed: bool = ...
- ) -> None: ...
+ def __init__(self, numRows: int, numCols: int, isTransposed: bool = ...) -> None: ...
def toArray(self) -> ndarray: ...
class DenseMatrix(Matrix):
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index ab3491c..e011c50 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -25,7 +25,7 @@ from pyspark.ml.linalg import DenseVector, Vector, Matrix
from pyspark.ml.util import Identifiable
-__all__ = ['Param', 'Params', 'TypeConverters']
+__all__ = ["Param", "Params", "TypeConverters"]
class Param(object):
@@ -78,7 +78,7 @@ class TypeConverters(object):
@staticmethod
def _is_numeric(value):
vtype = type(value)
- return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long'
+ return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == "long"
@staticmethod
def _is_integer(value):
@@ -263,9 +263,16 @@ class Params(Identifiable, metaclass=ABCMeta):
:py:class:`Param`.
"""
if self._params is None:
- self._params = list(filter(lambda attr: isinstance(attr, Param),
- [getattr(self, x) for x in dir(self) if x != "params" and
- not isinstance(getattr(type(self), x, None), property)]))
+ self._params = list(
+ filter(
+ lambda attr: isinstance(attr, Param),
+ [
+ getattr(self, x)
+ for x in dir(self)
+ if x != "params" and not isinstance(getattr(type(self), x, None), property)
+ ],
+ )
+ )
return self._params
def explainParam(self, param):
@@ -484,8 +491,9 @@ class Params(Identifiable, metaclass=ABCMeta):
try:
value = p.typeConverter(value)
except TypeError as e:
- raise TypeError('Invalid default param value given for param "%s". %s'
- % (p.name, e))
+ raise TypeError(
+ 'Invalid default param value given for param "%s". %s' % (p.name, e)
+ )
self._defaultParamMap[p] = value
return self
@@ -512,11 +520,13 @@ class Params(Identifiable, metaclass=ABCMeta):
if isinstance(param, Param):
paramMap[param] = value
else:
- raise TypeError("Expecting a valid instance of Param, but received: {}"
- .format(param))
+ raise TypeError(
+ "Expecting a valid instance of Param, but received: {}".format(param)
+ )
elif extra is not None:
- raise TypeError("Expecting a dict, but received an object of type {}."
- .format(type(extra)))
+ raise TypeError(
+ "Expecting a dict, but received an object of type {}.".format(type(extra))
+ )
for param in self.params:
# copy default params
if param in self._defaultParamMap and to.hasParam(param.name):
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index bcab51f..f8b3d1f 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -54,18 +54,19 @@ def _gen_param_header(name, doc, defaultValueStr, typeConverter):
super(Has$Name, self).__init__()'''
if defaultValueStr is not None:
- template += '''
- self._setDefault($name=$defaultValueStr)'''
+ template += """
+ self._setDefault($name=$defaultValueStr)"""
Name = name[0].upper() + name[1:]
if typeConverter is None:
typeConverter = str(None)
- return template \
- .replace("$name", name) \
- .replace("$Name", Name) \
- .replace("$doc", doc) \
- .replace("$defaultValueStr", str(defaultValueStr)) \
+ return (
+ template.replace("$name", name)
+ .replace("$Name", Name)
+ .replace("$doc", doc)
+ .replace("$defaultValueStr", str(defaultValueStr))
.replace("$typeConverter", typeConverter)
+ )
def _gen_param_code(name, doc, defaultValueStr):
@@ -86,11 +87,13 @@ def _gen_param_code(name, doc, defaultValueStr):
return self.getOrDefault(self.$name)'''
Name = name[0].upper() + name[1:]
- return template \
- .replace("$name", name) \
- .replace("$Name", Name) \
- .replace("$doc", doc) \
+ return (
+ template.replace("$name", name)
+ .replace("$Name", Name)
+ .replace("$doc", doc)
.replace("$defaultValueStr", str(defaultValueStr))
+ )
+
if __name__ == "__main__":
print(header)
@@ -102,74 +105,169 @@ if __name__ == "__main__":
("featuresCol", "features column name.", "'features'", "TypeConverters.toString"),
("labelCol", "label column name.", "'label'", "TypeConverters.toString"),
("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"),
- ("probabilityCol", "Column name for predicted class conditional probabilities. " +
- "Note: Not all models output well-calibrated probability estimates! These probabilities " +
- "should be treated as confidences, not precise probabilities.", "'probability'",
- "TypeConverters.toString"),
... 39109 lines suppressed ...
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org