This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8559e94 [SPARK-37397][PYTHON] Inline annotations for pyspark.ml.base
8559e94 is described below
commit 8559e94389fb9b2a3f453f568b3e11c79d2b4e2c
Author: zero323 <[email protected]>
AuthorDate: Tue Feb 1 23:22:36 2022 -0800
[SPARK-37397][PYTHON] Inline annotations for pyspark.ml.base
### What changes were proposed in this pull request?
Migration of type annotation for `pyspark.ml.base` from stub file to inline
hints.
### Why are the changes needed?
As a part of ongoing type hints migration.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests + new data tests.
Closes #35289 from zero323/SPARK-37397.
Authored-by: zero323 <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/ml/base.py | 105 ++++++++++++++-------
python/pyspark/ml/base.pyi | 103 --------------------
python/pyspark/ml/classification.pyi | 2 +
python/pyspark/ml/param/__init__.py | 9 +-
python/pyspark/ml/pipeline.pyi | 3 +
python/pyspark/ml/tests/typing/test_feature.yml | 13 ++-
python/pyspark/ml/tests/typing/test_regression.yml | 15 +++
python/pyspark/ml/tuning.pyi | 5 +
python/pyspark/ml/wrapper.pyi | 10 +-
9 files changed, 122 insertions(+), 143 deletions(-)
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index d984209..4f4ddef 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -20,7 +20,25 @@ from abc import ABCMeta, abstractmethod
import copy
import threading
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+ TYPE_CHECKING,
+)
+
from pyspark import since
+from pyspark.ml.param import P
from pyspark.ml.common import inherit_doc
from pyspark.ml.param.shared import (
HasInputCol,
@@ -30,11 +48,18 @@ from pyspark.ml.param.shared import (
HasPredictionCol,
Params,
)
+from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import udf
-from pyspark.sql.types import StructField, StructType
+from pyspark.sql.types import DataType, StructField, StructType
+
+if TYPE_CHECKING:
+ from pyspark.ml._typing import ParamMap
+T = TypeVar("T")
+M = TypeVar("M", bound="Transformer")
-class _FitMultipleIterator:
+
+class _FitMultipleIterator(Generic[M]):
"""
Used by default implementation of Estimator.fitMultiple to produce models
in a thread safe
iterator. This class handles the simple case of fitMultiple where each
param map should be
@@ -55,17 +80,17 @@ class _FitMultipleIterator:
See :py:meth:`Estimator.fitMultiple` for more info.
"""
- def __init__(self, fitSingleModel, numModels):
+ def __init__(self, fitSingleModel: Callable[[int], M], numModels: int):
""" """
self.fitSingleModel = fitSingleModel
self.numModel = numModels
self.counter = 0
self.lock = threading.Lock()
- def __iter__(self):
+ def __iter__(self) -> Iterator[Tuple[int, M]]:
return self
- def __next__(self):
+ def __next__(self) -> Tuple[int, M]:
with self.lock:
index = self.counter
if index >= self.numModel:
@@ -73,13 +98,13 @@ class _FitMultipleIterator:
self.counter += 1
return index, self.fitSingleModel(index)
- def next(self):
+ def next(self) -> Tuple[int, M]:
"""For python2 compatibility."""
return self.__next__()
@inherit_doc
-class Estimator(Params, metaclass=ABCMeta):
+class Estimator(Generic[M], Params, metaclass=ABCMeta):
"""
Abstract class for estimators that fit models to data.
@@ -89,7 +114,7 @@ class Estimator(Params, metaclass=ABCMeta):
pass
@abstractmethod
- def _fit(self, dataset):
+ def _fit(self, dataset: DataFrame) -> M:
"""
Fits a model to the input dataset. This is called by the default
implementation of fit.
@@ -106,7 +131,9 @@ class Estimator(Params, metaclass=ABCMeta):
"""
raise NotImplementedError()
- def fitMultiple(self, dataset, paramMaps):
+ def fitMultiple(
+ self, dataset: DataFrame, paramMaps: Sequence["ParamMap"]
+ ) -> Iterable[Tuple[int, M]]:
"""
Fits a model to the input dataset for each param map in `paramMaps`.
@@ -128,12 +155,26 @@ class Estimator(Params, metaclass=ABCMeta):
"""
estimator = self.copy()
- def fitSingleModel(index):
+ def fitSingleModel(index: int) -> M:
return estimator.fit(dataset, paramMaps[index])
return _FitMultipleIterator(fitSingleModel, len(paramMaps))
- def fit(self, dataset, params=None):
+ @overload
+ def fit(self, dataset: DataFrame, params: Optional["ParamMap"] = ...) -> M:
+ ...
+
+ @overload
+ def fit(
+ self, dataset: DataFrame, params: Union[List["ParamMap"],
Tuple["ParamMap"]]
+ ) -> List[M]:
+ ...
+
+ def fit(
+ self,
+ dataset: DataFrame,
+ params: Optional[Union["ParamMap", List["ParamMap"],
Tuple["ParamMap"]]] = None,
+ ) -> Union[M, List[M]]:
"""
Fits a model to the input dataset with optional parameters.
@@ -156,10 +197,10 @@ class Estimator(Params, metaclass=ABCMeta):
if params is None:
params = dict()
if isinstance(params, (list, tuple)):
- models = [None] * len(params)
+ models: List[Optional[M]] = [None] * len(params)
for index, model in self.fitMultiple(dataset, params):
models[index] = model
- return models
+ return cast(List[M], models)
elif isinstance(params, dict):
if params:
return self.copy(params)._fit(dataset)
@@ -183,7 +224,7 @@ class Transformer(Params, metaclass=ABCMeta):
pass
@abstractmethod
- def _transform(self, dataset):
+ def _transform(self, dataset: DataFrame) -> DataFrame:
"""
Transforms the input dataset.
@@ -199,7 +240,7 @@ class Transformer(Params, metaclass=ABCMeta):
"""
raise NotImplementedError()
- def transform(self, dataset, params=None):
+ def transform(self, dataset: DataFrame, params: Optional["ParamMap"] =
None) -> DataFrame:
"""
Transforms the input dataset with optional parameters.
@@ -248,20 +289,20 @@ class UnaryTransformer(HasInputCol, HasOutputCol,
Transformer):
.. versionadded:: 2.3.0
"""
- def setInputCol(self, value):
+ def setInputCol(self: P, value: str) -> P:
"""
Sets the value of :py:attr:`inputCol`.
"""
return self._set(inputCol=value)
- def setOutputCol(self, value):
+ def setOutputCol(self: P, value: str) -> P:
"""
Sets the value of :py:attr:`outputCol`.
"""
return self._set(outputCol=value)
@abstractmethod
- def createTransformFunc(self):
+ def createTransformFunc(self) -> Callable[..., Any]:
"""
Creates the transform function using the given param map. The input
param map already takes
account of the embedded param map. So the param values should be
determined
@@ -270,20 +311,20 @@ class UnaryTransformer(HasInputCol, HasOutputCol,
Transformer):
raise NotImplementedError()
@abstractmethod
- def outputDataType(self):
+ def outputDataType(self) -> DataType:
"""
Returns the data type of the output column.
"""
raise NotImplementedError()
@abstractmethod
- def validateInputType(self, inputType):
+ def validateInputType(self, inputType: DataType) -> None:
"""
Validates the input type. Throw an exception if it is invalid.
"""
raise NotImplementedError()
- def transformSchema(self, schema):
+ def transformSchema(self, schema: StructType) -> StructType:
inputType = schema[self.getInputCol()].dataType
self.validateInputType(inputType)
if self.getOutputCol() in schema.names:
@@ -292,7 +333,7 @@ class UnaryTransformer(HasInputCol, HasOutputCol,
Transformer):
outputFields.append(StructField(self.getOutputCol(),
self.outputDataType(), nullable=False))
return StructType(outputFields)
- def _transform(self, dataset):
+ def _transform(self, dataset: DataFrame) -> DataFrame:
self.transformSchema(dataset.schema)
transformUDF = udf(self.createTransformFunc(), self.outputDataType())
transformedDataset = dataset.withColumn(
@@ -313,27 +354,27 @@ class _PredictorParams(HasLabelCol, HasFeaturesCol,
HasPredictionCol):
@inherit_doc
-class Predictor(Estimator, _PredictorParams, metaclass=ABCMeta):
+class Predictor(Estimator[M], _PredictorParams, metaclass=ABCMeta):
"""
Estimator for prediction tasks (regression and classification).
"""
@since("3.0.0")
- def setLabelCol(self, value):
+ def setLabelCol(self: P, value: str) -> P:
"""
Sets the value of :py:attr:`labelCol`.
"""
return self._set(labelCol=value)
@since("3.0.0")
- def setFeaturesCol(self, value):
+ def setFeaturesCol(self: P, value: str) -> P:
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)
@since("3.0.0")
- def setPredictionCol(self, value):
+ def setPredictionCol(self: P, value: str) -> P:
"""
Sets the value of :py:attr:`predictionCol`.
"""
@@ -341,29 +382,29 @@ class Predictor(Estimator, _PredictorParams,
metaclass=ABCMeta):
@inherit_doc
-class PredictionModel(Model, _PredictorParams, metaclass=ABCMeta):
+class PredictionModel(Generic[T], Transformer, _PredictorParams,
metaclass=ABCMeta):
"""
Model for prediction tasks (regression and classification).
"""
@since("3.0.0")
- def setFeaturesCol(self, value):
+ def setFeaturesCol(self: P, value: str) -> P:
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)
@since("3.0.0")
- def setPredictionCol(self, value):
+ def setPredictionCol(self: P, value: str) -> P:
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)
- @property
+ @property # type: ignore[misc]
@abstractmethod
@since("2.1.0")
- def numFeatures(self):
+ def numFeatures(self) -> int:
"""
Returns the number of features the model was trained on. If unknown,
returns -1
"""
@@ -371,7 +412,7 @@ class PredictionModel(Model, _PredictorParams,
metaclass=ABCMeta):
@abstractmethod
@since("3.0.0")
- def predict(self, value):
+ def predict(self, value: T) -> float:
"""
Predict label for the given features.
"""
diff --git a/python/pyspark/ml/base.pyi b/python/pyspark/ml/base.pyi
deleted file mode 100644
index 37ae6de..0000000
--- a/python/pyspark/ml/base.pyi
+++ /dev/null
@@ -1,103 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import (
- Callable,
- Generic,
- Iterable,
- List,
- Optional,
- Sequence,
- Tuple,
- Union,
-)
-from pyspark.ml._typing import M, P, T, ParamMap
-
-import _thread
-
-import abc
-from abc import abstractmethod
-from pyspark import since as since # noqa: F401
-from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401
-from pyspark.ml.param.shared import (
- HasFeaturesCol as HasFeaturesCol,
- HasInputCol as HasInputCol,
- HasLabelCol as HasLabelCol,
- HasOutputCol as HasOutputCol,
- HasPredictionCol as HasPredictionCol,
- Params as Params,
-)
-from pyspark.sql.functions import udf as udf # noqa: F401
-from pyspark.sql.types import ( # noqa: F401
- DataType,
- StructField as StructField,
- StructType as StructType,
-)
-
-from pyspark.sql.dataframe import DataFrame
-
-class _FitMultipleIterator:
- fitSingleModel: Callable[[int], Transformer]
- numModel: int
- counter: int = ...
- lock: _thread.LockType
- def __init__(self, fitSingleModel: Callable[[int], Transformer],
numModels: int) -> None: ...
- def __iter__(self) -> _FitMultipleIterator: ...
- def __next__(self) -> Tuple[int, Transformer]: ...
- def next(self) -> Tuple[int, Transformer]: ...
-
-class Estimator(Generic[M], Params, metaclass=abc.ABCMeta):
- @overload
- def fit(self, dataset: DataFrame, params: Optional[ParamMap] = ...) -> M:
...
- @overload
- def fit(
- self, dataset: DataFrame, params: Union[List[ParamMap],
Tuple[ParamMap]]
- ) -> List[M]: ...
- def fitMultiple(
- self, dataset: DataFrame, params: Sequence[ParamMap]
- ) -> Iterable[Tuple[int, M]]: ...
-
-class Transformer(Params, metaclass=abc.ABCMeta):
- def transform(self, dataset: DataFrame, params: Optional[ParamMap] = ...)
-> DataFrame: ...
-
-class Model(Transformer, metaclass=abc.ABCMeta): ...
-
-class UnaryTransformer(HasInputCol, HasOutputCol, Transformer,
metaclass=abc.ABCMeta):
- def createTransformFunc(self) -> Callable: ...
- def outputDataType(self) -> DataType: ...
- def validateInputType(self, inputType: DataType) -> None: ...
- def transformSchema(self, schema: StructType) -> StructType: ...
- def setInputCol(self: M, value: str) -> M: ...
- def setOutputCol(self: M, value: str) -> M: ...
-
-class _PredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol): ...
-
-class Predictor(Estimator[M], _PredictorParams, metaclass=abc.ABCMeta):
- def setLabelCol(self: P, value: str) -> P: ...
- def setFeaturesCol(self: P, value: str) -> P: ...
- def setPredictionCol(self: P, value: str) -> P: ...
-
-class PredictionModel(Generic[T], Model, _PredictorParams,
metaclass=abc.ABCMeta):
- def setFeaturesCol(self: M, value: str) -> M: ...
- def setPredictionCol(self: M, value: str) -> M: ...
- @property
- @abc.abstractmethod
- def numFeatures(self) -> int: ...
- @abstractmethod
- def predict(self, value: T) -> float: ...
diff --git a/python/pyspark/ml/classification.pyi
b/python/pyspark/ml/classification.pyi
index bb4fb05..4170a8c 100644
--- a/python/pyspark/ml/classification.pyi
+++ b/python/pyspark/ml/classification.pyi
@@ -820,6 +820,7 @@ class OneVsRest(
weightCol: Optional[str] = ...,
parallelism: int = ...,
) -> OneVsRest: ...
+ def _fit(self, dataset: DataFrame) -> OneVsRestModel: ...
def setClassifier(self, value: Estimator[M]) -> OneVsRest: ...
def setLabelCol(self, value: str) -> OneVsRest: ...
def setFeaturesCol(self, value: str) -> OneVsRest: ...
@@ -832,6 +833,7 @@ class OneVsRest(
class OneVsRestModel(Model, _OneVsRestParams, MLReadable[OneVsRestModel],
MLWritable):
models: List[Transformer]
def __init__(self, models: List[Transformer]) -> None: ...
+ def _transform(self, dataset: DataFrame) -> DataFrame: ...
def setFeaturesCol(self, value: str) -> OneVsRestModel: ...
def setPredictionCol(self, value: str) -> OneVsRestModel: ...
def setRawPredictionCol(self, value: str) -> OneVsRestModel: ...
diff --git a/python/pyspark/ml/param/__init__.py
b/python/pyspark/ml/param/__init__.py
index 092f79f..fd5ed63 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -43,6 +43,7 @@ if TYPE_CHECKING:
__all__ = ["Param", "Params", "TypeConverters"]
T = TypeVar("T")
+P = TypeVar("P", bound="Params")
class Param(Generic[T]):
@@ -409,7 +410,7 @@ class Params(Identifiable, metaclass=ABCMeta):
paramMap.update(extra)
return paramMap
- def copy(self, extra: Optional["ParamMap"] = None) -> "Params":
+ def copy(self: P, extra: Optional["ParamMap"] = None) -> P:
"""
Creates a copy of this instance with the same uid and some
extra params. The default implementation creates a
@@ -492,7 +493,7 @@ class Params(Identifiable, metaclass=ABCMeta):
dummy.uid = "undefined"
return dummy
- def _set(self, **kwargs: Any) -> "Params":
+ def _set(self: P, **kwargs: Any) -> P:
"""
Sets user-supplied params.
"""
@@ -513,7 +514,7 @@ class Params(Identifiable, metaclass=ABCMeta):
if self.isSet(param):
del self._paramMap[param]
- def _setDefault(self, **kwargs: Any) -> "Params":
+ def _setDefault(self: P, **kwargs: Any) -> P:
"""
Sets default params.
"""
@@ -529,7 +530,7 @@ class Params(Identifiable, metaclass=ABCMeta):
self._defaultParamMap[p] = value
return self
- def _copyValues(self, to: "Params", extra: Optional["ParamMap"] = None) ->
"Params":
+ def _copyValues(self, to: P, extra: Optional["ParamMap"] = None) -> P:
"""
Copies param values from this instance to another instance for
params shared by them.
diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi
index f55b1e3..7b38900 100644
--- a/python/pyspark/ml/pipeline.pyi
+++ b/python/pyspark/ml/pipeline.pyi
@@ -33,10 +33,12 @@ from pyspark.ml.util import ( # noqa: F401
MLWritable as MLWritable,
MLWriter as MLWriter,
)
+from pyspark.sql.dataframe import DataFrame
class Pipeline(Estimator[PipelineModel], MLReadable[Pipeline], MLWritable):
stages: List[PipelineStage]
def __init__(self, *, stages: Optional[List[PipelineStage]] = ...) ->
None: ...
+ def _fit(self, dataset: DataFrame) -> PipelineModel: ...
def setStages(self, stages: List[PipelineStage]) -> Pipeline: ...
def getStages(self) -> List[PipelineStage]: ...
def setParams(self, *, stages: Optional[List[PipelineStage]] = ...) ->
Pipeline: ...
@@ -69,6 +71,7 @@ class PipelineModelReader(MLReader[PipelineModel]):
class PipelineModel(Model, MLReadable[PipelineModel], MLWritable):
stages: List[PipelineStage]
def __init__(self, stages: List[Transformer]) -> None: ...
+ def _transform(self, dataset: DataFrame) -> DataFrame: ...
def copy(self, extra: Optional[Dict[Param, Any]] = ...) -> PipelineModel:
...
def write(self) -> JavaMLWriter: ...
def save(self, path: str) -> None: ...
diff --git a/python/pyspark/ml/tests/typing/test_feature.yml
b/python/pyspark/ml/tests/typing/test_feature.yml
index 3d6b090..0d1034a 100644
--- a/python/pyspark/ml/tests/typing/test_feature.yml
+++ b/python/pyspark/ml/tests/typing/test_feature.yml
@@ -15,6 +15,17 @@
# limitations under the License.
#
+
+- case: featureMethodChaining
+ main: |
+ from pyspark.ml.feature import NGram
+
+ reveal_type(NGram().setInputCol("foo").setOutputCol("bar"))
+
+ out: |
+ main:3: note: Revealed type is "pyspark.ml.feature.NGram"
+
+
- case: stringIndexerOverloads
main: |
from pyspark.ml.feature import StringIndexer
@@ -41,4 +52,4 @@
main:15: error: No overload variant of "StringIndexer" matches argument
types "List[str]", "str" [call-overload]
main:15: note: Possible overload variants:
main:15: note: def StringIndexer(self, *, inputCol: Optional[str] =
..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType:
str = ...) -> StringIndexer
- main:15: note: def StringIndexer(self, *, inputCols:
Optional[List[str]] = ..., outputCols: Optional[List[str]] = ...,
handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
\ No newline at end of file
+ main:15: note: def StringIndexer(self, *, inputCols:
Optional[List[str]] = ..., outputCols: Optional[List[str]] = ...,
handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
diff --git a/python/pyspark/ml/tests/typing/test_regression.yml
b/python/pyspark/ml/tests/typing/test_regression.yml
index b045bec..4a54a56 100644
--- a/python/pyspark/ml/tests/typing/test_regression.yml
+++ b/python/pyspark/ml/tests/typing/test_regression.yml
@@ -15,6 +15,21 @@
# limitations under the License.
#
+- case: linearRegressionMethodChaining
+ main: |
+ from pyspark.ml.regression import LinearRegression, LinearRegressionModel
+
+ lr = LinearRegression()
+ reveal_type(lr.setFeaturesCol("foo").setLabelCol("bar"))
+
+ lrm = LinearRegressionModel.load("/foo")
+ reveal_type(lrm.setPredictionCol("baz"))
+
+ out: |
+ main:4: note: Revealed type is "pyspark.ml.regression.LinearRegression"
+ main:7: note: Revealed type is
"pyspark.ml.regression.LinearRegressionModel"
+
+
- case: loadFMRegressor
main: |
from pyspark.ml.regression import FMRegressor, FMRegressionModel
diff --git a/python/pyspark/ml/tuning.pyi b/python/pyspark/ml/tuning.pyi
index 75da80b..2538059 100644
--- a/python/pyspark/ml/tuning.pyi
+++ b/python/pyspark/ml/tuning.pyi
@@ -25,6 +25,7 @@ from pyspark.ml.evaluation import Evaluator
from pyspark.ml.param import Param
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism,
HasSeed
from pyspark.ml.util import MLReader, MLReadable, MLWriter, MLWritable
+from pyspark.sql import DataFrame
class ParamGridBuilder:
def __init__(self) -> None: ...
@@ -82,6 +83,7 @@ class CrossValidator(
collectSubModels: bool = ...,
foldCol: str = ...,
) -> CrossValidator: ...
+ def _fit(self, dataset: DataFrame) -> CrossValidatorModel: ...
def setEstimator(self, value: Estimator) -> CrossValidator: ...
def setEstimatorParamMaps(self, value: List[ParamMap]) -> CrossValidator:
...
def setEvaluator(self, value: Evaluator) -> CrossValidator: ...
@@ -107,6 +109,7 @@ class CrossValidatorModel(
avgMetrics: Optional[List[float]] = ...,
subModels: Optional[List[List[Model]]] = ...,
) -> None: ...
+ def _transform(self, dataset: DataFrame) -> DataFrame: ...
def copy(self, extra: Optional[ParamMap] = ...) -> CrossValidatorModel: ...
def write(self) -> MLWriter: ...
@classmethod
@@ -147,6 +150,7 @@ class TrainValidationSplit(
collectSubModels: bool = ...,
seed: Optional[int] = ...,
) -> TrainValidationSplit: ...
+ def _fit(self, dataset: DataFrame) -> TrainValidationSplitModel: ...
def setEstimator(self, value: Estimator) -> TrainValidationSplit: ...
def setEstimatorParamMaps(self, value: List[ParamMap]) ->
TrainValidationSplit: ...
def setEvaluator(self, value: Evaluator) -> TrainValidationSplit: ...
@@ -174,6 +178,7 @@ class TrainValidationSplitModel(
validationMetrics: Optional[List[float]] = ...,
subModels: Optional[List[Model]] = ...,
) -> None: ...
+ def _transform(self, dataset: DataFrame) -> DataFrame: ...
def setEstimator(self, value: Estimator) -> TrainValidationSplitModel: ...
def setEstimatorParamMaps(self, value: List[ParamMap]) ->
TrainValidationSplitModel: ...
def setEvaluator(self, value: Evaluator) -> TrainValidationSplitModel: ...
diff --git a/python/pyspark/ml/wrapper.pyi b/python/pyspark/ml/wrapper.pyi
index 7c3406a..a238436 100644
--- a/python/pyspark/ml/wrapper.pyi
+++ b/python/pyspark/ml/wrapper.pyi
@@ -17,12 +17,13 @@
# under the License.
import abc
-from typing import Any, Optional
+from typing import Any, Optional, Generic
from pyspark.ml._typing import P, T, JM, ParamMap
from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer,
Model
from pyspark.ml.base import _PredictorParams
from pyspark.ml.param import Param, Params
+from pyspark.sql.dataframe import DataFrame
class JavaWrapper:
def __init__(self, java_obj: Optional[Any] = ...) -> None: ...
@@ -32,8 +33,11 @@ class JavaParams(JavaWrapper, Params, metaclass=abc.ABCMeta):
def copy(self: P, extra: Optional[ParamMap] = ...) -> P: ...
def clear(self, param: Param) -> None: ...
-class JavaEstimator(JavaParams, Estimator[JM], metaclass=abc.ABCMeta): ...
-class JavaTransformer(JavaParams, Transformer, metaclass=abc.ABCMeta): ...
+class JavaEstimator(Generic[JM], JavaParams, Estimator[JM],
metaclass=abc.ABCMeta):
+ def _fit(self, dataset: DataFrame) -> JM: ...
+
+class JavaTransformer(JavaParams, Transformer, metaclass=abc.ABCMeta):
+ def _transform(self, dataset: DataFrame) -> DataFrame: ...
class JavaModel(JavaTransformer, Model, metaclass=abc.ABCMeta):
def __init__(self, java_model: Optional[Any] = ...) -> None: ...
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]