You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ze...@apache.org on 2022/01/21 11:01:14 UTC
[spark] branch master updated: [SPARK-37972][PYTHON][MLLIB] Address typing incompatibilities with numpy==1.22.x
This is an automated email from the ASF dual-hosted git repository.
zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1aa6652 [SPARK-37972][PYTHON][MLLIB] Address typing incompatibilities with numpy==1.22.x
1aa6652 is described below
commit 1aa665239876b32ccf81c9d170e17368c6b44c61
Author: zero323 <ms...@gmail.com>
AuthorDate: Fri Jan 21 12:00:16 2022 +0100
[SPARK-37972][PYTHON][MLLIB] Address typing incompatibilities with numpy==1.22.x
### What changes were proposed in this pull request?
This PR:
- Updates `Vector.norm` annotation to match numpy counterpart.
- Adds cast for numpy `dot` arguments.
### Why are the changes needed?
To resolve typing incompatibilities between `pyspark.mllib.linalg` and numpy 1.22.
```
python/pyspark/mllib/linalg/__init__.py:412: error: Argument 2 to "norm" has incompatible type "Union[float, str]"; expected "Union[None, float, Literal['fro'], Literal['nuc']]" [arg-type]
python/pyspark/mllib/linalg/__init__.py:457: error: No overload variant of "dot" matches argument types "ndarray[Any, Any]", "Iterable[float]" [call-overload]
python/pyspark/mllib/linalg/__init__.py:457: note: Possible overload variant:
python/pyspark/mllib/linalg/__init__.py:457: note: def dot(a: Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]], b: Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]], out: None = ...) -> Any
python/pyspark/mllib/linalg/__init__.py:457: note: <1 more non-matching overload not shown>
python/pyspark/mllib/linalg/__init__.py:707: error: Argument 2 to "norm" has incompatible type "Union[float, str]"; expected "Union[None, float, Literal['fro'], Literal['nuc']]" [arg-type]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
`dev/lint-python`.
Closes #35261 from zero323/SPARK-37972.
Authored-by: zero323 <ms...@gmail.com>
Signed-off-by: zero323 <ms...@gmail.com>
---
python/pyspark/mllib/_typing.pyi | 2 ++
python/pyspark/mllib/linalg/__init__.py | 9 +++++----
2 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi
index 51a98cb..6a1a0f5 100644
--- a/python/pyspark/mllib/_typing.pyi
+++ b/python/pyspark/mllib/_typing.pyi
@@ -17,6 +17,7 @@
# under the License.
from typing import List, Tuple, TypeVar, Union
+from typing_extensions import Literal
from pyspark.mllib.linalg import Vector
from numpy import ndarray # noqa: F401
from py4j.java_gateway import JavaObject
@@ -24,3 +25,4 @@ from py4j.java_gateway import JavaObject
VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]]
C = TypeVar("C", bound=type)
JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes]
+NormType = Union[None, float, Literal["fro"], Literal["nuc"]]
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index bbe8728..30fa84c 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -61,8 +61,9 @@ from typing import (
)
if TYPE_CHECKING:
- from pyspark.mllib._typing import VectorLike
+ from pyspark.mllib._typing import VectorLike, NormType
from scipy.sparse import spmatrix
+ from numpy.typing import ArrayLike
QT = TypeVar("QT")
@@ -397,7 +398,7 @@ class DenseVector(Vector):
"""
return np.count_nonzero(self.array)
- def norm(self, p: Union[float, str]) -> np.float64:
+ def norm(self, p: "NormType") -> np.float64:
"""
Calculates the norm of a DenseVector.
@@ -454,7 +455,7 @@ class DenseVector(Vector):
elif isinstance(other, Vector):
return np.dot(self.toArray(), other.toArray())
else:
- return np.dot(self.toArray(), other)
+ return np.dot(self.toArray(), cast("ArrayLike", other))
def squared_distance(self, other: Iterable[float]) -> np.float64:
"""
@@ -692,7 +693,7 @@ class SparseVector(Vector):
"""
return np.count_nonzero(self.values)
- def norm(self, p: Union[float, str]) -> np.float64:
+ def norm(self, p: "NormType") -> np.float64:
"""
Calculates the norm of a SparseVector.
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org