This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8559e94  [SPARK-37397][PYTHON] Inline annotations for pyspark.ml.base
8559e94 is described below

commit 8559e94389fb9b2a3f453f568b3e11c79d2b4e2c
Author: zero323 <[email protected]>
AuthorDate: Tue Feb 1 23:22:36 2022 -0800

    [SPARK-37397][PYTHON] Inline annotations for pyspark.ml.base
    
    ### What changes were proposed in this pull request?
    
    Migration of type annotation for `pyspark.ml.base` from stub file to inline 
hints.
    
    ### Why are the changes needed?
    
    As a part of ongoing type hints migration.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests + new data tests.
    
    Closes #35289 from zero323/SPARK-37397.
    
    Authored-by: zero323 <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/ml/base.py                          | 105 ++++++++++++++-------
 python/pyspark/ml/base.pyi                         | 103 --------------------
 python/pyspark/ml/classification.pyi               |   2 +
 python/pyspark/ml/param/__init__.py                |   9 +-
 python/pyspark/ml/pipeline.pyi                     |   3 +
 python/pyspark/ml/tests/typing/test_feature.yml    |  13 ++-
 python/pyspark/ml/tests/typing/test_regression.yml |  15 +++
 python/pyspark/ml/tuning.pyi                       |   5 +
 python/pyspark/ml/wrapper.pyi                      |  10 +-
 9 files changed, 122 insertions(+), 143 deletions(-)

diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index d984209..4f4ddef 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -20,7 +20,25 @@ from abc import ABCMeta, abstractmethod
 import copy
 import threading
 
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+    TYPE_CHECKING,
+)
+
 from pyspark import since
+from pyspark.ml.param import P
 from pyspark.ml.common import inherit_doc
 from pyspark.ml.param.shared import (
     HasInputCol,
@@ -30,11 +48,18 @@ from pyspark.ml.param.shared import (
     HasPredictionCol,
     Params,
 )
+from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.functions import udf
-from pyspark.sql.types import StructField, StructType
+from pyspark.sql.types import DataType, StructField, StructType
+
+if TYPE_CHECKING:
+    from pyspark.ml._typing import ParamMap
 
+T = TypeVar("T")
+M = TypeVar("M", bound="Transformer")
 
-class _FitMultipleIterator:
+
+class _FitMultipleIterator(Generic[M]):
     """
     Used by default implementation of Estimator.fitMultiple to produce models 
in a thread safe
     iterator. This class handles the simple case of fitMultiple where each 
param map should be
@@ -55,17 +80,17 @@ class _FitMultipleIterator:
     See :py:meth:`Estimator.fitMultiple` for more info.
     """
 
-    def __init__(self, fitSingleModel, numModels):
+    def __init__(self, fitSingleModel: Callable[[int], M], numModels: int):
         """ """
         self.fitSingleModel = fitSingleModel
         self.numModel = numModels
         self.counter = 0
         self.lock = threading.Lock()
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Tuple[int, M]]:
         return self
 
-    def __next__(self):
+    def __next__(self) -> Tuple[int, M]:
         with self.lock:
             index = self.counter
             if index >= self.numModel:
@@ -73,13 +98,13 @@ class _FitMultipleIterator:
             self.counter += 1
         return index, self.fitSingleModel(index)
 
-    def next(self):
+    def next(self) -> Tuple[int, M]:
         """For python2 compatibility."""
         return self.__next__()
 
 
 @inherit_doc
-class Estimator(Params, metaclass=ABCMeta):
+class Estimator(Generic[M], Params, metaclass=ABCMeta):
     """
     Abstract class for estimators that fit models to data.
 
@@ -89,7 +114,7 @@ class Estimator(Params, metaclass=ABCMeta):
     pass
 
     @abstractmethod
-    def _fit(self, dataset):
+    def _fit(self, dataset: DataFrame) -> M:
         """
         Fits a model to the input dataset. This is called by the default 
implementation of fit.
 
@@ -106,7 +131,9 @@ class Estimator(Params, metaclass=ABCMeta):
         """
         raise NotImplementedError()
 
-    def fitMultiple(self, dataset, paramMaps):
+    def fitMultiple(
+        self, dataset: DataFrame, paramMaps: Sequence["ParamMap"]
+    ) -> Iterable[Tuple[int, M]]:
         """
         Fits a model to the input dataset for each param map in `paramMaps`.
 
