This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 711d0b4  [SPARK-37416][PYTHON][ML] Inline type hints for 
pyspark.ml.wrapper
711d0b4 is described below

commit 711d0b4aac45f5bb69f91a774e94063a769d31d0
Author: zero323 <[email protected]>
AuthorDate: Mon Feb 7 00:02:33 2022 +0100

    [SPARK-37416][PYTHON][ML] Inline type hints for pyspark.ml.wrapper
    
    ### What changes were proposed in this pull request?
    
    This PR migrates type `pyspark.ml.wrapper` annotations from stub file to 
inline type hints.
    
    ### Why are the changes needed?
    
    Part of ongoing migration of type hints.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #35399 from zero323/SPARK-37416.
    
    Authored-by: zero323 <[email protected]>
    Signed-off-by: zero323 <[email protected]>
---
 python/pyspark/ml/classification.pyi |  16 +++++-
 python/pyspark/ml/clustering.pyi     |   6 ++
 python/pyspark/ml/feature.pyi        |  23 +++++++-
 python/pyspark/ml/fpm.pyi            |   3 +
 python/pyspark/ml/recommendation.pyi |   3 +
 python/pyspark/ml/regression.pyi     |  14 ++++-
 python/pyspark/ml/wrapper.py         | 108 ++++++++++++++++++++++++-----------
 python/pyspark/ml/wrapper.pyi        |  51 -----------------
 8 files changed, 133 insertions(+), 91 deletions(-)

diff --git a/python/pyspark/ml/classification.pyi 
b/python/pyspark/ml/classification.pyi
index 4170a8c..89089a4 100644
--- a/python/pyspark/ml/classification.pyi
+++ b/python/pyspark/ml/classification.pyi
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import Any, List, Optional, Type
+from typing import Any, Generic, List, Optional, Type
 from pyspark.ml._typing import JM, M, P, T, ParamMap
 
 import abc
@@ -69,6 +69,8 @@ from pyspark.ml.param import Param
 from pyspark.ml.regression import DecisionTreeRegressionModel
 from pyspark.sql.dataframe import DataFrame
 
+from py4j.java_gateway import JavaObject  # type: ignore[import]
+
 class _ClassifierParams(HasRawPredictionCol, _PredictorParams): ...
 
 class Classifier(Predictor, _ClassifierParams, metaclass=abc.ABCMeta):
