This is an automated email from the ASF dual-hosted git repository.
zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 711d0b4 [SPARK-37416][PYTHON][ML] Inline type hints for
pyspark.ml.wrapper
711d0b4 is described below
commit 711d0b4aac45f5bb69f91a774e94063a769d31d0
Author: zero323 <[email protected]>
AuthorDate: Mon Feb 7 00:02:33 2022 +0100
[SPARK-37416][PYTHON][ML] Inline type hints for pyspark.ml.wrapper
### What changes were proposed in this pull request?
This PR migrates type `pyspark.ml.wrapper` annotations from stub file to
inline type hints.
### Why are the changes needed?
Part of ongoing migration of type hints.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #35399 from zero323/SPARK-37416.
Authored-by: zero323 <[email protected]>
Signed-off-by: zero323 <[email protected]>
---
python/pyspark/ml/classification.pyi | 16 +++++-
python/pyspark/ml/clustering.pyi | 6 ++
python/pyspark/ml/feature.pyi | 23 +++++++-
python/pyspark/ml/fpm.pyi | 3 +
python/pyspark/ml/recommendation.pyi | 3 +
python/pyspark/ml/regression.pyi | 14 ++++-
python/pyspark/ml/wrapper.py | 108 ++++++++++++++++++++++++-----------
python/pyspark/ml/wrapper.pyi | 51 -----------------
8 files changed, 133 insertions(+), 91 deletions(-)
diff --git a/python/pyspark/ml/classification.pyi
b/python/pyspark/ml/classification.pyi
index 4170a8c..89089a4 100644
--- a/python/pyspark/ml/classification.pyi
+++ b/python/pyspark/ml/classification.pyi
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List, Optional, Type
+from typing import Any, Generic, List, Optional, Type
from pyspark.ml._typing import JM, M, P, T, ParamMap
import abc
@@ -69,6 +69,8 @@ from pyspark.ml.param import Param
from pyspark.ml.regression import DecisionTreeRegressionModel
from pyspark.sql.dataframe import DataFrame
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
class _ClassifierParams(HasRawPredictionCol, _PredictorParams): ...
class Classifier(Predictor, _ClassifierParams, metaclass=abc.ABCMeta):
@@ -96,7 +98,7 @@ class ProbabilisticClassificationModel(
@abstractmethod
def predictProbability(self, value: Vector) -> Vector: ...
-class _JavaClassifier(Classifier, JavaPredictor[JM], metaclass=abc.ABCMeta):
+class _JavaClassifier(Classifier, JavaPredictor[JM], Generic[JM],
metaclass=abc.ABCMeta):
def setRawPredictionCol(self: P, value: str) -> P: ...
class _JavaClassificationModel(ClassificationModel, JavaPredictionModel[T]):
@@ -105,7 +107,7 @@ class _JavaClassificationModel(ClassificationModel,
JavaPredictionModel[T]):
def predictRaw(self, value: Vector) -> Vector: ...
class _JavaProbabilisticClassifier(
- ProbabilisticClassifier, _JavaClassifier[JM], metaclass=abc.ABCMeta
+ ProbabilisticClassifier, _JavaClassifier[JM], Generic[JM],
metaclass=abc.ABCMeta
): ...
class _JavaProbabilisticClassificationModel(
@@ -231,6 +233,7 @@ class LinearSVC(
def setWeightCol(self, value: str) -> LinearSVC: ...
def setAggregationDepth(self, value: int) -> LinearSVC: ...
def setMaxBlockSizeInMB(self, value: float) -> LinearSVC: ...
+ def _create_model(self, java_model: JavaObject) -> LinearSVCModel: ...
class LinearSVCModel(
_JavaClassificationModel[Vector],
@@ -350,6 +353,7 @@ class LogisticRegression(
def setWeightCol(self, value: str) -> LogisticRegression: ...
def setAggregationDepth(self, value: int) -> LogisticRegression: ...
def setMaxBlockSizeInMB(self, value: float) -> LogisticRegression: ...
+ def _create_model(self, java_model: JavaObject) ->
LogisticRegressionModel: ...
class LogisticRegressionModel(
_JavaProbabilisticClassificationModel[Vector],
@@ -444,6 +448,7 @@ class DecisionTreeClassifier(
def setCheckpointInterval(self, value: int) -> DecisionTreeClassifier: ...
def setSeed(self, value: int) -> DecisionTreeClassifier: ...
def setWeightCol(self, value: str) -> DecisionTreeClassifier: ...
+ def _create_model(self, java_model: JavaObject) ->
DecisionTreeClassificationModel: ...
class DecisionTreeClassificationModel(
_DecisionTreeModel,
@@ -529,6 +534,7 @@ class RandomForestClassifier(
def setCheckpointInterval(self, value: int) -> RandomForestClassifier: ...
def setWeightCol(self, value: str) -> RandomForestClassifier: ...
def setMinWeightFractionPerNode(self, value: float) ->
RandomForestClassifier: ...
+ def _create_model(self, java_model: JavaObject) ->
RandomForestClassificationModel: ...
class RandomForestClassificationModel(
_TreeEnsembleModel,
@@ -633,6 +639,7 @@ class GBTClassifier(
def setStepSize(self, value: float) -> GBTClassifier: ...
def setWeightCol(self, value: str) -> GBTClassifier: ...
def setMinWeightFractionPerNode(self, value: float) -> GBTClassifier: ...
+ def _create_model(self, java_model: JavaObject) -> GBTClassificationModel:
...
class GBTClassificationModel(
_TreeEnsembleModel,
@@ -691,6 +698,7 @@ class NaiveBayes(
def setSmoothing(self, value: float) -> NaiveBayes: ...
def setModelType(self, value: str) -> NaiveBayes: ...
def setWeightCol(self, value: str) -> NaiveBayes: ...
+ def _create_model(self, java_model: JavaObject) -> NaiveBayesModel: ...
class NaiveBayesModel(
_JavaProbabilisticClassificationModel[Vector],
@@ -769,6 +777,7 @@ class MultilayerPerceptronClassifier(
def setTol(self, value: float) -> MultilayerPerceptronClassifier: ...
def setStepSize(self, value: float) -> MultilayerPerceptronClassifier: ...
def setSolver(self, value: str) -> MultilayerPerceptronClassifier: ...
+ def _create_model(self, java_model: JavaObject) ->
MultilayerPerceptronClassificationModel: ...
class MultilayerPerceptronClassificationModel(
_JavaProbabilisticClassificationModel[Vector],
@@ -921,6 +930,7 @@ class FMClassifier(
def setSeed(self, value: int) -> FMClassifier: ...
def setFitIntercept(self, value: bool) -> FMClassifier: ...
def setRegParam(self, value: float) -> FMClassifier: ...
+ def _create_model(self, java_model: JavaObject) -> FMClassificationModel:
...
class FMClassificationModel(
_JavaProbabilisticClassificationModel[Vector],
diff --git a/python/pyspark/ml/clustering.pyi b/python/pyspark/ml/clustering.pyi
index 81074fc..e0ee3d6 100644
--- a/python/pyspark/ml/clustering.pyi
+++ b/python/pyspark/ml/clustering.pyi
@@ -45,6 +45,8 @@ from pyspark.sql.dataframe import DataFrame
from numpy import ndarray
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
class ClusteringSummary(JavaWrapper):
@property
def predictionCol(self) -> str: ...
@@ -137,6 +139,7 @@ class GaussianMixture(
def setSeed(self, value: int) -> GaussianMixture: ...
def setTol(self, value: float) -> GaussianMixture: ...
def setAggregationDepth(self, value: int) -> GaussianMixture: ...
+ def _create_model(self, java_model: JavaObject) -> GaussianMixtureModel:
...
class GaussianMixtureSummary(ClusteringSummary):
@property
@@ -219,6 +222,7 @@ class KMeans(JavaEstimator[KMeansModel], _KMeansParams,
JavaMLWritable, JavaMLRe
def setSeed(self, value: int) -> KMeans: ...
def setTol(self, value: float) -> KMeans: ...
def setWeightCol(self, value: str) -> KMeans: ...
+ def _create_model(self, java_model: JavaObject) -> KMeansModel: ...
class _BisectingKMeansParams(
HasMaxIter,
@@ -287,6 +291,7 @@ class BisectingKMeans(
def setPredictionCol(self, value: str) -> BisectingKMeans: ...
def setSeed(self, value: int) -> BisectingKMeans: ...
def setWeightCol(self, value: str) -> BisectingKMeans: ...
+ def _create_model(self, java_model: JavaObject) -> BisectingKMeansModel:
...
class BisectingKMeansSummary(ClusteringSummary):
@property
@@ -386,6 +391,7 @@ class LDA(JavaEstimator[LDAModel], _LDAParams,
JavaMLReadable[LDA], JavaMLWritab
def setKeepLastCheckpoint(self, value: bool) -> LDA: ...
def setMaxIter(self, value: int) -> LDA: ...
def setFeaturesCol(self, value: str) -> LDA: ...
+ def _create_model(self, java_model: JavaObject) -> LDAModel: ...
class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
k: Param[int]
diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi
index 6efc304..ecfd26e 100644
--- a/python/pyspark/ml/feature.pyi
+++ b/python/pyspark/ml/feature.pyi
@@ -42,6 +42,8 @@ from pyspark.ml.linalg import Vector, DenseVector, DenseMatrix
from pyspark.sql.dataframe import DataFrame
from pyspark.ml.param import Param
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
class Binarizer(
JavaTransformer,
HasThreshold,
@@ -103,6 +105,7 @@ class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams,
JavaMLReadable, JavaMLWri
def setNumHashTables(self: P, value: int) -> P: ...
def setInputCol(self: P, value: str) -> P: ...
def setOutputCol(self: P, value: str) -> P: ...
+ def _create_model(self, java_model: JavaObject) -> JM: ...
class _LSHModel(JavaModel, _LSHParams):
def setInputCol(self: P, value: str) -> P: ...
@@ -268,6 +271,7 @@ class CountVectorizer(
def setBinary(self, value: bool) -> CountVectorizer: ...
def setInputCol(self, value: str) -> CountVectorizer: ...
def setOutputCol(self, value: str) -> CountVectorizer: ...
+ def _create_model(self, java_model: JavaObject) -> CountVectorizerModel:
...
class CountVectorizerModel(JavaModel, JavaMLReadable[CountVectorizerModel],
JavaMLWritable):
def setInputCol(self, value: str) -> CountVectorizerModel: ...
@@ -412,6 +416,7 @@ class IDF(JavaEstimator[IDFModel], _IDFParams,
JavaMLReadable[IDF], JavaMLWritab
def setMinDocFreq(self, value: int) -> IDF: ...
def setInputCol(self, value: str) -> IDF: ...
def setOutputCol(self, value: str) -> IDF: ...
+ def _create_model(self, java_model: JavaObject) -> IDFModel: ...
class IDFModel(JavaModel, _IDFParams, JavaMLReadable[IDFModel],
JavaMLWritable):
def setInputCol(self, value: str) -> IDFModel: ...
@@ -477,6 +482,7 @@ class Imputer(JavaEstimator[ImputerModel], _ImputerParams,
JavaMLReadable[Impute
def setInputCol(self, value: str) -> Imputer: ...
def setOutputCol(self, value: str) -> Imputer: ...
def setRelativeError(self, value: float) -> Imputer: ...
+ def _create_model(self, java_model: JavaObject) -> ImputerModel: ...
class ImputerModel(JavaModel, _ImputerParams, JavaMLReadable[ImputerModel],
JavaMLWritable):
def setInputCols(self, value: List[str]) -> ImputerModel: ...
@@ -518,6 +524,7 @@ class MaxAbsScaler(
) -> MaxAbsScaler: ...
def setInputCol(self, value: str) -> MaxAbsScaler: ...
def setOutputCol(self, value: str) -> MaxAbsScaler: ...
+ def _create_model(self, java_model: JavaObject) -> MaxAbsScalerModel: ...
class MaxAbsScalerModel(
JavaModel, _MaxAbsScalerParams, JavaMLReadable[MaxAbsScalerModel],
JavaMLWritable
@@ -588,6 +595,7 @@ class MinMaxScaler(
def setMax(self, value: float) -> MinMaxScaler: ...
def setInputCol(self, value: str) -> MinMaxScaler: ...
def setOutputCol(self, value: str) -> MinMaxScaler: ...
+ def _create_model(self, java_model: JavaObject) -> MinMaxScalerModel: ...
class MinMaxScalerModel(
JavaModel, _MinMaxScalerParams, JavaMLReadable[MinMaxScalerModel],
JavaMLWritable
@@ -687,6 +695,7 @@ class OneHotEncoder(
def setHandleInvalid(self, value: str) -> OneHotEncoder: ...
def setInputCol(self, value: str) -> OneHotEncoder: ...
def setOutputCol(self, value: str) -> OneHotEncoder: ...
+ def _create_model(self, java_model: JavaObject) -> OneHotEncoderModel: ...
class OneHotEncoderModel(
JavaModel, _OneHotEncoderParams, JavaMLReadable[OneHotEncoderModel],
JavaMLWritable
@@ -783,6 +792,7 @@ class QuantileDiscretizer(
def setOutputCol(self, value: str) -> QuantileDiscretizer: ...
def setOutputCols(self, value: List[str]) -> QuantileDiscretizer: ...
def setHandleInvalid(self, value: str) -> QuantileDiscretizer: ...
+ def _create_model(self, java_model: JavaObject) -> Bucketizer: ...
class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
lower: Param[float]
@@ -827,6 +837,7 @@ class RobustScaler(
def setInputCol(self, value: str) -> RobustScaler: ...
def setOutputCol(self, value: str) -> RobustScaler: ...
def setRelativeError(self, value: float) -> RobustScaler: ...
+ def _create_model(self, java_model: JavaObject) -> RobustScalerModel: ...
class RobustScalerModel(
JavaModel, _RobustScalerParams, JavaMLReadable[RobustScalerModel],
JavaMLWritable
@@ -920,6 +931,7 @@ class StandardScaler(
def setWithStd(self, value: bool) -> StandardScaler: ...
def setInputCol(self, value: str) -> StandardScaler: ...
def setOutputCol(self, value: str) -> StandardScaler: ...
+ def _create_model(self, java_model: JavaObject) -> StandardScalerModel: ...
class StandardScalerModel(
JavaModel,
@@ -990,6 +1002,7 @@ class StringIndexer(
def setOutputCol(self, value: str) -> StringIndexer: ...
def setOutputCols(self, value: List[str]) -> StringIndexer: ...
def setHandleInvalid(self, value: str) -> StringIndexer: ...
+ def _create_model(self, java_model: JavaObject) -> StringIndexerModel: ...
class StringIndexerModel(
JavaModel, _StringIndexerParams, JavaMLReadable[StringIndexerModel],
JavaMLWritable
@@ -1186,6 +1199,7 @@ class VectorIndexer(
def setInputCol(self, value: str) -> VectorIndexer: ...
def setOutputCol(self, value: str) -> VectorIndexer: ...
def setHandleInvalid(self, value: str) -> VectorIndexer: ...
+ def _create_model(self, java_model: JavaObject) -> VectorIndexerModel: ...
class VectorIndexerModel(
JavaModel, _VectorIndexerParams, JavaMLReadable[VectorIndexerModel],
JavaMLWritable
@@ -1286,6 +1300,7 @@ class Word2Vec(
def setOutputCol(self, value: str) -> Word2Vec: ...
def setSeed(self, value: int) -> Word2Vec: ...
def setStepSize(self, value: float) -> Word2Vec: ...
+ def _create_model(self, java_model: JavaObject) -> Word2VecModel: ...
class Word2VecModel(JavaModel, _Word2VecParams, JavaMLReadable[Word2VecModel],
JavaMLWritable):
def getVectors(self) -> DataFrame: ...
@@ -1322,6 +1337,7 @@ class PCA(JavaEstimator[PCAModel], _PCAParams,
JavaMLReadable[PCA], JavaMLWritab
def setK(self, value: int) -> PCA: ...
def setInputCol(self, value: str) -> PCA: ...
def setOutputCol(self, value: str) -> PCA: ...
+ def _create_model(self, java_model: JavaObject) -> PCAModel: ...
class PCAModel(JavaModel, _PCAParams, JavaMLReadable[PCAModel],
JavaMLWritable):
def setInputCol(self, value: str) -> PCAModel: ...
@@ -1373,6 +1389,7 @@ class RFormula(
def setFeaturesCol(self, value: str) -> RFormula: ...
def setLabelCol(self, value: str) -> RFormula: ...
def setHandleInvalid(self, value: str) -> RFormula: ...
+ def _create_model(self, java_model: JavaObject) -> RFormulaModel: ...
class RFormulaModel(JavaModel, _RFormulaParams, JavaMLReadable[RFormulaModel],
JavaMLWritable): ...
@@ -1391,7 +1408,7 @@ class _SelectorParams(HasFeaturesCol, HasOutputCol,
HasLabelCol):
def getFdr(self) -> float: ...
def getFwe(self) -> float: ...
-class _Selector(JavaEstimator[JM], _SelectorParams, JavaMLReadable,
JavaMLWritable):
+class _Selector(JavaEstimator[JM], _SelectorParams, JavaMLReadable,
JavaMLWritable, Generic[JM]):
def setSelectorType(self: P, value: str) -> P: ...
def setNumTopFeatures(self: P, value: int) -> P: ...
def setPercentile(self: P, value: float) -> P: ...
@@ -1401,6 +1418,7 @@ class _Selector(JavaEstimator[JM], _SelectorParams,
JavaMLReadable, JavaMLWritab
def setFeaturesCol(self: P, value: str) -> P: ...
def setOutputCol(self: P, value: str) -> P: ...
def setLabelCol(self: P, value: str) -> P: ...
+ def _create_model(self, java_model: JavaObject) -> JM: ...
class _SelectorModel(JavaModel, _SelectorParams):
def setFeaturesCol(self: P, value: str) -> P: ...
@@ -1448,6 +1466,7 @@ class ChiSqSelector(
def setFeaturesCol(self, value: str) -> ChiSqSelector: ...
def setOutputCol(self, value: str) -> ChiSqSelector: ...
def setLabelCol(self, value: str) -> ChiSqSelector: ...
+ def _create_model(self, java_model: JavaObject) -> ChiSqSelectorModel: ...
class ChiSqSelectorModel(_SelectorModel, JavaMLReadable[ChiSqSelectorModel],
JavaMLWritable):
def setFeaturesCol(self, value: str) -> ChiSqSelectorModel: ...
@@ -1500,6 +1519,7 @@ class VarianceThresholdSelector(
def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector:
...
def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ...
def setOutputCol(self, value: str) -> VarianceThresholdSelector: ...
+ def _create_model(self, java_model: JavaObject) ->
VarianceThresholdSelectorModel: ...
class VarianceThresholdSelectorModel(
JavaModel,
@@ -1552,6 +1572,7 @@ class UnivariateFeatureSelector(
def setFeaturesCol(self, value: str) -> UnivariateFeatureSelector: ...
def setOutputCol(self, value: str) -> UnivariateFeatureSelector: ...
def setLabelCol(self, value: str) -> UnivariateFeatureSelector: ...
+ def _create_model(self, java_model: JavaObject) ->
UnivariateFeatureSelectorModel: ...
class UnivariateFeatureSelectorModel(
JavaModel,
diff --git a/python/pyspark/ml/fpm.pyi b/python/pyspark/ml/fpm.pyi
index 609bc44..00d5c5f 100644
--- a/python/pyspark/ml/fpm.pyi
+++ b/python/pyspark/ml/fpm.pyi
@@ -25,6 +25,8 @@ from pyspark.sql.dataframe import DataFrame
from pyspark.ml.param import Param
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
class _FPGrowthParams(HasPredictionCol):
itemsCol: Param[str]
minSupport: Param[float]
@@ -74,6 +76,7 @@ class FPGrowth(
def setNumPartitions(self, value: int) -> FPGrowth: ...
def setMinConfidence(self, value: float) -> FPGrowth: ...
def setPredictionCol(self, value: str) -> FPGrowth: ...
+ def _create_model(self, java_model: JavaObject) -> FPGrowthModel: ...
class PrefixSpan(JavaParams):
minSupport: Param[float]
diff --git a/python/pyspark/ml/recommendation.pyi
b/python/pyspark/ml/recommendation.pyi
index 6ce178b..f7faaca 100644
--- a/python/pyspark/ml/recommendation.pyi
+++ b/python/pyspark/ml/recommendation.pyi
@@ -36,6 +36,8 @@ from pyspark.ml.util import JavaMLWritable, JavaMLReadable
from pyspark.sql.dataframe import DataFrame
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
class _ALSModelParams(HasPredictionCol, HasBlockSize):
userCol: Param[str]
itemCol: Param[str]
@@ -127,6 +129,7 @@ class ALS(JavaEstimator[ALSModel], _ALSParams,
JavaMLWritable, JavaMLReadable[AL
def setCheckpointInterval(self, value: int) -> ALS: ...
def setSeed(self, value: int) -> ALS: ...
def setBlockSize(self, value: int) -> ALS: ...
+ def _create_model(self, java_model: JavaObject) -> ALSModel: ...
class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable,
JavaMLReadable[ALSModel]):
def setUserCol(self, value: str) -> ALSModel: ...
diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi
index 3b553b1..750e4c7 100644
--- a/python/pyspark/ml/regression.pyi
+++ b/python/pyspark/ml/regression.pyi
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List, Optional
+from typing import Any, Generic, List, Optional
from pyspark.ml._typing import JM, M, T
import abc
@@ -67,9 +67,11 @@ from pyspark.ml.linalg import Matrix, Vector
from pyspark.ml.param import Param
from pyspark.sql.dataframe import DataFrame
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
class Regressor(Predictor[M], _PredictorParams, metaclass=abc.ABCMeta): ...
class RegressionModel(PredictionModel[T], _PredictorParams,
metaclass=abc.ABCMeta): ...
-class _JavaRegressor(Regressor, JavaPredictor[JM], metaclass=abc.ABCMeta): ...
+class _JavaRegressor(Regressor, JavaPredictor[JM], Generic[JM],
metaclass=abc.ABCMeta): ...
class _JavaRegressionModel(RegressionModel, JavaPredictionModel[T],
metaclass=abc.ABCMeta): ...
class _LinearRegressionParams(
@@ -146,6 +148,7 @@ class LinearRegression(
def setAggregationDepth(self, value: int) -> LinearRegression: ...
def setLoss(self, value: str) -> LinearRegression: ...
def setMaxBlockSizeInMB(self, value: float) -> LinearRegression: ...
+ def _create_model(self, java_model: JavaObject) -> LinearRegressionModel:
...
class LinearRegressionModel(
_JavaRegressionModel[Vector],
@@ -241,6 +244,7 @@ class IsotonicRegression(
def setPredictionCol(self, value: str) -> IsotonicRegression: ...
def setLabelCol(self, value: str) -> IsotonicRegression: ...
def setWeightCol(self, value: str) -> IsotonicRegression: ...
+ def _create_model(self, java_model: JavaObject) ->
IsotonicRegressionModel: ...
class IsotonicRegressionModel(
JavaModel,
@@ -320,6 +324,7 @@ class DecisionTreeRegressor(
def setSeed(self, value: int) -> DecisionTreeRegressor: ...
def setWeightCol(self, value: str) -> DecisionTreeRegressor: ...
def setVarianceCol(self, value: str) -> DecisionTreeRegressor: ...
+ def _create_model(self, java_model: JavaObject) ->
DecisionTreeRegressionModel: ...
class DecisionTreeRegressionModel(
_JavaRegressionModel[Vector],
@@ -402,6 +407,7 @@ class RandomForestRegressor(
def setSeed(self, value: int) -> RandomForestRegressor: ...
def setWeightCol(self, value: str) -> RandomForestRegressor: ...
def setMinWeightFractionPerNode(self, value: float) ->
RandomForestRegressor: ...
+ def _create_model(self, java_model: JavaObject) ->
RandomForestRegressionModel: ...
class RandomForestRegressionModel(
_JavaRegressionModel[Vector],
@@ -496,6 +502,7 @@ class GBTRegressor(
def setStepSize(self, value: float) -> GBTRegressor: ...
def setWeightCol(self, value: str) -> GBTRegressor: ...
def setMinWeightFractionPerNode(self, value: float) -> GBTRegressor: ...
+ def _create_model(self, java_model: JavaObject) -> GBTRegressionModel: ...
class GBTRegressionModel(
_JavaRegressionModel[Vector],
@@ -570,6 +577,7 @@ class AFTSurvivalRegression(
def setFitIntercept(self, value: bool) -> AFTSurvivalRegression: ...
def setAggregationDepth(self, value: int) -> AFTSurvivalRegression: ...
def setMaxBlockSizeInMB(self, value: float) -> AFTSurvivalRegression: ...
+ def _create_model(self, java_model: JavaObject) ->
AFTSurvivalRegressionModel: ...
class AFTSurvivalRegressionModel(
_JavaRegressionModel[Vector],
@@ -672,6 +680,7 @@ class GeneralizedLinearRegression(
def setWeightCol(self, value: str) -> GeneralizedLinearRegression: ...
def setSolver(self, value: str) -> GeneralizedLinearRegression: ...
def setAggregationDepth(self, value: int) -> GeneralizedLinearRegression:
...
+ def _create_model(self, java_model: JavaObject) ->
GeneralizedLinearRegressionModel: ...
class GeneralizedLinearRegressionModel(
_JavaRegressionModel[Vector],
@@ -802,6 +811,7 @@ class FMRegressor(
def setSeed(self, value: int) -> FMRegressor: ...
def setFitIntercept(self, value: bool) -> FMRegressor: ...
def setRegParam(self, value: float) -> FMRegressor: ...
+ def _create_model(self, java_model: JavaObject) -> FMRegressionModel: ...
class FMRegressionModel(
_JavaRegressionModel,
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index c35df2e..7f03f64 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -17,49 +17,68 @@
from abc import ABCMeta, abstractmethod
+from typing import Any, Generic, Optional, List, Type, TypeVar, TYPE_CHECKING
+
from pyspark import since
from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer,
Model
from pyspark.ml.base import _PredictorParams
-from pyspark.ml.param import Params
-from pyspark.ml.util import _jvm
+from pyspark.ml.param import Param, Params
+from pyspark.ml.util import _jvm # type: ignore[attr-defined]
from pyspark.ml.common import inherit_doc, _java2py, _py2java
+if TYPE_CHECKING:
+ from pyspark.ml._typing import ParamMap
+ from py4j.java_gateway import JavaObject, JavaClass
+
+
+T = TypeVar("T")
+JW = TypeVar("JW", bound="JavaWrapper")
+JM = TypeVar("JM", bound="JavaTransformer")
+JP = TypeVar("JP", bound="JavaParams")
+
+
class JavaWrapper:
"""
Wrapper class for a Java companion object
"""
- def __init__(self, java_obj=None):
+ def __init__(self, java_obj: Optional["JavaObject"] = None):
super(JavaWrapper, self).__init__()
self._java_obj = java_obj
- def __del__(self):
+ def __del__(self) -> None:
if SparkContext._active_spark_context and self._java_obj is not None:
- SparkContext._active_spark_context._gateway.detach(self._java_obj)
+ SparkContext._active_spark_context._gateway.detach( # type:
ignore[union-attr]
+ self._java_obj
+ )
@classmethod
- def _create_from_java_class(cls, java_class, *args):
+ def _create_from_java_class(cls: Type[JW], java_class: str, *args: Any) ->
JW:
"""
Construct this object from given Java classname and arguments
"""
java_obj = JavaWrapper._new_java_obj(java_class, *args)
return cls(java_obj)
- def _call_java(self, name, *args):
+ def _call_java(self, name: str, *args: Any) -> Any:
m = getattr(self._java_obj, name)
sc = SparkContext._active_spark_context
+ assert sc is not None
+
java_args = [_py2java(sc, arg) for arg in args]
return _java2py(sc, m(*java_args))
@staticmethod
- def _new_java_obj(java_class, *args):
+ def _new_java_obj(java_class: str, *args: Any) -> "JavaObject":
"""
Returns a new Java object.
"""
sc = SparkContext._active_spark_context
+ assert sc is not None
+
java_obj = _jvm()
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
@@ -67,7 +86,7 @@ class JavaWrapper:
return java_obj(*java_args)
@staticmethod
- def _new_java_array(pylist, java_class):
+ def _new_java_array(pylist: List[Any], java_class: "JavaClass") ->
"JavaObject":
"""
Create a Java array of given java_class type. Useful for
calling a method with a Scala Array from Python with Py4J.
@@ -97,6 +116,9 @@ class JavaWrapper:
Java Array of converted pylist.
"""
sc = SparkContext._active_spark_context
+ assert sc is not None
+ assert sc._gateway is not None
+
java_array = None
if len(pylist) > 0 and isinstance(pylist[0], list):
# If pylist is a 2D array, then a 2D java array will be created.
@@ -125,20 +147,24 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
#: The param values in the Java object should be
#: synced with the Python wrapper in fit/transform/evaluate/copy.
- def _make_java_param_pair(self, param, value):
+ def _make_java_param_pair(self, param: Param[T], value: T) -> "JavaObject":
"""
Makes a Java param pair.
"""
sc = SparkContext._active_spark_context
+ assert sc is not None and self._java_obj is not None
+
param = self._resolveParam(param)
java_param = self._java_obj.getParam(param.name)
java_value = _py2java(sc, value)
return java_param.w(java_value)
- def _transfer_params_to_java(self):
+ def _transfer_params_to_java(self) -> None:
"""
Transforms the embedded params to the companion Java object.
"""
+ assert self._java_obj is not None
+
pair_defaults = []
for param in self.params:
if self.isSet(param):
@@ -149,10 +175,12 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
pair_defaults.append(pair)
if len(pair_defaults) > 0:
sc = SparkContext._active_spark_context
+ assert sc is not None and sc._jvm is not None
+
pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults)
self._java_obj.setDefault(pair_defaults_seq)
- def _transfer_param_map_to_java(self, pyParamMap):
+ def _transfer_param_map_to_java(self, pyParamMap: "ParamMap") ->
"JavaObject":
"""
Transforms a Python ParamMap into a Java ParamMap.
"""
@@ -163,26 +191,30 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
paramMap.put([pair])
return paramMap
- def _create_params_from_java(self):
+ def _create_params_from_java(self) -> None:
"""
SPARK-10931: Temporary fix to create params that are defined in the
Java obj but not here
"""
+ assert self._java_obj is not None
+
java_params = list(self._java_obj.params())
from pyspark.ml.param import Param
for java_param in java_params:
java_param_name = java_param.name()
if not hasattr(self, java_param_name):
- param = Param(self, java_param_name, java_param.doc())
+ param: Param[Any] = Param(self, java_param_name,
java_param.doc())
setattr(param, "created_from_java_param", True)
setattr(self, java_param_name, param)
self._params = None # need to reset so self.params will
discover new params
- def _transfer_params_from_java(self):
+ def _transfer_params_from_java(self) -> None:
"""
Transforms the embedded params from the companion Java object.
"""
sc = SparkContext._active_spark_context
+ assert sc is not None and self._java_obj is not None
+
for param in self.params:
if self._java_obj.hasParam(param.name):
java_param = self._java_obj.getParam(param.name)
@@ -195,11 +227,13 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
value = _java2py(sc,
self._java_obj.getDefault(java_param)).get()
self._setDefault(**{param.name: value})
- def _transfer_param_map_from_java(self, javaParamMap):
+ def _transfer_param_map_from_java(self, javaParamMap: "JavaObject") ->
"ParamMap":
"""
Transforms a Java ParamMap into a Python ParamMap.
"""
sc = SparkContext._active_spark_context
+ assert sc is not None
+
paramMap = dict()
for pair in javaParamMap.toList():
param = pair.param()
@@ -208,13 +242,13 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
return paramMap
@staticmethod
- def _empty_java_param_map():
+ def _empty_java_param_map() -> "JavaObject":
"""
Returns an empty Java ParamMap reference.
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
- def _to_java(self):
+ def _to_java(self) -> "JavaObject":
"""
Transfer this instance's Params to the wrapped Java object, and return
the Java object.
Used for ML persistence.
@@ -230,7 +264,7 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
return self._java_obj
@staticmethod
- def _from_java(java_stage):
+ def _from_java(java_stage: "JavaObject") -> "JP":
"""
Given a Java object, create and return a Python wrapper of it.
Used for ML persistence.
@@ -238,7 +272,7 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
Meta-algorithms such as Pipeline should override this method as a
classmethod.
"""
- def __get_class(clazz):
+ def __get_class(clazz: str) -> Type[JP]:
"""
Loads Python class from its name.
"""
@@ -271,7 +305,7 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
)
return py_stage
- def copy(self, extra=None):
+ def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP":
"""
Creates a copy of this instance with the same uid and some
extra params. This implementation first calls Params.copy and
@@ -297,30 +331,32 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
that._transfer_params_to_java()
return that
- def clear(self, param):
+ def clear(self, param: Param) -> None:
"""
Clears a param from the param map if it has been explicitly set.
"""
+ assert self._java_obj is not None
+
super(JavaParams, self).clear(param)
java_param = self._java_obj.getParam(param.name)
self._java_obj.clear(java_param)
@inherit_doc
-class JavaEstimator(JavaParams, Estimator, metaclass=ABCMeta):
+class JavaEstimator(JavaParams, Estimator[JM], metaclass=ABCMeta):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
"""
@abstractmethod
- def _create_model(self, java_model):
+ def _create_model(self, java_model: "JavaObject") -> JM:
"""
Creates a model from the input Java model reference.
"""
raise NotImplementedError()
- def _fit_java(self, dataset):
+ def _fit_java(self, dataset: DataFrame) -> "JavaObject":
"""
Fits a Java model to the input dataset.
@@ -334,10 +370,12 @@ class JavaEstimator(JavaParams, Estimator,
metaclass=ABCMeta):
py4j.java_gateway.JavaObject
fitted Java model
"""
+ assert self._java_obj is not None
+
self._transfer_params_to_java()
return self._java_obj.fit(dataset._jdf)
- def _fit(self, dataset):
+ def _fit(self, dataset: DataFrame) -> JM:
java_model = self._fit_java(dataset)
model = self._create_model(java_model)
return self._copyValues(model)
@@ -351,7 +389,9 @@ class JavaTransformer(JavaParams, Transformer,
metaclass=ABCMeta):
available as _java_obj.
"""
- def _transform(self, dataset):
+ def _transform(self, dataset: DataFrame) -> DataFrame:
+ assert self._java_obj is not None
+
self._transfer_params_to_java()
return DataFrame(self._java_obj.transform(dataset._jdf),
dataset.sql_ctx)
@@ -364,7 +404,7 @@ class JavaModel(JavaTransformer, Model, metaclass=ABCMeta):
param mix-ins, because this sets the UID from the Java model.
"""
- def __init__(self, java_model=None):
+ def __init__(self, java_model: Optional["JavaObject"] = None):
"""
Initialize this instance with a Java model object.
Subclasses should call this constructor, initialize params,
@@ -388,12 +428,12 @@ class JavaModel(JavaTransformer, Model,
metaclass=ABCMeta):
self._resetUid(java_model.uid())
- def __repr__(self):
+ def __repr__(self) -> str:
return self._call_java("toString")
@inherit_doc
-class JavaPredictor(Predictor, JavaEstimator, _PredictorParams,
metaclass=ABCMeta):
+class JavaPredictor(Predictor, JavaEstimator[JM], _PredictorParams,
Generic[JM], metaclass=ABCMeta):
"""
(Private) Java Estimator for prediction tasks (regression and
classification).
"""
@@ -402,21 +442,21 @@ class JavaPredictor(Predictor, JavaEstimator,
_PredictorParams, metaclass=ABCMet
@inherit_doc
-class JavaPredictionModel(PredictionModel, JavaModel, _PredictorParams):
+class JavaPredictionModel(PredictionModel[T], JavaModel, _PredictorParams):
"""
(Private) Java Model for prediction tasks (regression and classification).
"""
- @property
+ @property # type: ignore[misc]
@since("2.1.0")
- def numFeatures(self):
+ def numFeatures(self) -> int:
"""
Returns the number of features the model was trained on. If unknown,
returns -1
"""
return self._call_java("numFeatures")
@since("3.0.0")
- def predict(self, value):
+ def predict(self, value: T) -> float:
"""
Predict label for the given features.
"""
diff --git a/python/pyspark/ml/wrapper.pyi b/python/pyspark/ml/wrapper.pyi
deleted file mode 100644
index 7b3bfb4..0000000
--- a/python/pyspark/ml/wrapper.pyi
+++ /dev/null
@@ -1,51 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import abc
-from typing import Any, Optional, Generic
-from pyspark.ml._typing import P, T, JM, ParamMap
-
-from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer,
Model
-from pyspark.ml.base import _PredictorParams
-from pyspark.ml.param import Param, Params
-from pyspark.sql.dataframe import DataFrame
-
-class JavaWrapper:
- def __init__(self, java_obj: Optional[Any] = ...) -> None: ...
- def __del__(self) -> None: ...
- def _call_java(self, name: str, *args: Any) -> Any: ...
-
-class JavaParams(JavaWrapper, Params, metaclass=abc.ABCMeta):
- def copy(self: P, extra: Optional[ParamMap] = ...) -> P: ...
- def clear(self, param: Param) -> None: ...
-
-class JavaEstimator(Generic[JM], JavaParams, Estimator[JM],
metaclass=abc.ABCMeta):
- def _fit(self, dataset: DataFrame) -> JM: ...
-
-class JavaTransformer(JavaParams, Transformer, metaclass=abc.ABCMeta):
- def _transform(self, dataset: DataFrame) -> DataFrame: ...
-
-class JavaModel(JavaTransformer, Model, metaclass=abc.ABCMeta):
- def __init__(self, java_model: Optional[Any] = ...) -> None: ...
-
-class JavaPredictor(Predictor[JM], JavaEstimator, _PredictorParams,
metaclass=abc.ABCMeta): ...
-
-class JavaPredictionModel(PredictionModel[T], JavaModel, _PredictorParams):
- @property
- def numFeatures(self) -> int: ...
- def predict(self, value: T) -> float: ...
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]