@@ -128,12 +155,26 @@ class Estimator(Params, metaclass=ABCMeta):
         """
         estimator = self.copy()
 
-        def fitSingleModel(index):
+        def fitSingleModel(index: int) -> M:
             return estimator.fit(dataset, paramMaps[index])
 
         return _FitMultipleIterator(fitSingleModel, len(paramMaps))
 
-    def fit(self, dataset, params=None):
+    @overload
+    def fit(self, dataset: DataFrame, params: Optional["ParamMap"] = ...) -> M:
+        ...
+
+    @overload
+    def fit(
+        self, dataset: DataFrame, params: Union[List["ParamMap"], 
Tuple["ParamMap"]]
+    ) -> List[M]:
+        ...
+
+    def fit(
+        self,
+        dataset: DataFrame,
+        params: Optional[Union["ParamMap", List["ParamMap"], 
Tuple["ParamMap"]]] = None,
+    ) -> Union[M, List[M]]:
         """
         Fits a model to the input dataset with optional parameters.
 
@@ -156,10 +197,10 @@ class Estimator(Params, metaclass=ABCMeta):
         if params is None:
             params = dict()
         if isinstance(params, (list, tuple)):
-            models = [None] * len(params)
+            models: List[Optional[M]] = [None] * len(params)
             for index, model in self.fitMultiple(dataset, params):
                 models[index] = model
-            return models
+            return cast(List[M], models)
         elif isinstance(params, dict):
             if params:
                 return self.copy(params)._fit(dataset)
@@ -183,7 +224,7 @@ class Transformer(Params, metaclass=ABCMeta):
     pass
 
     @abstractmethod
-    def _transform(self, dataset):
+    def _transform(self, dataset: DataFrame) -> DataFrame:
         """
         Transforms the input dataset.
 
@@ -199,7 +240,7 @@ class Transformer(Params, metaclass=ABCMeta):
         """
         raise NotImplementedError()
 
-    def transform(self, dataset, params=None):
+    def transform(self, dataset: DataFrame, params: Optional["ParamMap"] = 
None) -> DataFrame:
         """
         Transforms the input dataset with optional parameters.
 
@@ -248,20 +289,20 @@ class UnaryTransformer(HasInputCol, HasOutputCol, 
Transformer):
     .. versionadded:: 2.3.0
     """
 
-    def setInputCol(self, value):
+    def setInputCol(self: P, value: str) -> P:
         """
         Sets the value of :py:attr:`inputCol`.
         """
         return self._set(inputCol=value)
 
-    def setOutputCol(self, value):
+    def setOutputCol(self: P, value: str) -> P:
         """
         Sets the value of :py:attr:`outputCol`.
         """
         return self._set(outputCol=value)
 
     @abstractmethod