@@ -96,7 +98,7 @@ class ProbabilisticClassificationModel(
     @abstractmethod
     def predictProbability(self, value: Vector) -> Vector: ...
 
-class _JavaClassifier(Classifier, JavaPredictor[JM], metaclass=abc.ABCMeta):
+class _JavaClassifier(Classifier, JavaPredictor[JM], Generic[JM], 
metaclass=abc.ABCMeta):
     def setRawPredictionCol(self: P, value: str) -> P: ...
 
 class _JavaClassificationModel(ClassificationModel, JavaPredictionModel[T]):
@@ -105,7 +107,7 @@ class _JavaClassificationModel(ClassificationModel, 
JavaPredictionModel[T]):
     def predictRaw(self, value: Vector) -> Vector: ...
 
 class _JavaProbabilisticClassifier(
-    ProbabilisticClassifier, _JavaClassifier[JM], metaclass=abc.ABCMeta
+    ProbabilisticClassifier, _JavaClassifier[JM], Generic[JM], 
metaclass=abc.ABCMeta
 ): ...
 
 class _JavaProbabilisticClassificationModel(
@@ -231,6 +233,7 @@ class LinearSVC(
     def setWeightCol(self, value: str) -> LinearSVC: ...
     def setAggregationDepth(self, value: int) -> LinearSVC: ...
     def setMaxBlockSizeInMB(self, value: float) -> LinearSVC: ...
+    def _create_model(self, java_model: JavaObject) -> LinearSVCModel: ...
 
 class LinearSVCModel(
     _JavaClassificationModel[Vector],
@@ -350,6 +353,7 @@ class LogisticRegression(
     def setWeightCol(self, value: str) -> LogisticRegression: ...
     def setAggregationDepth(self, value: int) -> LogisticRegression: ...
     def setMaxBlockSizeInMB(self, value: float) -> LogisticRegression: ...
+    def _create_model(self, java_model: JavaObject) -> 
LogisticRegressionModel: ...
 
 class LogisticRegressionModel(
     _JavaProbabilisticClassificationModel[Vector],
@@ -444,6 +448,7 @@ class DecisionTreeClassifier(
     def setCheckpointInterval(self, value: int) -> DecisionTreeClassifier: ...
     def setSeed(self, value: int) -> DecisionTreeClassifier: ...
     def setWeightCol(self, value: str) -> DecisionTreeClassifier: ...
+    def _create_model(self, java_model: JavaObject) -> 
DecisionTreeClassificationModel: ...
 
 class DecisionTreeClassificationModel(
     _DecisionTreeModel,
@@ -529,6 +534,7 @@ class RandomForestClassifier(
     def setCheckpointInterval(self, value: int) -> RandomForestClassifier: ...
     def setWeightCol(self, value: str) -> RandomForestClassifier: ...
     def setMinWeightFractionPerNode(self, value: float) -> 
RandomForestClassifier: ...
+    def _create_model(self, java_model: JavaObject) -> 
RandomForestClassificationModel: ...
 
 class RandomForestClassificationModel(
     _TreeEnsembleModel,
@@ -633,6 +639,7 @@ class GBTClassifier(
     def setStepSize(self, value: float) -> GBTClassifier: ...
     def setWeightCol(self, value: str) -> GBTClassifier: ...
     def setMinWeightFractionPerNode(self, value: float) -> GBTClassifier: ...
+    def _create_model(self, java_model: JavaObject) -> GBTClassificationModel: 
...
 
 class GBTClassificationModel(
     _TreeEnsembleModel,
@@ -691,6 +698,7 @@ class NaiveBayes(
     def setSmoothing(self, value: float) -> NaiveBayes: ...
     def setModelType(self, value: str) -> NaiveBayes: ...
     def setWeightCol(self, value: str) -> NaiveBayes: ...
+    def _create_model(self, java_model: JavaObject) -> NaiveBayesModel: ...
 
 class NaiveBayesModel(
     _JavaProbabilisticClassificationModel[Vector],
@@ -769,6 +777,7 @@ class MultilayerPerceptronClassifier(
     def setTol(self, value: float) -> MultilayerPerceptronClassifier: ...
     def setStepSize(self, value: float) -> MultilayerPerceptronClassifier: ...
     def setSolver(self, value: str) -> MultilayerPerceptronClassifier: ...
+    def _create_model(self, java_model: JavaObject) -> 
MultilayerPerceptronClassificationModel: ...
 
 class MultilayerPerceptronClassificationModel(
     _JavaProbabilisticClassificationModel[Vector],
@@ -921,6 +930,7 @@ class FMClassifier(
     def setSeed(self, value: int) -> FMClassifier: ...
     def setFitIntercept(self, value: bool) -> FMClassifier: ...
     def setRegParam(self, value: float) -> FMClassifier: ...
+    def _create_model(self, java_model: JavaObject) -> FMClassificationModel: 
...
 
 class FMClassificationModel(
     _JavaProbabilisticClassificationModel[Vector],
diff --git a/python/pyspark/ml/clustering.pyi b/python/pyspark/ml/clustering.pyi
index 81074fc..e0ee3d6 100644
--- a/python/pyspark/ml/clustering.pyi
+++ b/python/pyspark/ml/clustering.pyi
@@ -45,6 +45,8 @@ from pyspark.sql.dataframe import DataFrame
 
 from numpy import ndarray
 
+from py4j.java_gateway import JavaObject  # type: ignore[import]
+
 class ClusteringSummary(JavaWrapper):
     @property
     def predictionCol(self) -> str: ...
@@ -137,6 +139,7 @@ class GaussianMixture(
     def setSeed(self, value: int) -> GaussianMixture: ...
     def setTol(self, value: float) -> GaussianMixture: ...
     def setAggregationDepth(self, value: int) -> GaussianMixture: ...
+    def _create_model(self, java_model: JavaObject) -> GaussianMixtureModel: 
...
 
 class GaussianMixtureSummary(ClusteringSummary):
     @property
@@ -219,6 +222,7 @@ class KMeans(JavaEstimator[KMeansModel], _KMeansParams, 
JavaMLWritable, JavaMLRe
     def setSeed(self, value: int) -> KMeans: ...
     def setTol(self, value: float) -> KMeans: ...
     def setWeightCol(self, value: str) -> KMeans: ...
+    def _create_model(self, java_model: JavaObject) -> KMeansModel: ...
 
 class _BisectingKMeansParams(
     HasMaxIter,
@@ -287,6 +291,7 @@ class BisectingKMeans(
     def setPredictionCol(self, value: str) -> BisectingKMeans: ...
     def setSeed(self, value: int) -> BisectingKMeans: ...
     def setWeightCol(self, value: str) -> BisectingKMeans: ...
+    def _create_model(self, java_model: JavaObject) -> BisectingKMeansModel: 
...
 
 class BisectingKMeansSummary(ClusteringSummary):
     @property
@@ -386,6 +391,7 @@ class LDA(JavaEstimator[LDAModel], _LDAParams, 
JavaMLReadable[LDA], JavaMLWritab
     def setKeepLastCheckpoint(self, value: bool) -> LDA: ...
     def setMaxIter(self, value: int) -> LDA: ...
     def setFeaturesCol(self, value: str) -> LDA: ...
+    def _create_model(self, java_model: JavaObject) -> LDAModel: ...
 
 class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
     k: Param[int]
diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi
index 6efc304..ecfd26e 100644
--- a/python/pyspark/ml/feature.pyi
+++ b/python/pyspark/ml/feature.pyi
@@ -42,6 +42,8 @@ from pyspark.ml.linalg import Vector, DenseVector, DenseMatrix
 from pyspark.sql.dataframe import DataFrame
 from pyspark.ml.param import Param
 
+from py4j.java_gateway import JavaObject  # type: ignore[import]
+
 class Binarizer(
     JavaTransformer,
     HasThreshold,
@@ -103,6 +105,7 @@ class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, 
JavaMLReadable, JavaMLWri
     def setNumHashTables(self: P, value: int) -> P: ...
     def setInputCol(self: P, value: str) -> P: ...
     def setOutputCol(self: P, value: str) -> P: ...
+    def _create_model(self, java_model: JavaObject) -> JM: ...
 
 class _LSHModel(JavaModel, _LSHParams):
     def setInputCol(self: P, value: str) -> P: ...
@@ -268,6 +271,7 @@ class CountVectorizer(
     def setBinary(self, value: bool) -> CountVectorizer: ...
     def setInputCol(self, value: str) -> CountVectorizer: ...
     def setOutputCol(self, value: str) -> CountVectorizer: ...
+    def _create_model(self, java_model: JavaObject) -> CountVectorizerModel: 
...
 
 class CountVectorizerModel(JavaModel, JavaMLReadable[CountVectorizerModel], 
JavaMLWritable):
     def setInputCol(self, value: str) -> CountVectorizerModel: ...
@@ -412,6 +416,7 @@ class IDF(JavaEstimator[IDFModel], _IDFParams, 
JavaMLReadable[IDF], JavaMLWritab
     def setMinDocFreq(self, value: int) -> IDF: ...
     def setInputCol(self, value: str) -> IDF: ...
     def setOutputCol(self, value: str) -> IDF: ...
+    def _create_model(self, java_model: JavaObject) -> IDFModel: ...
 
 class IDFModel(JavaModel, _IDFParams, JavaMLReadable[IDFModel], 
JavaMLWritable):
     def setInputCol(self, value: str) -> IDFModel: ...
@@ -477,6 +482,7 @@ class Imputer(JavaEstimator[ImputerModel], _ImputerParams, 
JavaMLReadable[Impute
     def setInputCol(self, value: str) -> Imputer: ...
     def setOutputCol(self, value: str) -> Imputer: ...
     def setRelativeError(self, value: float) -> Imputer: ...
+    def _create_model(self, java_model: JavaObject) -> ImputerModel: ...
 
 class ImputerModel(JavaModel, _ImputerParams, JavaMLReadable[ImputerModel], 
JavaMLWritable):
     def setInputCols(self, value: List[str]) -> ImputerModel: ...
@@ -518,6 +524,7 @@ class MaxAbsScaler(
     ) -> MaxAbsScaler: ...
     def setInputCol(self, value: str) -> MaxAbsScaler: ...
     def setOutputCol(self, value: str) -> MaxAbsScaler: ...
+    def _create_model(self, java_model: JavaObject) -> MaxAbsScalerModel: ...
 
 class MaxAbsScalerModel(
     JavaModel, _MaxAbsScalerParams, JavaMLReadable[MaxAbsScalerModel], 
JavaMLWritable
@@ -588,6 +595,7 @@ class MinMaxScaler(
     def setMax(self, value: float) -> MinMaxScaler: ...
     def setInputCol(self, value: str) -> MinMaxScaler: ...
     def setOutputCol(self, value: str) -> MinMaxScaler: ...
+    def _create_model(self, java_model: JavaObject) -> MinMaxScalerModel: ...
 
 class MinMaxScalerModel(
     JavaModel, _MinMaxScalerParams, JavaMLReadable[MinMaxScalerModel], 
JavaMLWritable
@@ -687,6 +695,7 @@ class OneHotEncoder(
     def setHandleInvalid(self, value: str) -> OneHotEncoder: ...
     def setInputCol(self, value: str) -> OneHotEncoder: ...
     def setOutputCol(self, value: str) -> OneHotEncoder: ...
+    def _create_model(self, java_model: JavaObject) -> OneHotEncoderModel: ...
 
 class OneHotEncoderModel(
     JavaModel, _OneHotEncoderParams, JavaMLReadable[OneHotEncoderModel], 
JavaMLWritable
@@ -783,6 +792,7 @@ class QuantileDiscretizer(
     def setOutputCol(self, value: str) -> QuantileDiscretizer: ...
     def setOutputCols(self, value: List[str]) -> QuantileDiscretizer: ...
     def setHandleInvalid(self, value: str) -> QuantileDiscretizer: ...
+    def _create_model(self, java_model: JavaObject) -> Bucketizer: ...
 
 class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
     lower: Param[float]
@@ -827,6 +837,7 @@ class RobustScaler(
     def setInputCol(self, value: str) -> RobustScaler: ...
     def setOutputCol(self, value: str) -> RobustScaler: ...
     def setRelativeError(self, value: float) -> RobustScaler: ...
+    def _create_model(self, java_model: JavaObject) -> RobustScalerModel: ...
 
 class RobustScalerModel(
     JavaModel, _RobustScalerParams, JavaMLReadable[RobustScalerModel], 
JavaMLWritable
@@ -920,6 +931,7 @@ class StandardScaler(
     def setWithStd(self, value: bool) -> StandardScaler: ...
     def setInputCol(self, value: str) -> StandardScaler: ...
     def setOutputCol(self, value: str) -> StandardScaler: ...
+    def _create_model(self, java_model: JavaObject) -> StandardScalerModel: ...
 
 class StandardScalerModel(
     JavaModel,
@@ -990,6 +1002,7 @@ class StringIndexer(
     def setOutputCol(self, value: str) -> StringIndexer: ...
     def setOutputCols(self, value: List[str]) -> StringIndexer: ...
     def setHandleInvalid(self, value: str) -> StringIndexer: ...
+    def _create_model(self, java_model: JavaObject) -> StringIndexerModel: ...
 
 class StringIndexerModel(
     JavaModel, _StringIndexerParams, JavaMLReadable[StringIndexerModel], 
JavaMLWritable
@@ -1186,6 +1199,7 @@ class VectorIndexer(
     def setInputCol(self, value: str) -> VectorIndexer: ...
     def setOutputCol(self, value: str) -> VectorIndexer: ...
     def setHandleInvalid(self, value: str) -> VectorIndexer: ...
+    def _create_model(self, java_model: JavaObject) -> VectorIndexerModel: ...
 
 class VectorIndexerModel(
     JavaModel, _VectorIndexerParams, JavaMLReadable[VectorIndexerModel], 
JavaMLWritable
@@ -1286,6 +1300,7 @@ class Word2Vec(
     def setOutputCol(self, value: str) -> Word2Vec: ...
     def setSeed(self, value: int) -> Word2Vec: ...
     def setStepSize(self, value: float) -> Word2Vec: ...
+    def _create_model(self, java_model: JavaObject) -> Word2VecModel: ...
 
 class Word2VecModel(JavaModel, _Word2VecParams, JavaMLReadable[Word2VecModel], 
JavaMLWritable):
     def getVectors(self) -> DataFrame: ...
@@ -1322,6 +1337,7 @@ class PCA(JavaEstimator[PCAModel], _PCAParams, 
JavaMLReadable[PCA], JavaMLWritab
     def setK(self, value: int) -> PCA: ...
     def setInputCol(self, value: str) -> PCA: ...
     def setOutputCol(self, value: str) -> PCA: ...
+    def _create_model(self, java_model: JavaObject) -> PCAModel: ...
 
 class PCAModel(JavaModel, _PCAParams, JavaMLReadable[PCAModel], 
JavaMLWritable):
     def setInputCol(self, value: str) -> PCAModel: ...
@@ -1373,6 +1389,7 @@ class RFormula(
     def setFeaturesCol(self, value: str) -> RFormula: ...
     def setLabelCol(self, value: str) -> RFormula: ...
     def setHandleInvalid(self, value: str) -> RFormula: ...
+    def _create_model(self, java_model: JavaObject) -> RFormulaModel: ...
 
 class RFormulaModel(JavaModel, _RFormulaParams, JavaMLReadable[RFormulaModel], 
JavaMLWritable): ...
 
@@ -1391,7 +1408,7 @@ class _SelectorParams(HasFeaturesCol, HasOutputCol, 
HasLabelCol):
     def getFdr(self) -> float: ...
     def getFwe(self) -> float: ...
 
-class _Selector(JavaEstimator[JM], _SelectorParams, JavaMLReadable, 
JavaMLWritable):
+class _Selector(JavaEstimator[JM], _SelectorParams, JavaMLReadable, 
JavaMLWritable, Generic[JM]):
     def setSelectorType(self: P, value: str) -> P: ...
     def setNumTopFeatures(self: P, value: int) -> P: ...
     def setPercentile(self: P, value: float) -> P: ...
@@ -1401,6 +1418,7 @@ class _Selector(JavaEstimator[JM], _SelectorParams, 
JavaMLReadable, JavaMLWritab
     def setFeaturesCol(self: P, value: str) -> P: ...
     def setOutputCol(self: P, value: str) -> P: ...
     def setLabelCol(self: P, value: str) -> P: ...
+    def _create_model(self, java_model: JavaObject) -> JM: ...
 
 class _SelectorModel(JavaModel, _SelectorParams):
     def setFeaturesCol(self: P, value: str) -> P: ...
@@ -1448,6 +1466,7 @@ class ChiSqSelector(
     def setFeaturesCol(self, value: str) -> ChiSqSelector: ...
     def setOutputCol(self, value: str) -> ChiSqSelector: ...
     def setLabelCol(self, value: str) -> ChiSqSelector: ...
+    def _create_model(self, java_model: JavaObject) -> ChiSqSelectorModel: ...
 
 class ChiSqSelectorModel(_SelectorModel, JavaMLReadable[ChiSqSelectorModel], 
JavaMLWritable):
     def setFeaturesCol(self, value: str) -> ChiSqSelectorModel: ...
@@ -1500,6 +1519,7 @@ class VarianceThresholdSelector(
     def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: 
...
     def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ...
     def setOutputCol(self, value: str) -> VarianceThresholdSelector: ...
+    def _create_model(self, java_model: JavaObject) -> 
VarianceThresholdSelectorModel: ...
 
 class VarianceThresholdSelectorModel(
     JavaModel,
@@ -1552,6 +1572,7 @@ class UnivariateFeatureSelector(
     def setFeaturesCol(self, value: str) -> UnivariateFeatureSelector: ...
     def setOutputCol(self, value: str) -> UnivariateFeatureSelector: ...
     def setLabelCol(self, value: str) -> UnivariateFeatureSelector: ...
+    def _create_model(self, java_model: JavaObject) -> 
UnivariateFeatureSelectorModel: ...
 
 class UnivariateFeatureSelectorModel(
     JavaModel,
diff --git a/python/pyspark/ml/fpm.pyi b/python/pyspark/ml/fpm.pyi
index 609bc44..00d5c5f 100644
--- a/python/pyspark/ml/fpm.pyi
+++ b/python/pyspark/ml/fpm.pyi
@@ -25,6 +25,8 @@ from pyspark.sql.dataframe import DataFrame
 
 from pyspark.ml.param import Param
 
+from py4j.java_gateway import JavaObject  # type: ignore[import]
+
 class _FPGrowthParams(HasPredictionCol):
     itemsCol: Param[str]
     minSupport: Param[float]
@@ -74,6 +76,7 @@ class FPGrowth(
     def setNumPartitions(self, value: int) -> FPGrowth: ...
     def setMinConfidence(self, value: float) -> FPGrowth: ...
     def setPredictionCol(self, value: str) -> FPGrowth: ...
+    def _create_model(self, java_model: JavaObject) -> FPGrowthModel: ...
 
 class PrefixSpan(JavaParams):
     minSupport: Param[float]
diff --git a/python/pyspark/ml/recommendation.pyi 
b/python/pyspark/ml/recommendation.pyi
index 6ce178b..f7faaca 100644
--- a/python/pyspark/ml/recommendation.pyi
+++ b/python/pyspark/ml/recommendation.pyi
@@ -36,6 +36,8 @@ from pyspark.ml.util import JavaMLWritable, JavaMLReadable
 
 from pyspark.sql.dataframe import DataFrame
 
+from py4j.java_gateway import JavaObject  # type: ignore[import]
+
 class _ALSModelParams(HasPredictionCol, HasBlockSize):
     userCol: Param[str]
     itemCol: Param[str]
@@ -127,6 +129,7 @@ class ALS(JavaEstimator[ALSModel], _ALSParams, 
JavaMLWritable, JavaMLReadable[AL
     def setCheckpointInterval(self, value: int) -> ALS: ...
     def setSeed(self, value: int) -> ALS: ...
     def setBlockSize(self, value: int) -> ALS: ...
+    def _create_model(self, java_model: JavaObject) -> ALSModel: ...
 
 class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, 
JavaMLReadable[ALSModel]):
     def setUserCol(self, value: str) -> ALSModel: ...
diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi
index 3b553b1..750e4c7 100644
--- a/python/pyspark/ml/regression.pyi
+++ b/python/pyspark/ml/regression.pyi
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import Any, List, Optional
+from typing import Any, Generic, List, Optional
 from pyspark.ml._typing import JM, M, T
 
 import abc
@@ -67,9 +67,11 @@ from pyspark.ml.linalg import Matrix, Vector
 from pyspark.ml.param import Param
 from pyspark.sql.dataframe import DataFrame
 
+from py4j.java_gateway import JavaObject  # type: ignore[import]
+
 class Regressor(Predictor[M], _PredictorParams, metaclass=abc.ABCMeta): ...
 class RegressionModel(PredictionModel[T], _PredictorParams, 
metaclass=abc.ABCMeta): ...
-class _JavaRegressor(Regressor, JavaPredictor[JM], metaclass=abc.ABCMeta): ...
+class _JavaRegressor(Regressor, JavaPredictor[JM], Generic[JM], 
metaclass=abc.ABCMeta): ...
 class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T], 
metaclass=abc.ABCMeta): ...
 
 class _LinearRegressionParams(
@@ -146,6 +148,7 @@ class LinearRegression(
     def setAggregationDepth(self, value: int) -> LinearRegression: ...
     def setLoss(self, value: str) -> LinearRegression: ...
     def setMaxBlockSizeInMB(self, value: float) -> LinearRegression: ...
+    def _create_model(self, java_model: JavaObject) -> LinearRegressionModel: 
...
 
 class LinearRegressionModel(
     _JavaRegressionModel[Vector],
@@ -241,6 +244,7 @@ class IsotonicRegression(
     def setPredictionCol(self, value: str) -> IsotonicRegression: ...
     def setLabelCol(self, value: str) -> IsotonicRegression: ...
     def setWeightCol(self, value: str) -> IsotonicRegression: ...
+    def _create_model(self, java_model: JavaObject) -> 
IsotonicRegressionModel: ...
 
 class IsotonicRegressionModel(
     JavaModel,
@@ -320,6 +324,7 @@ class DecisionTreeRegressor(
     def setSeed(self, value: int) -> DecisionTreeRegressor: ...
     def setWeightCol(self, value: str) -> DecisionTreeRegressor: ...
     def setVarianceCol(self, value: str) -> DecisionTreeRegressor: ...
+    def _create_model(self, java_model: JavaObject) -> 
DecisionTreeRegressionModel: ...
 
 class DecisionTreeRegressionModel(
     _JavaRegressionModel[Vector],
@@ -402,6 +407,7 @@ class RandomForestRegressor(
     def setSeed(self, value: int) -> RandomForestRegressor: ...
     def setWeightCol(self, value: str) -> RandomForestRegressor: ...
     def setMinWeightFractionPerNode(self, value: float) -> 
RandomForestRegressor: ...
+    def _create_model(self, java_model: JavaObject) -> 
RandomForestRegressionModel: ...
 
 class RandomForestRegressionModel(
     _JavaRegressionModel[Vector],
@@ -496,6 +502,7 @@ class GBTRegressor(
     def setStepSize(self, value: float) -> GBTRegressor: ...
     def setWeightCol(self, value: str) -> GBTRegressor: ...
     def setMinWeightFractionPerNode(self, value: float) -> GBTRegressor: ...
+    def _create_model(self, java_model: JavaObject) -> GBTRegressionModel: ...
 
 class GBTRegressionModel(
     _JavaRegressionModel[Vector],
@@ -570,6 +577,7 @@ class AFTSurvivalRegression(
     def setFitIntercept(self, value: bool) -> AFTSurvivalRegression: ...
     def setAggregationDepth(self, value: int) -> AFTSurvivalRegression: ...
     def setMaxBlockSizeInMB(self, value: float) -> AFTSurvivalRegression: ...
+    def _create_model(self, java_model: JavaObject) -> 
AFTSurvivalRegressionModel: ...
 
 class AFTSurvivalRegressionModel(
     _JavaRegressionModel[Vector],
@@ -672,6 +680,7 @@ class GeneralizedLinearRegression(
     def setWeightCol(self, value: str) -> GeneralizedLinearRegression: ...
     def setSolver(self, value: str) -> GeneralizedLinearRegression: ...
     def setAggregationDepth(self, value: int) -> GeneralizedLinearRegression: 
...
+    def _create_model(self, java_model: JavaObject) -> 
GeneralizedLinearRegressionModel: ...
 
 class GeneralizedLinearRegressionModel(
     _JavaRegressionModel[Vector],
@@ -802,6 +811,7 @@ class FMRegressor(
     def setSeed(self, value: int) -> FMRegressor: ...
     def setFitIntercept(self, value: bool) -> FMRegressor: ...
     def setRegParam(self, value: float) -> FMRegressor: ...
+    def _create_model(self, java_model: JavaObject) -> FMRegressionModel: ...
 
 class FMRegressionModel(
     _JavaRegressionModel,
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index c35df2e..7f03f64 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -17,49 +17,68 @@
 
 from abc import ABCMeta, abstractmethod
 
+from typing import Any, Generic, Optional, List, Type, TypeVar, TYPE_CHECKING
+
 from pyspark import since
 from pyspark import SparkContext
 from pyspark.sql import DataFrame
 from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, 
Model
 from pyspark.ml.base import _PredictorParams
-from pyspark.ml.param import Params
-from pyspark.ml.util import _jvm
+from pyspark.ml.param import Param, Params
+from pyspark.ml.util import _jvm  # type: ignore[attr-defined]
 from pyspark.ml.common import inherit_doc, _java2py, _py2java
 
 
+if TYPE_CHECKING:
+    from pyspark.ml._typing import ParamMap
+    from py4j.java_gateway import JavaObject, JavaClass
+
+
+T = TypeVar("T")
+JW = TypeVar("JW", bound="JavaWrapper")
+JM = TypeVar("JM", bound="JavaTransformer")
+JP = TypeVar("JP", bound="JavaParams")
+
+
 class JavaWrapper:
     """
     Wrapper class for a Java companion object
     """
 
-    def __init__(self, java_obj=None):
+    def __init__(self, java_obj: Optional["JavaObject"] = None):
         super(JavaWrapper, self).__init__()
         self._java_obj = java_obj
 
-    def __del__(self):
+    def __del__(self) -> None:
         if SparkContext._active_spark_context and self._java_obj is not None:
-            SparkContext._active_spark_context._gateway.detach(self._java_obj)
+            SparkContext._active_spark_context._gateway.detach(  # type: 
ignore[union-attr]
+                self._java_obj
+            )
 
     @classmethod
-    def _create_from_java_class(cls, java_class, *args):
+    def _create_from_java_class(cls: Type[JW], java_class: str, *args: Any) -> 
JW:
         """
         Construct this object from given Java classname and arguments
         """
         java_obj = JavaWrapper._new_java_obj(java_class, *args)
         return cls(java_obj)
 
-    def _call_java(self, name, *args):
+    def _call_java(self, name: str, *args: Any) -> Any:
         m = getattr(self._java_obj, name)
         sc = SparkContext._active_spark_context
+        assert sc is not None
+
         java_args = [_py2java(sc, arg) for arg in args]
         return _java2py(sc, m(*java_args))
 
     @staticmethod
-    def _new_java_obj(java_class, *args):
+    def _new_java_obj(java_class: str, *args: Any) -> "JavaObject":
         """
         Returns a new Java object.
         """
         sc = SparkContext._active_spark_context
+        assert sc is not None
+
         java_obj = _jvm()
         for name in java_class.split("."):
             java_obj = getattr(java_obj, name)
@@ -67,7 +86,7 @@ class JavaWrapper:
         return java_obj(*java_args)
 
     @staticmethod
-    def _new_java_array(pylist, java_class):
+    def _new_java_array(pylist: List[Any], java_class: "JavaClass") -> 
"JavaObject":
         """
         Create a Java array of given java_class type. Useful for
         calling a method with a Scala Array from Python with Py4J.
@@ -97,6 +116,9 @@ class JavaWrapper:
           Java Array of converted pylist.
         """
         sc = SparkContext._active_spark_context
+        assert sc is not None
+        assert sc._gateway is not None
+
         java_array = None
         if len(pylist) > 0 and isinstance(pylist[0], list):
             # If pylist is a 2D array, then a 2D java array will be created.
@@ -125,20 +147,24 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
     #: The param values in the Java object should be
     #: synced with the Python wrapper in fit/transform/evaluate/copy.
 
-    def _make_java_param_pair(self, param, value):
+    def _make_java_param_pair(self, param: Param[T], value: T) -> "JavaObject":
         """
         Makes a Java param pair.
         """
         sc = SparkContext._active_spark_context
+        assert sc is not None and self._java_obj is not None
+
         param = self._resolveParam(param)
         java_param = self._java_obj.getParam(param.name)
         java_value = _py2java(sc, value)
         return java_param.w(java_value)
 
-    def _transfer_params_to_java(self):
+    def _transfer_params_to_java(self) -> None:
         """
         Transforms the embedded params to the companion Java object.
         """
+        assert self._java_obj is not None
+
         pair_defaults = []
         for param in self.params:
             if self.isSet(param):
@@ -149,10 +175,12 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
                 pair_defaults.append(pair)
         if len(pair_defaults) > 0:
             sc = SparkContext._active_spark_context
+            assert sc is not None and sc._jvm is not None
+
             pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults)
             self._java_obj.setDefault(pair_defaults_seq)
 
-    def _transfer_param_map_to_java(self, pyParamMap):
+    def _transfer_param_map_to_java(self, pyParamMap: "ParamMap") -> 
"JavaObject":
         """
         Transforms a Python ParamMap into a Java ParamMap.
         """
@@ -163,26 +191,30 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
                 paramMap.put([pair])
         return paramMap
 
-    def _create_params_from_java(self):
+    def _create_params_from_java(self) -> None:
         """
         SPARK-10931: Temporary fix to create params that are defined in the 
Java obj but not here
         """
+        assert self._java_obj is not None
+
         java_params = list(self._java_obj.params())
         from pyspark.ml.param import Param
 
         for java_param in java_params:
             java_param_name = java_param.name()
             if not hasattr(self, java_param_name):
-                param = Param(self, java_param_name, java_param.doc())
+                param: Param[Any] = Param(self, java_param_name, 
java_param.doc())
                 setattr(param, "created_from_java_param", True)
                 setattr(self, java_param_name, param)
                 self._params = None  # need to reset so self.params will 
discover new params
 
-    def _transfer_params_from_java(self):
+    def _transfer_params_from_java(self) -> None:
         """
         Transforms the embedded params from the companion Java object.
         """
         sc = SparkContext._active_spark_context
+        assert sc is not None and self._java_obj is not None
+
         for param in self.params:
             if self._java_obj.hasParam(param.name):
                 java_param = self._java_obj.getParam(param.name)
@@ -195,11 +227,13 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
                     value = _java2py(sc, 
self._java_obj.getDefault(java_param)).get()
                     self._setDefault(**{param.name: value})
 
-    def _transfer_param_map_from_java(self, javaParamMap):
+    def _transfer_param_map_from_java(self, javaParamMap: "JavaObject") -> 
"ParamMap":
         """
         Transforms a Java ParamMap into a Python ParamMap.
         """
         sc = SparkContext._active_spark_context
+        assert sc is not None
+
         paramMap = dict()
         for pair in javaParamMap.toList():
             param = pair.param()
@@ -208,13 +242,13 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
         return paramMap
 
     @staticmethod
-    def _empty_java_param_map():
+    def _empty_java_param_map() -> "JavaObject":
         """
         Returns an empty Java ParamMap reference.
         """
         return _jvm().org.apache.spark.ml.param.ParamMap()
 
-    def _to_java(self):
+    def _to_java(self) -> "JavaObject":
         """
         Transfer this instance's Params to the wrapped Java object, and return 
the Java object.
         Used for ML persistence.
@@ -230,7 +264,7 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
         return self._java_obj
 
     @staticmethod
-    def _from_java(java_stage):
+    def _from_java(java_stage: "JavaObject") -> "JP":
         """
         Given a Java object, create and return a Python wrapper of it.
         Used for ML persistence.
@@ -238,7 +272,7 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
         Meta-algorithms such as Pipeline should override this method as a 
classmethod.
         """
 
-        def __get_class(clazz):
+        def __get_class(clazz: str) -> Type[JP]:
             """
             Loads Python class from its name.
             """
@@ -271,7 +305,7 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
             )
         return py_stage
 
-    def copy(self, extra=None):
+    def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP":
         """
         Creates a copy of this instance with the same uid and some
         extra params. This implementation first calls Params.copy and
@@ -297,30 +331,32 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
             that._transfer_params_to_java()
         return that
 
-    def clear(self, param):
+    def clear(self, param: Param) -> None:
         """
         Clears a param from the param map if it has been explicitly set.
         """
+        assert self._java_obj is not None
+
         super(JavaParams, self).clear(param)
         java_param = self._java_obj.getParam(param.name)
         self._java_obj.clear(java_param)
 
 
 @inherit_doc
-class JavaEstimator(JavaParams, Estimator, metaclass=ABCMeta):
+class JavaEstimator(JavaParams, Estimator[JM], metaclass=ABCMeta):
     """
     Base class for :py:class:`Estimator`s that wrap Java/Scala
     implementations.
     """
 
     @abstractmethod
-    def _create_model(self, java_model):
+    def _create_model(self, java_model: "JavaObject") -> JM:
         """
         Creates a model from the input Java model reference.
         """
         raise NotImplementedError()
 
-    def _fit_java(self, dataset):
+    def _fit_java(self, dataset: DataFrame) -> "JavaObject":
         """
         Fits a Java model to the input dataset.
 
@@ -334,10 +370,12 @@ class JavaEstimator(JavaParams, Estimator, 
metaclass=ABCMeta):
         py4j.java_gateway.JavaObject
             fitted Java model
         """
+        assert self._java_obj is not None
+
         self._transfer_params_to_java()
         return self._java_obj.fit(dataset._jdf)
 
-    def _fit(self, dataset):
+    def _fit(self, dataset: DataFrame) -> JM:
         java_model = self._fit_java(dataset)
         model = self._create_model(java_model)
         return self._copyValues(model)
@@ -351,7 +389,9 @@ class JavaTransformer(JavaParams, Transformer, 
metaclass=ABCMeta):
     available as _java_obj.
     """
 
-    def _transform(self, dataset):
+    def _transform(self, dataset: DataFrame) -> DataFrame:
+        assert self._java_obj is not None
+
         self._transfer_params_to_java()
         return DataFrame(self._java_obj.transform(dataset._jdf), 
dataset.sql_ctx)
 
@@ -364,7 +404,7 @@ class JavaModel(JavaTransformer, Model, metaclass=ABCMeta):
     param mix-ins, because this sets the UID from the Java model.
     """
 
-    def __init__(self, java_model=None):
+    def __init__(self, java_model: Optional["JavaObject"] = None):
         """
         Initialize this instance with a Java model object.
         Subclasses should call this constructor, initialize params,
@@ -388,12 +428,12 @@ class JavaModel(JavaTransformer, Model, 
metaclass=ABCMeta):
 
             self._resetUid(java_model.uid())
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self._call_java("toString")
 
 
 @inherit_doc
-class JavaPredictor(Predictor, JavaEstimator, _PredictorParams, 
metaclass=ABCMeta):
+class JavaPredictor(Predictor, JavaEstimator[JM], _PredictorParams, 
Generic[JM], metaclass=ABCMeta):
     """
     (Private) Java Estimator for prediction tasks (regression and 
classification).
     """
@@ -402,21 +442,21 @@ class JavaPredictor(Predictor, JavaEstimator, 
_PredictorParams, metaclass=ABCMet
 
 
 @inherit_doc
-class JavaPredictionModel(PredictionModel, JavaModel, _PredictorParams):
+class JavaPredictionModel(PredictionModel[T], JavaModel, _PredictorParams):
     """
     (Private) Java Model for prediction tasks (regression and classification).
     """
 
-    @property
+    @property  # type: ignore[misc]
     @since("2.1.0")
-    def numFeatures(self):
+    def numFeatures(self) -> int:
         """
         Returns the number of features the model was trained on. If unknown, 
returns -1
         """
         return self._call_java("numFeatures")
 
     @since("3.0.0")
-    def predict(self, value):
+    def predict(self, value: T) -> float:
         """
         Predict label for the given features.
         """
diff --git a/python/pyspark/ml/wrapper.pyi b/python/pyspark/ml/wrapper.pyi
deleted file mode 100644
index 7b3bfb4..0000000
--- a/python/pyspark/ml/wrapper.pyi
+++ /dev/null
@@ -1,51 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import abc
-from typing import Any, Optional, Generic
-from pyspark.ml._typing import P, T, JM, ParamMap
-
-from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, 
Model
-from pyspark.ml.base import _PredictorParams
-from pyspark.ml.param import Param, Params
-from pyspark.sql.dataframe import DataFrame
-
-class JavaWrapper:
-    def __init__(self, java_obj: Optional[Any] = ...) -> None: ...
-    def __del__(self) -> None: ...
-    def _call_java(self, name: str, *args: Any) -> Any: ...
-
-class JavaParams(JavaWrapper, Params, metaclass=abc.ABCMeta):
-    def copy(self: P, extra: Optional[ParamMap] = ...) -> P: ...
-    def clear(self, param: Param) -> None: ...
-
-class JavaEstimator(Generic[JM], JavaParams, Estimator[JM], 
metaclass=abc.ABCMeta):
-    def _fit(self, dataset: DataFrame) -> JM: ...
-
-class JavaTransformer(JavaParams, Transformer, metaclass=abc.ABCMeta):
-    def _transform(self, dataset: DataFrame) -> DataFrame: ...
-
-class JavaModel(JavaTransformer, Model, metaclass=abc.ABCMeta):
-    def __init__(self, java_model: Optional[Any] = ...) -> None: ...
-
-class JavaPredictor(Predictor[JM], JavaEstimator, _PredictorParams, 
metaclass=abc.ABCMeta): ...
-
-class JavaPredictionModel(PredictionModel[T], JavaModel, _PredictorParams):
-    @property
-    def numFeatures(self) -> int: ...
-    def predict(self, value: T) -> float: ...

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to