This is an automated email from the ASF dual-hosted git repository.
zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1aa6652 [SPARK-37972][PYTHON][MLLIB] Address typing incompatibilities
with numpy==1.22.x
1aa6652 is described below
commit 1aa665239876b32ccf81c9d170e17368c6b44c61
Author: zero323 <[email protected]>
AuthorDate: Fri Jan 21 12:00:16 2022 +0100
[SPARK-37972][PYTHON][MLLIB] Address typing incompatibilities with
numpy==1.22.x
### What changes were proposed in this pull request?
This PR:
- Updates `Vector.norm` annotation to match numpy counterpart.
- Adds cast for numpy `dot` arguments.
### Why are the changes needed?
To resolve typing incompatibilities between `pyspark.mllib.linalg` and
numpy 1.22.
```
python/pyspark/mllib/linalg/__init__.py:412: error: Argument 2 to "norm"
has incompatible type "Union[float, str]"; expected "Union[None, float,
Literal['fro'], Literal['nuc']]" [arg-type]
python/pyspark/mllib/linalg/__init__.py:457: error: No overload variant of
"dot" matches argument types "ndarray[Any, Any]", "Iterable[float]"
[call-overload]
python/pyspark/mllib/linalg/__init__.py:457: note: Possible overload
variant:
python/pyspark/mllib/linalg/__init__.py:457: note: def dot(a:
Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]],
bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float,
complex, str, bytes]]], b: Union[_SupportsArray[dtype[Any]],
_NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str,
bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]], out:
None = ...) -> Any
python/pyspark/mllib/linalg/__init__.py:457: note: <1 more non-matching
overload not shown>
python/pyspark/mllib/linalg/__init__.py:707: error: Argument 2 to "norm"
has incompatible type "Union[float, str]"; expected "Union[None, float,
Literal['fro'], Literal['nuc']]" [arg-type]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
`dev/lint-python`.
Closes #35261 from zero323/SPARK-37972.
Authored-by: zero323 <[email protected]>
Signed-off-by: zero323 <[email protected]>
---
python/pyspark/mllib/_typing.pyi | 2 ++
python/pyspark/mllib/linalg/__init__.py | 9 +++++----
2 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi
index 51a98cb..6a1a0f5 100644
--- a/python/pyspark/mllib/_typing.pyi
+++ b/python/pyspark/mllib/_typing.pyi
@@ -17,6 +17,7 @@
# under the License.
from typing import List, Tuple, TypeVar, Union
+from typing_extensions import Literal
from pyspark.mllib.linalg import Vector
from numpy import ndarray # noqa: F401
from py4j.java_gateway import JavaObject
@@ -24,3 +25,4 @@ from py4j.java_gateway import JavaObject
VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]]
C = TypeVar("C", bound=type)
JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes]
+NormType = Union[None, float, Literal["fro"], Literal["nuc"]]
diff --git a/python/pyspark/mllib/linalg/__init__.py
b/python/pyspark/mllib/linalg/__init__.py
index bbe8728..30fa84c 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -61,8 +61,9 @@ from typing import (
)
if TYPE_CHECKING:
- from pyspark.mllib._typing import VectorLike
+ from pyspark.mllib._typing import VectorLike, NormType
from scipy.sparse import spmatrix
+ from numpy.typing import ArrayLike
QT = TypeVar("QT")
@@ -397,7 +398,7 @@ class DenseVector(Vector):
"""
return np.count_nonzero(self.array)
- def norm(self, p: Union[float, str]) -> np.float64:
+ def norm(self, p: "NormType") -> np.float64:
"""
Calculates the norm of a DenseVector.
@@ -454,7 +455,7 @@ class DenseVector(Vector):
elif isinstance(other, Vector):
return np.dot(self.toArray(), other.toArray())
else:
- return np.dot(self.toArray(), other)
+ return np.dot(self.toArray(), cast("ArrayLike", other))
def squared_distance(self, other: Iterable[float]) -> np.float64:
"""
@@ -692,7 +693,7 @@ class SparseVector(Vector):
"""
return np.count_nonzero(self.values)
- def norm(self, p: Union[float, str]) -> np.float64:
+ def norm(self, p: "NormType") -> np.float64:
"""
Calculates the norm of a SparseVector.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]