-    def createTransformFunc(self):
+    def createTransformFunc(self) -> Callable[..., Any]:
         """
         Creates the transform function using the given param map. The input 
param map already takes
         account of the embedded param map. So the param values should be 
determined
@@ -270,20 +311,20 @@ class UnaryTransformer(HasInputCol, HasOutputCol, 
Transformer):
         raise NotImplementedError()
 
     @abstractmethod
-    def outputDataType(self):
+    def outputDataType(self) -> DataType:
         """
         Returns the data type of the output column.
         """
         raise NotImplementedError()
 
     @abstractmethod
-    def validateInputType(self, inputType):
+    def validateInputType(self, inputType: DataType) -> None:
         """
         Validates the input type. Throw an exception if it is invalid.
         """
         raise NotImplementedError()
 
-    def transformSchema(self, schema):
+    def transformSchema(self, schema: StructType) -> StructType:
         inputType = schema[self.getInputCol()].dataType
         self.validateInputType(inputType)
         if self.getOutputCol() in schema.names:
@@ -292,7 +333,7 @@ class UnaryTransformer(HasInputCol, HasOutputCol, 
Transformer):
         outputFields.append(StructField(self.getOutputCol(), 
self.outputDataType(), nullable=False))
         return StructType(outputFields)
 
-    def _transform(self, dataset):
+    def _transform(self, dataset: DataFrame) -> DataFrame:
         self.transformSchema(dataset.schema)
         transformUDF = udf(self.createTransformFunc(), self.outputDataType())
         transformedDataset = dataset.withColumn(
@@ -313,27 +354,27 @@ class _PredictorParams(HasLabelCol, HasFeaturesCol, 
HasPredictionCol):
 
 
 @inherit_doc
-class Predictor(Estimator, _PredictorParams, metaclass=ABCMeta):
+class Predictor(Estimator[M], _PredictorParams, metaclass=ABCMeta):
     """
     Estimator for prediction tasks (regression and classification).
     """
 
     @since("3.0.0")
-    def setLabelCol(self, value):
+    def setLabelCol(self: P, value: str) -> P:
         """
         Sets the value of :py:attr:`labelCol`.
         """
         return self._set(labelCol=value)
 
     @since("3.0.0")
-    def setFeaturesCol(self, value):
+    def setFeaturesCol(self: P, value: str) -> P:
         """
         Sets the value of :py:attr:`featuresCol`.
         """
         return self._set(featuresCol=value)
 
     @since("3.0.0")
-    def setPredictionCol(self, value):
+    def setPredictionCol(self: P, value: str) -> P:
         """
         Sets the value of :py:attr:`predictionCol`.
         """
@@ -341,29 +382,29 @@ class Predictor(Estimator, _PredictorParams, 
metaclass=ABCMeta):
 
 
 @inherit_doc
-class PredictionModel(Model, _PredictorParams, metaclass=ABCMeta):
+class PredictionModel(Generic[T], Transformer, _PredictorParams, 
metaclass=ABCMeta):
     """
     Model for prediction tasks (regression and classification).
     """
 
     @since("3.0.0")
-    def setFeaturesCol(self, value):
+    def setFeaturesCol(self: P, value: str) -> P:
         """
         Sets the value of :py:attr:`featuresCol`.
         """
         return self._set(featuresCol=value)
 
     @since("3.0.0")
-    def setPredictionCol(self, value):
+    def setPredictionCol(self: P, value: str) -> P:
         """
         Sets the value of :py:attr:`predictionCol`.
         """
         return self._set(predictionCol=value)
 
-    @property
+    @property  # type: ignore[misc]
     @abstractmethod
     @since("2.1.0")
-    def numFeatures(self):
+    def numFeatures(self) -> int:
         """
         Returns the number of features the model was trained on. If unknown, 
