This is an automated email from the ASF dual-hosted git repository.
zero323 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 60ce69df029 [SPARK-37234][PYTHON] Inline type hints for
python/pyspark/mllib/stat/_statistics.py
60ce69df029 is described below
commit 60ce69df029b1e1d7cf7f7eece02e668de24cca8
Author: dch nguyen <[email protected]>
AuthorDate: Sun Apr 10 14:14:33 2022 +0200
[SPARK-37234][PYTHON] Inline type hints for
python/pyspark/mllib/stat/_statistics.py
### What changes were proposed in this pull request?
Inline type hints for python/pyspark/mllib/stat/_statistics.py
### Why are the changes needed?
We can take advantage of static type checking within the functions by
inlining the type hints.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests
Closes #34513 from dchvn/SPARK-37234.
Lead-authored-by: dch nguyen <[email protected]>
Co-authored-by: dch nguyen <[email protected]>
Signed-off-by: zero323 <[email protected]>
(cherry picked from commit c3dcdb118ca403a8fbefc3308a116d9e12a1f038)
Signed-off-by: zero323 <[email protected]>
---
python/pyspark/mllib/_typing.pyi | 5 ++
python/pyspark/mllib/stat/_statistics.py | 94 +++++++++++++++++++++++--------
python/pyspark/mllib/stat/_statistics.pyi | 63 ---------------------
3 files changed, 74 insertions(+), 88 deletions(-)
diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi
index 6a1a0f53a59..4fbaeca39be 100644
--- a/python/pyspark/mllib/_typing.pyi
+++ b/python/pyspark/mllib/_typing.pyi
@@ -17,7 +17,9 @@
# under the License.
from typing import List, Tuple, TypeVar, Union
+
from typing_extensions import Literal
+
from pyspark.mllib.linalg import Vector
from numpy import ndarray # noqa: F401
from py4j.java_gateway import JavaObject
@@ -25,4 +27,7 @@ from py4j.java_gateway import JavaObject
VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]]
C = TypeVar("C", bound=type)
JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes]
+
+CorrelationMethod = Union[Literal["spearman"], Literal["pearson"]]
+DistName = Literal["norm"]
NormType = Union[None, float, Literal["fro"], Literal["nuc"]]
diff --git a/python/pyspark/mllib/stat/_statistics.py
b/python/pyspark/mllib/stat/_statistics.py
index 34a373d5358..25095d99dd9 100644
--- a/python/pyspark/mllib/stat/_statistics.py
+++ b/python/pyspark/mllib/stat/_statistics.py
@@ -16,13 +16,19 @@
#
import sys
+from typing import cast, overload, List, Optional, TYPE_CHECKING, Union
+
+from numpy import ndarray
+from py4j.java_gateway import JavaObject
from pyspark.rdd import RDD
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Matrix, _convert_to_vector
+from pyspark.mllib.linalg import Matrix, Vector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.stat.test import ChiSqTestResult,
KolmogorovSmirnovTestResult
+if TYPE_CHECKING:
+ from pyspark.mllib._typing import CorrelationMethod, DistName
__all__ = ["MultivariateStatisticalSummary", "Statistics"]
@@ -33,34 +39,34 @@ class MultivariateStatisticalSummary(JavaModelWrapper):
Trait for multivariate statistical summary of a data matrix.
"""
- def mean(self):
- return self.call("mean").toArray()
+ def mean(self) -> ndarray:
+ return cast(JavaObject, self.call("mean")).toArray()
- def variance(self):
- return self.call("variance").toArray()
+ def variance(self) -> ndarray:
+ return cast(JavaObject, self.call("variance")).toArray()
- def count(self):
+ def count(self) -> int:
return int(self.call("count"))
- def numNonzeros(self):
- return self.call("numNonzeros").toArray()
+ def numNonzeros(self) -> ndarray:
+ return cast(JavaObject, self.call("numNonzeros")).toArray()
- def max(self):
- return self.call("max").toArray()
+ def max(self) -> ndarray:
+ return cast(JavaObject, self.call("max")).toArray()
- def min(self):
- return self.call("min").toArray()
+ def min(self) -> ndarray:
+ return cast(JavaObject, self.call("min")).toArray()
- def normL1(self):
- return self.call("normL1").toArray()
+ def normL1(self) -> ndarray:
+ return cast(JavaObject, self.call("normL1")).toArray()
- def normL2(self):
- return self.call("normL2").toArray()
+ def normL2(self) -> ndarray:
+ return cast(JavaObject, self.call("normL2")).toArray()
class Statistics:
@staticmethod
- def colStats(rdd):
+ def colStats(rdd: RDD[Vector]) -> MultivariateStatisticalSummary:
"""
Computes column-wise summary statistics for the input RDD[Vector].
@@ -98,8 +104,22 @@ class Statistics:
cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector))
return MultivariateStatisticalSummary(cStats)
+ @overload
+ @staticmethod
+ def corr(x: RDD[Vector], *, method: Optional["CorrelationMethod"] = ...)
-> Matrix:
+ ...
+
+ @overload
@staticmethod
- def corr(x, y=None, method=None):
+ def corr(x: RDD[float], y: RDD[float], method:
Optional["CorrelationMethod"] = ...) -> float:
+ ...
+
+ @staticmethod
+ def corr(
+ x: Union[RDD[Vector], RDD[float]],
+ y: Optional[RDD[float]] = None,
+ method: Optional["CorrelationMethod"] = None,
+ ) -> Union[float, Matrix]:
"""
Compute the correlation (matrix) for the input RDD(s) using the
specified method.
@@ -168,12 +188,34 @@ class Statistics:
raise TypeError("Use 'method=' to specify method name.")
if not y:
- return callMLlibFunc("corr", x.map(_convert_to_vector),
method).toArray()
+ return cast(
+ JavaObject, callMLlibFunc("corr", x.map(_convert_to_vector),
method)
+ ).toArray()
else:
- return callMLlibFunc("corr", x.map(float), y.map(float), method)
+ return cast(
+ float,
+ callMLlibFunc("corr", cast(RDD[float], x).map(float),
y.map(float), method),
+ )
+
+ @overload
+ @staticmethod
+ def chiSqTest(observed: Matrix) -> ChiSqTestResult:
+ ...
+
+ @overload
+ @staticmethod
+ def chiSqTest(observed: Vector, expected: Optional[Vector] = ...) ->
ChiSqTestResult:
+ ...
+
+ @overload
+ @staticmethod
+ def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]:
+ ...
@staticmethod
- def chiSqTest(observed, expected=None):
+ def chiSqTest(
+ observed: Union[Matrix, RDD[LabeledPoint], Vector], expected:
Optional[Vector] = None
+ ) -> Union[ChiSqTestResult, List[ChiSqTestResult]]:
"""
If `observed` is Vector, conduct Pearson's chi-squared goodness
of fit test of the observed data against the expected distribution,
@@ -270,7 +312,9 @@ class Statistics:
return ChiSqTestResult(jmodel)
@staticmethod
- def kolmogorovSmirnovTest(data, distName="norm", *params):
+ def kolmogorovSmirnovTest(
+ data: RDD[float], distName: "DistName" = "norm", *params: float
+ ) -> KolmogorovSmirnovTestResult:
"""
Performs the Kolmogorov-Smirnov (KS) test for data sampled from
a continuous distribution. It tests the null hypothesis that
@@ -334,13 +378,13 @@ class Statistics:
if not isinstance(distName, str):
raise TypeError("distName should be a string, got %s." %
type(distName))
- params = [float(param) for param in params]
+ param_list = [float(param) for param in params]
return KolmogorovSmirnovTestResult(
- callMLlibFunc("kolmogorovSmirnovTest", data, distName, params)
+ callMLlibFunc("kolmogorovSmirnovTest", data, distName, param_list)
)
-def _test():
+def _test() -> None:
import doctest
import numpy
from pyspark.sql import SparkSession
diff --git a/python/pyspark/mllib/stat/_statistics.pyi
b/python/pyspark/mllib/stat/_statistics.pyi
deleted file mode 100644
index 1bf76dd3af0..00000000000
--- a/python/pyspark/mllib/stat/_statistics.pyi
+++ /dev/null
@@ -1,63 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import List, Optional, overload, Union
-from typing_extensions import Literal
-
-from numpy import ndarray
-
-from pyspark.mllib.common import JavaModelWrapper
-from pyspark.mllib.linalg import Vector, Matrix
-from pyspark.mllib.regression import LabeledPoint
-from pyspark.mllib.stat.test import ChiSqTestResult,
KolmogorovSmirnovTestResult
-from pyspark.rdd import RDD
-
-CorrelationMethod = Union[Literal["spearman"], Literal["pearson"]]
-
-class MultivariateStatisticalSummary(JavaModelWrapper):
- def mean(self) -> ndarray: ...
- def variance(self) -> ndarray: ...
- def count(self) -> int: ...
- def numNonzeros(self) -> ndarray: ...
- def max(self) -> ndarray: ...
- def min(self) -> ndarray: ...
- def normL1(self) -> ndarray: ...
- def normL2(self) -> ndarray: ...
-
-class Statistics:
- @staticmethod
- def colStats(rdd: RDD[Vector]) -> MultivariateStatisticalSummary: ...
- @overload
- @staticmethod
- def corr(x: RDD[Vector], *, method: Optional[CorrelationMethod] = ...) ->
Matrix: ...
- @overload
- @staticmethod
- def corr(x: RDD[float], y: RDD[float], method: Optional[CorrelationMethod]
= ...) -> float: ...
- @overload
- @staticmethod
- def chiSqTest(observed: Matrix) -> ChiSqTestResult: ...
- @overload
- @staticmethod
- def chiSqTest(observed: Vector, expected: Optional[Vector] = ...) ->
ChiSqTestResult: ...
- @overload
- @staticmethod
- def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ...
- @staticmethod
- def kolmogorovSmirnovTest(
- data: RDD[float], distName: Literal["norm"] = ..., *params: float
- ) -> KolmogorovSmirnovTestResult: ...
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]