You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ze...@apache.org on 2022/01/21 11:01:14 UTC

[spark] branch master updated: [SPARK-37972][PYTHON][MLLIB] Address typing incompatibilities with numpy==1.22.x

This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1aa6652  [SPARK-37972][PYTHON][MLLIB] Address typing incompatibilities with numpy==1.22.x
1aa6652 is described below

commit 1aa665239876b32ccf81c9d170e17368c6b44c61
Author: zero323 <ms...@gmail.com>
AuthorDate: Fri Jan 21 12:00:16 2022 +0100

    [SPARK-37972][PYTHON][MLLIB] Address typing incompatibilities with numpy==1.22.x
    
    ### What changes were proposed in this pull request?
    
    This PR:
    
    - Updates `Vector.norm` annotation to match numpy counterpart.
    - Adds cast for numpy `dot` arguments.
    
    ### Why are the changes needed?
    
    To resolve typing incompatibilities between `pyspark.mllib.linalg` and numpy 1.22.
    
    ```
    python/pyspark/mllib/linalg/__init__.py:412: error: Argument 2 to "norm" has incompatible type "Union[float, str]"; expected "Union[None, float, Literal['fro'], Literal['nuc']]"  [arg-type]
    python/pyspark/mllib/linalg/__init__.py:457: error: No overload variant of "dot" matches argument types "ndarray[Any, Any]", "Iterable[float]"  [call-overload]
    python/pyspark/mllib/linalg/__init__.py:457: note: Possible overload variant:
    python/pyspark/mllib/linalg/__init__.py:457: note:     def dot(a: Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]], b: Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]], out: None = ...) -> Any
    python/pyspark/mllib/linalg/__init__.py:457: note:     <1 more non-matching overload not shown>
    python/pyspark/mllib/linalg/__init__.py:707: error: Argument 2 to "norm" has incompatible type "Union[float, str]"; expected "Union[None, float, Literal['fro'], Literal['nuc']]"  [arg-type]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    `dev/lint-python`.
    
    Closes #35261 from zero323/SPARK-37972.
    
    Authored-by: zero323 <ms...@gmail.com>
    Signed-off-by: zero323 <ms...@gmail.com>
---
 python/pyspark/mllib/_typing.pyi        | 2 ++
 python/pyspark/mllib/linalg/__init__.py | 9 +++++----
 2 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi
index 51a98cb..6a1a0f5 100644
--- a/python/pyspark/mllib/_typing.pyi
+++ b/python/pyspark/mllib/_typing.pyi
@@ -17,6 +17,7 @@
 # under the License.
 
 from typing import List, Tuple, TypeVar, Union
+from typing_extensions import Literal
 from pyspark.mllib.linalg import Vector
 from numpy import ndarray  # noqa: F401
 from py4j.java_gateway import JavaObject
@@ -24,3 +25,4 @@ from py4j.java_gateway import JavaObject
 VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]]
 C = TypeVar("C", bound=type)
 JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes]
+NormType = Union[None, float, Literal["fro"], Literal["nuc"]]
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index bbe8728..30fa84c 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -61,8 +61,9 @@ from typing import (
 )
 
 if TYPE_CHECKING:
-    from pyspark.mllib._typing import VectorLike
+    from pyspark.mllib._typing import VectorLike, NormType
     from scipy.sparse import spmatrix
+    from numpy.typing import ArrayLike
 
 
 QT = TypeVar("QT")
@@ -397,7 +398,7 @@ class DenseVector(Vector):
         """
         return np.count_nonzero(self.array)
 
-    def norm(self, p: Union[float, str]) -> np.float64:
+    def norm(self, p: "NormType") -> np.float64:
         """
         Calculates the norm of a DenseVector.
 
@@ -454,7 +455,7 @@ class DenseVector(Vector):
             elif isinstance(other, Vector):
                 return np.dot(self.toArray(), other.toArray())
             else:
-                return np.dot(self.toArray(), other)
+                return np.dot(self.toArray(), cast("ArrayLike", other))
 
     def squared_distance(self, other: Iterable[float]) -> np.float64:
         """
@@ -692,7 +693,7 @@ class SparseVector(Vector):
         """
         return np.count_nonzero(self.values)
 
-    def norm(self, p: Union[float, str]) -> np.float64:
+    def norm(self, p: "NormType") -> np.float64:
         """
         Calculates the norm of a SparseVector.
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org