returns -1
         """
@@ -371,7 +412,7 @@ class PredictionModel(Model, _PredictorParams, 
metaclass=ABCMeta):
 
     @abstractmethod
     @since("3.0.0")
-    def predict(self, value):
+    def predict(self, value: T) -> float:
         """
         Predict label for the given features.
         """
diff --git a/python/pyspark/ml/base.pyi b/python/pyspark/ml/base.pyi
deleted file mode 100644
index 37ae6de..0000000
--- a/python/pyspark/ml/base.pyi
+++ /dev/null
@@ -1,103 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import (
-    Callable,
-    Generic,
-    Iterable,
-    List,
-    Optional,
-    Sequence,
-    Tuple,
-    Union,
-)
-from pyspark.ml._typing import M, P, T, ParamMap
-
-import _thread
-
-import abc
-from abc import abstractmethod
-from pyspark import since as since  # noqa: F401
-from pyspark.ml.common import inherit_doc as inherit_doc  # noqa: F401
-from pyspark.ml.param.shared import (
-    HasFeaturesCol as HasFeaturesCol,
-    HasInputCol as HasInputCol,
-    HasLabelCol as HasLabelCol,
-    HasOutputCol as HasOutputCol,
-    HasPredictionCol as HasPredictionCol,
-    Params as Params,
-)
-from pyspark.sql.functions import udf as udf  # noqa: F401
-from pyspark.sql.types import (  # noqa: F401
-    DataType,
-    StructField as StructField,
-    StructType as StructType,
-)
-
-from pyspark.sql.dataframe import DataFrame
-
-class _FitMultipleIterator:
-    fitSingleModel: Callable[[int], Transformer]
-    numModel: int
-    counter: int = ...
-    lock: _thread.LockType
-    def __init__(self, fitSingleModel: Callable[[int], Transformer], 
numModels: int) -> None: ...
-    def __iter__(self) -> _FitMultipleIterator: ...
-    def __next__(self) -> Tuple[int, Transformer]: ...
-    def next(self) -> Tuple[int, Transformer]: ...
-
-class Estimator(Generic[M], Params, metaclass=abc.ABCMeta):
-    @overload
-    def fit(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> M: 
...
-    @overload
-    def fit(
-        self, dataset: DataFrame, params: Union[List[ParamMap], 
Tuple[ParamMap]]
-    ) -> List[M]: ...
-    def fitMultiple(
-        self, dataset: DataFrame, params: Sequence[ParamMap]
-    ) -> Iterable[Tuple[int, M]]: ...
-
-class Transformer(Params, metaclass=abc.ABCMeta):
-    def transform(self, dataset: DataFrame, params: Optional[ParamMap] = ...) 
-> DataFrame: ...
-
-class Model(Transformer, metaclass=abc.ABCMeta): ...
-
-class UnaryTransformer(HasInputCol, HasOutputCol, Transformer, 
metaclass=abc.ABCMeta):
-    def createTransformFunc(self) -> Callable: ...
-    def outputDataType(self) -> DataType: ...
-    def validateInputType(self, inputType: DataType) -> None: ...
-    def transformSchema(self, schema: StructType) -> StructType: ...
-    def setInputCol(self: M, value: str) -> M: ...
-    def setOutputCol(self: M, value: str) -> M: ...
-
-class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol): ...
-
-class Predictor(Estimator[M], _PredictorParams, metaclass=abc.ABCMeta):
-    def setLabelCol(self: P, value: str) -> P: ...
-    def setFeaturesCol(self: P, value: str) -> P: ...
-    def setPredictionCol(self: P, value: str) -> P: ...
-
-class PredictionModel(Generic[T], Model, _PredictorParams, 
metaclass=abc.ABCMeta):
-    def setFeaturesCol(self: M, value: str) -> M: ...
-    def setPredictionCol(self: M, value: str) -> M: ...
-    @property
-    @abc.abstractmethod
-    def numFeatures(self) -> int: ...
-    @abstractmethod
-    def predict(self, value: T) -> float: ...
diff --git a/python/pyspark/ml/classification.pyi 
b/python/pyspark/ml/classification.pyi
index bb4fb05..4170a8c 100644
--- a/python/pyspark/ml/classification.pyi
+++ b/python/pyspark/ml/classification.pyi
@@ -820,6 +820,7 @@ class OneVsRest(
         weightCol: Optional[str] = ...,
         parallelism: int = ...,
     ) -> OneVsRest: ...
+    def _fit(self, dataset: DataFrame) -> OneVsRestModel: ...
     def setClassifier(self, value: Estimator[M]) -> OneVsRest: ...
     def setLabelCol(self, value: str) -> OneVsRest: ...
     def setFeaturesCol(self, value: str) -> OneVsRest: ...
@@ -832,6 +833,7 @@ class OneVsRest(
 class OneVsRestModel(Model, _OneVsRestParams, MLReadable[OneVsRestModel], 
MLWritable):
     models: List[Transformer]
     def __init__(self, models: List[Transformer]) -> None: ...
+    def _transform(self, dataset: DataFrame) -> DataFrame: ...
     def setFeaturesCol(self, value: str) -> OneVsRestModel: ...
     def setPredictionCol(self, value: str) -> OneVsRestModel: ...
     def setRawPredictionCol(self, value: str) -> OneVsRestModel: ...
diff --git a/python/pyspark/ml/param/__init__.py 
b/python/pyspark/ml/param/__init__.py
index 092f79f..fd5ed63 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -43,6 +43,7 @@ if TYPE_CHECKING:
 __all__ = ["Param", "Params", "TypeConverters"]
 
 T = TypeVar("T")
+P = TypeVar("P", bound="Params")
 
 
 class Param(Generic[T]):
@@ -409,7 +410,7 @@ class Params(Identifiable, metaclass=ABCMeta):
         paramMap.update(extra)
         return paramMap
 
-    def copy(self, extra: Optional["ParamMap"] = None) -> "Params":
+    def copy(self: P, extra: Optional["ParamMap"] = None) -> P:
         """
         Creates a copy of this instance with the same uid and some
         extra params. The default implementation creates a
@@ -492,7 +493,7 @@ class Params(Identifiable, metaclass=ABCMeta):
         dummy.uid = "undefined"
         return dummy
 
-    def _set(self, **kwargs: Any) -> "Params":
+    def _set(self: P, **kwargs: Any) -> P:
         """
         Sets user-supplied params.
         """
@@ -513,7 +514,7 @@ class Params(Identifiable, metaclass=ABCMeta):
         if self.isSet(param):
             del self._paramMap[param]
 
-    def _setDefault(self, **kwargs: Any) -> "Params":
+    def _setDefault(self: P, **kwargs: Any) -> P:
         """
         Sets default params.
         """
@@ -529,7 +530,7 @@ class Params(Identifiable, metaclass=ABCMeta):
             self._defaultParamMap[p] = value
         return self
 
-    def _copyValues(self, to: "Params", extra: Optional["ParamMap"] = None) -> 
"Params":
+    def _copyValues(self, to: P, extra: Optional["ParamMap"] = None) -> P:
         """
         Copies param values from this instance to another instance for
         params shared by them.
diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi
index f55b1e3..7b38900 100644
--- a/python/pyspark/ml/pipeline.pyi
+++ b/python/pyspark/ml/pipeline.pyi
@@ -33,10 +33,12 @@ from pyspark.ml.util import (  # noqa: F401
     MLWritable as MLWritable,
     MLWriter as MLWriter,
 )
+from pyspark.sql.dataframe import DataFrame
 
 class Pipeline(Estimator[PipelineModel], MLReadable[Pipeline], MLWritable):
     stages: List[PipelineStage]
     def __init__(self, *, stages: Optional[List[PipelineStage]] = ...) -> 
None: ...
+    def _fit(self, dataset: DataFrame) -> PipelineModel: ...
     def setStages(self, stages: List[PipelineStage]) -> Pipeline: ...
     def getStages(self) -> List[PipelineStage]: ...
     def setParams(self, *, stages: Optional[List[PipelineStage]] = ...) -> 
Pipeline: ...
@@ -69,6 +71,7 @@ class PipelineModelReader(MLReader[PipelineModel]):
 class PipelineModel(Model, MLReadable[PipelineModel], MLWritable):
     stages: List[PipelineStage]
     def __init__(self, stages: List[Transformer]) -> None: ...
+    def _transform(self, dataset: DataFrame) -> DataFrame: ...
     def copy(self, extra: Optional[Dict[Param, Any]] = ...) -> PipelineModel: 
...
     def write(self) -> JavaMLWriter: ...
     def save(self, path: str) -> None: ...
diff --git a/python/pyspark/ml/tests/typing/test_feature.yml 
b/python/pyspark/ml/tests/typing/test_feature.yml
index 3d6b090..0d1034a 100644
--- a/python/pyspark/ml/tests/typing/test_feature.yml
+++ b/python/pyspark/ml/tests/typing/test_feature.yml
@@ -15,6 +15,17 @@
 # limitations under the License.
 #
 
+
+- case: featureMethodChaining
+  main: |
+    from pyspark.ml.feature import NGram
+
+    reveal_type(NGram().setInputCol("foo").setOutputCol("bar"))
+
+  out: |
+    main:3: note: Revealed type is "pyspark.ml.feature.NGram"
+
+
 - case: stringIndexerOverloads
   main: |
     from pyspark.ml.feature import StringIndexer
@@ -41,4 +52,4 @@
     main:15: error: No overload variant of "StringIndexer" matches argument 
types "List[str]", "str"  [call-overload]
     main:15: note: Possible overload variants:
     main:15: note:     def StringIndexer(self, *, inputCol: Optional[str] = 
..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType: 
str = ...) -> StringIndexer
-    main:15: note:     def StringIndexer(self, *, inputCols: 
Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., 
handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
\ No newline at end of file
+    main:15: note:     def StringIndexer(self, *, inputCols: 
Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., 
handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
diff --git a/python/pyspark/ml/tests/typing/test_regression.yml 
b/python/pyspark/ml/tests/typing/test_regression.yml
index b045bec..4a54a56 100644
--- a/python/pyspark/ml/tests/typing/test_regression.yml
+++ b/python/pyspark/ml/tests/typing/test_regression.yml
@@ -15,6 +15,21 @@
 # limitations under the License.
 #
 
+- case: linearRegressionMethodChaining
+  main: |
+    from pyspark.ml.regression import LinearRegression, LinearRegressionModel
+
+    lr = LinearRegression()
+    reveal_type(lr.setFeaturesCol("foo").setLabelCol("bar"))
+
+    lrm = LinearRegressionModel.load("/foo")
+    reveal_type(lrm.setPredictionCol("baz"))
+
+  out: |
+     main:4: note: Revealed type is "pyspark.ml.regression.LinearRegression"
+     main:7: note: Revealed type is 
"pyspark.ml.regression.LinearRegressionModel"
+
+
 - case: loadFMRegressor
   main: |
     from pyspark.ml.regression import FMRegressor, FMRegressionModel
diff --git a/python/pyspark/ml/tuning.pyi b/python/pyspark/ml/tuning.pyi
index 75da80b..2538059 100644
--- a/python/pyspark/ml/tuning.pyi
+++ b/python/pyspark/ml/tuning.pyi
@@ -25,6 +25,7 @@ from pyspark.ml.evaluation import Evaluator
 from pyspark.ml.param import Param
 from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, 
HasSeed
 from pyspark.ml.util import MLReader, MLReadable, MLWriter, MLWritable
+from pyspark.sql import DataFrame
 
 class ParamGridBuilder:
     def __init__(self) -> None: ...
@@ -82,6 +83,7 @@ class CrossValidator(
         collectSubModels: bool = ...,
         foldCol: str = ...,
     ) -> CrossValidator: ...
+    def _fit(self, dataset: DataFrame) -> CrossValidatorModel: ...
     def setEstimator(self, value: Estimator) -> CrossValidator: ...
     def setEstimatorParamMaps(self, value: List[ParamMap]) -> CrossValidator: 
...
     def setEvaluator(self, value: Evaluator) -> CrossValidator: ...
@@ -107,6 +109,7 @@ class CrossValidatorModel(
         avgMetrics: Optional[List[float]] = ...,
         subModels: Optional[List[List[Model]]] = ...,
     ) -> None: ...
+    def _transform(self, dataset: DataFrame) -> DataFrame: ...
     def copy(self, extra: Optional[ParamMap] = ...) -> CrossValidatorModel: ...
     def write(self) -> MLWriter: ...
     @classmethod
@@ -147,6 +150,7 @@ class TrainValidationSplit(
         collectSubModels: bool = ...,
         seed: Optional[int] = ...,
     ) -> TrainValidationSplit: ...
+    def _fit(self, dataset: DataFrame) -> TrainValidationSplitModel: ...
     def setEstimator(self, value: Estimator) -> TrainValidationSplit: ...
     def setEstimatorParamMaps(self, value: List[ParamMap]) -> 
TrainValidationSplit: ...
     def setEvaluator(self, value: Evaluator) -> TrainValidationSplit: ...
@@ -174,6 +178,7 @@ class TrainValidationSplitModel(
         validationMetrics: Optional[List[float]] = ...,
         subModels: Optional[List[Model]] = ...,
     ) -> None: ...
+    def _transform(self, dataset: DataFrame) -> DataFrame: ...
     def setEstimator(self, value: Estimator) -> TrainValidationSplitModel: ...
     def setEstimatorParamMaps(self, value: List[ParamMap]) -> 
TrainValidationSplitModel: ...
     def setEvaluator(self, value: Evaluator) -> TrainValidationSplitModel: ...
diff --git a/python/pyspark/ml/wrapper.pyi b/python/pyspark/ml/wrapper.pyi
index 7c3406a..a238436 100644
--- a/python/pyspark/ml/wrapper.pyi
+++ b/python/pyspark/ml/wrapper.pyi
@@ -17,12 +17,13 @@
 # under the License.
 
 import abc
-from typing import Any, Optional
+from typing import Any, Optional, Generic
 from pyspark.ml._typing import P, T, JM, ParamMap
 
 from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, 
Model
 from pyspark.ml.base import _PredictorParams
 from pyspark.ml.param import Param, Params
+from pyspark.sql.dataframe import DataFrame
 
 class JavaWrapper:
     def __init__(self, java_obj: Optional[Any] = ...) -> None: ...
@@ -32,8 +33,11 @@ class JavaParams(JavaWrapper, Params, metaclass=abc.ABCMeta):
     def copy(self: P, extra: Optional[ParamMap] = ...) -> P: ...
     def clear(self, param: Param) -> None: ...
 
-class JavaEstimator(JavaParams, Estimator[JM], metaclass=abc.ABCMeta): ...
-class JavaTransformer(JavaParams, Transformer, metaclass=abc.ABCMeta): ...
+class JavaEstimator(Generic[JM], JavaParams, Estimator[JM], 
metaclass=abc.ABCMeta):
+    def _fit(self, dataset: DataFrame) -> JM: ...
+
+class JavaTransformer(JavaParams, Transformer, metaclass=abc.ABCMeta):
+    def _transform(self, dataset: DataFrame) -> DataFrame: ...
 
 class JavaModel(JavaTransformer, Model, metaclass=abc.ABCMeta):
     def __init__(self, java_model: Optional[Any] = ...) -> None: ...

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to