This is an automated email from the ASF dual-hosted git repository.
zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0084a86 [SPARK-37415][PYTHON][ML] Inline type hints for
pyspark.ml.util
0084a86 is described below
commit 0084a8677ad77143b109dff3d3e9be4035d00fd4
Author: zero323 <[email protected]>
AuthorDate: Sun Feb 6 11:09:41 2022 +0100
[SPARK-37415][PYTHON][ML] Inline type hints for pyspark.ml.util
### What changes were proposed in this pull request?
This PR migrates type `pyspark.ml.util` annotations from stub file to
inline type hints.
### Why are the changes needed?
Part of ongoing migration of type hints.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #35367 from zero323/SPARK-37415.
Authored-by: zero323 <[email protected]>
Signed-off-by: zero323 <[email protected]>
---
python/pyspark/ml/base.py | 4 +-
python/pyspark/ml/util.py | 207 +++++++++++++++++++++++++-----------------
python/pyspark/ml/util.pyi | 136 ---------------------------
python/pyspark/ml/wrapper.pyi | 1 +
4 files changed, 128 insertions(+), 220 deletions(-)
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 4f4ddef..9e8252d 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -104,7 +104,7 @@ class _FitMultipleIterator(Generic[M]):
@inherit_doc
-class Estimator(Generic[M], Params, metaclass=ABCMeta):
+class Estimator(Params, Generic[M], metaclass=ABCMeta):
"""
Abstract class for estimators that fit models to data.
@@ -382,7 +382,7 @@ class Predictor(Estimator[M], _PredictorParams,
metaclass=ABCMeta):
@inherit_doc
-class PredictionModel(Generic[T], Transformer, _PredictorParams,
metaclass=ABCMeta):
+class PredictionModel(Model, _PredictorParams, Generic[T], metaclass=ABCMeta):
"""
Model for prediction tasks (regression and classification).
"""
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index ac60ded..1dacffc 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -20,13 +20,31 @@ import os
import time
import uuid
+from typing import Any, Dict, Generic, List, Optional, Sequence, Type,
TypeVar, cast, TYPE_CHECKING
+
+
from pyspark import SparkContext, since
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
from pyspark.util import VersionUtils
+if TYPE_CHECKING:
+ from py4j.java_gateway import JavaGateway, JavaObject
+ from pyspark.ml._typing import PipelineStage
+
+ from pyspark.ml.param import Param
+ from pyspark.ml.base import Params
+ from pyspark.ml.wrapper import JavaWrapper
-def _jvm():
+T = TypeVar("T")
+RW = TypeVar("RW", bound="BaseReadWrite")
+W = TypeVar("W", bound="MLWriter")
+JW = TypeVar("JW", bound="JavaMLWriter")
+RL = TypeVar("RL", bound="MLReadable")
+JR = TypeVar("JR", bound="JavaMLReader")
+
+
+def _jvm() -> "JavaGateway":
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
@@ -43,15 +61,15 @@ class Identifiable:
Object with a unique ID.
"""
- def __init__(self):
+ def __init__(self) -> None:
#: A unique id for the object.
self.uid = self._randomUID()
- def __repr__(self):
+ def __repr__(self) -> str:
return self.uid
@classmethod
- def _randomUID(cls):
+ def _randomUID(cls) -> str:
"""
Generate a unique string id for the object. The default implementation
concatenates the class name, "_", and 12 random hex chars.
@@ -68,10 +86,10 @@ class BaseReadWrite:
.. versionadded:: 2.3.0
"""
- def __init__(self):
- self._sparkSession = None
+ def __init__(self) -> None:
+ self._sparkSession: Optional[SparkSession] = None
- def session(self, sparkSession):
+ def session(self: RW, sparkSession: SparkSession) -> RW:
"""
Sets the Spark Session to use for saving/loading.
"""
@@ -79,19 +97,21 @@ class BaseReadWrite:
return self
@property
- def sparkSession(self):
+ def sparkSession(self) -> SparkSession:
"""
Returns the user-specified Spark Session or the default.
"""
if self._sparkSession is None:
self._sparkSession = SparkSession._getActiveSessionOrCreate()
+ assert self._sparkSession is not None
return self._sparkSession
@property
- def sc(self):
+ def sc(self) -> SparkContext:
"""
Returns the underlying `SparkContext`.
"""
+ assert self.sparkSession is not None
return self.sparkSession.sparkContext
@@ -103,37 +123,41 @@ class MLWriter(BaseReadWrite):
.. versionadded:: 2.0.0
"""
- def __init__(self):
+ def __init__(self) -> None:
super(MLWriter, self).__init__()
- self.shouldOverwrite = False
- self.optionMap = {}
+ self.shouldOverwrite: bool = False
+ self.optionMap: Dict[str, Any] = {}
- def _handleOverwrite(self, path):
+ def _handleOverwrite(self, path: str) -> None:
from pyspark.ml.wrapper import JavaWrapper
- _java_obj =
JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
+ _java_obj = JavaWrapper._new_java_obj( # type: ignore[attr-defined]
+ "org.apache.spark.ml.util.FileSystemOverwrite"
+ )
wrapper = JavaWrapper(_java_obj)
- wrapper._call_java("handleOverwrite", path, True,
self.sparkSession._jsparkSession)
+ wrapper._call_java( # type: ignore[attr-defined]
+ "handleOverwrite", path, True, self.sparkSession._jsparkSession
+ )
- def save(self, path):
+ def save(self, path: str) -> None:
"""Save the ML instance to the input path."""
if self.shouldOverwrite:
self._handleOverwrite(path)
self.saveImpl(path)
- def saveImpl(self, path):
+ def saveImpl(self, path: str) -> None:
"""
save() handles overwriting and then calls this method. Subclasses
should override this
method to implement the actual saving of the instance.
"""
raise NotImplementedError("MLWriter is not yet implemented for type:
%s" % type(self))
- def overwrite(self):
+ def overwrite(self) -> "MLWriter":
"""Overwrites if the output path already exists."""
self.shouldOverwrite = True
return self
- def option(self, key, value):
+ def option(self, key: str, value: Any) -> "MLWriter":
"""
Adds an option to the underlying MLWriter. See the documentation for
the specific model's
writer for possible options. The option name (key) is case-insensitive.
@@ -150,7 +174,7 @@ class GeneralMLWriter(MLWriter):
.. versionadded:: 2.4.0
"""
- def format(self, source):
+ def format(self, source: str) -> "GeneralMLWriter":
"""
Specifies the format of ML export ("pmml", "internal", or the fully
qualified class
name for export).
@@ -165,27 +189,29 @@ class JavaMLWriter(MLWriter):
(Private) Specialization of :py:class:`MLWriter` for
:py:class:`JavaParams` types
"""
- def __init__(self, instance):
+ _jwrite: "JavaObject"
+
+ def __init__(self, instance: "JavaMLWritable"):
super(JavaMLWriter, self).__init__()
- _java_obj = instance._to_java()
+ _java_obj = instance._to_java() # type: ignore[attr-defined]
self._jwrite = _java_obj.write()
- def save(self, path):
+ def save(self, path: str) -> None:
"""Save the ML instance to the input path."""
if not isinstance(path, str):
raise TypeError("path should be a string, got type %s" %
type(path))
self._jwrite.save(path)
- def overwrite(self):
+ def overwrite(self) -> "JavaMLWriter":
"""Overwrites if the output path already exists."""
self._jwrite.overwrite()
return self
- def option(self, key, value):
+ def option(self, key: str, value: str) -> "JavaMLWriter":
self._jwrite.option(key, value)
return self
- def session(self, sparkSession):
+ def session(self, sparkSession: SparkSession) -> "JavaMLWriter":
"""Sets the Spark Session to use for saving."""
self._jwrite.session(sparkSession._jsparkSession)
return self
@@ -197,10 +223,10 @@ class GeneralJavaMLWriter(JavaMLWriter):
(Private) Specialization of :py:class:`GeneralMLWriter` for
:py:class:`JavaParams` types
"""
- def __init__(self, instance):
+ def __init__(self, instance: "JavaMLWritable"):
super(GeneralJavaMLWriter, self).__init__(instance)
- def format(self, source):
+ def format(self, source: str) -> "GeneralJavaMLWriter":
"""
Specifies the format of ML export ("pmml", "internal", or the fully
qualified class
name for export).
@@ -217,11 +243,11 @@ class MLWritable:
.. versionadded:: 2.0.0
"""
- def write(self):
+ def write(self) -> MLWriter:
"""Returns an MLWriter instance for this ML instance."""
raise NotImplementedError("MLWritable is not yet implemented for type:
%r" % type(self))
- def save(self, path):
+ def save(self, path: str) -> None:
"""Save this ML instance to the given path, a shortcut of
'write().save(path)'."""
self.write().save(path)
@@ -232,7 +258,7 @@ class JavaMLWritable(MLWritable):
(Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
"""
- def write(self):
+ def write(self) -> JavaMLWriter:
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
@@ -243,39 +269,39 @@ class GeneralJavaMLWritable(JavaMLWritable):
(Private) Mixin for ML instances that provide
:py:class:`GeneralJavaMLWriter`.
"""
- def write(self):
+ def write(self) -> GeneralJavaMLWriter:
"""Returns an GeneralMLWriter instance for this ML instance."""
return GeneralJavaMLWriter(self)
@inherit_doc
-class MLReader(BaseReadWrite):
+class MLReader(BaseReadWrite, Generic[RL]):
"""
Utility class that can load ML instances.
.. versionadded:: 2.0.0
"""
- def __init__(self):
+ def __init__(self) -> None:
super(MLReader, self).__init__()
- def load(self, path):
+ def load(self, path: str) -> RL:
"""Load the ML instance from the input path."""
raise NotImplementedError("MLReader is not yet implemented for type:
%s" % type(self))
@inherit_doc
-class JavaMLReader(MLReader):
+class JavaMLReader(MLReader[RL]):
"""
(Private) Specialization of :py:class:`MLReader` for
:py:class:`JavaParams` types
"""
- def __init__(self, clazz):
+ def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
super(JavaMLReader, self).__init__()
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()
- def load(self, path):
+ def load(self, path: str) -> RL:
"""Load the ML instance from the input path."""
if not isinstance(path, str):
raise TypeError("path should be a string, got type %s" %
type(path))
@@ -284,15 +310,15 @@ class JavaMLReader(MLReader):
raise NotImplementedError(
"This Java ML type cannot be loaded into Python currently: %r"
% self._clazz
)
- return self._clazz._from_java(java_obj)
+ return self._clazz._from_java(java_obj) # type: ignore[attr-defined]
- def session(self, sparkSession):
+ def session(self: JR, sparkSession: SparkSession) -> JR:
"""Sets the Spark Session to use for loading."""
self._jread.session(sparkSession._jsparkSession)
return self
@classmethod
- def _java_loader_class(cls, clazz):
+ def _java_loader_class(cls, clazz: Type["JavaMLReadable[RL]"]) -> str:
"""
Returns the full class name of the Java ML instance. The default
implementation replaces "pyspark" by "org.apache.spark" in
@@ -305,7 +331,7 @@ class JavaMLReader(MLReader):
return java_package + "." + clazz.__name__
@classmethod
- def _load_java_obj(cls, clazz):
+ def _load_java_obj(cls, clazz: Type["JavaMLReadable[RL]"]) -> "JavaObject":
"""Load the peer Java object of the ML instance."""
java_class = cls._java_loader_class(clazz)
java_obj = _jvm()
@@ -315,7 +341,7 @@ class JavaMLReader(MLReader):
@inherit_doc
-class MLReadable:
+class MLReadable(Generic[RL]):
"""
Mixin for instances that provide :py:class:`MLReader`.
@@ -323,24 +349,24 @@ class MLReadable:
"""
@classmethod
- def read(cls):
+ def read(cls) -> MLReader[RL]:
"""Returns an MLReader instance for this class."""
raise NotImplementedError("MLReadable.read() not implemented for type:
%r" % cls)
@classmethod
- def load(cls, path):
+ def load(cls, path: str) -> RL:
"""Reads an ML instance from the input path, a shortcut of
`read().load(path)`."""
return cls.read().load(path)
@inherit_doc
-class JavaMLReadable(MLReadable):
+class JavaMLReadable(MLReadable[RL]):
"""
(Private) Mixin for instances that provide JavaMLReader.
"""
@classmethod
- def read(cls):
+ def read(cls) -> JavaMLReader[RL]:
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
@@ -358,7 +384,7 @@ class DefaultParamsWritable(MLWritable):
.. versionadded:: 2.3.0
"""
- def write(self):
+ def write(self) -> MLWriter:
"""Returns a DefaultParamsWriter instance for this class."""
from pyspark.ml.param import Params
@@ -382,15 +408,15 @@ class DefaultParamsWriter(MLWriter):
.. versionadded:: 2.3.0
"""
- def __init__(self, instance):
+ def __init__(self, instance: "Params"):
super(DefaultParamsWriter, self).__init__()
self.instance = instance
- def saveImpl(self, path):
+ def saveImpl(self, path: str) -> None:
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
@staticmethod
- def extractJsonParams(instance, skipParams):
+ def extractJsonParams(instance: "Params", skipParams: Sequence[str]) ->
Dict[str, Any]:
paramMap = instance.extractParamMap()
jsonParams = {
param.name: value for param, value in paramMap.items() if
param.name not in skipParams
@@ -398,7 +424,13 @@ class DefaultParamsWriter(MLWriter):
return jsonParams
@staticmethod
- def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
+ def saveMetadata(
+ instance: "Params",
+ path: str,
+ sc: SparkContext,
+ extraMetadata: Optional[Dict[str, Any]] = None,
+ paramMap: Optional[Dict[str, "Param"]] = None,
+ ) -> None:
"""
Saves metadata + Params to: path + "/metadata"
@@ -424,7 +456,12 @@ class DefaultParamsWriter(MLWriter):
sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
@staticmethod
- def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
+ def _get_metadata_to_save(
+ instance: "Params",
+ sc: SparkContext,
+ extraMetadata: Optional[Dict[str, Any]] = None,
+ paramMap: Optional[Dict[str, "Param"]] = None,
+ ) -> str:
"""
Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts
the JSON to save.
This is useful for ensemble models which need to save metadata for
many sub-models.
@@ -460,11 +497,11 @@ class DefaultParamsWriter(MLWriter):
}
if extraMetadata is not None:
basicMetadata.update(extraMetadata)
- return json.dumps(basicMetadata, separators=[",", ":"])
+ return json.dumps(basicMetadata, separators=(",", ":"))
@inherit_doc
-class DefaultParamsReadable(MLReadable):
+class DefaultParamsReadable(MLReadable[RL]):
"""
Helper trait for making simple :py:class:`Params` types readable.
If a :py:class:`Params` class stores all data as :py:class:`Param` values,
@@ -477,13 +514,13 @@ class DefaultParamsReadable(MLReadable):
"""
@classmethod
- def read(cls):
+ def read(cls) -> "DefaultParamsReader[RL]":
"""Returns a DefaultParamsReader instance for this class."""
return DefaultParamsReader(cls)
@inherit_doc
-class DefaultParamsReader(MLReader):
+class DefaultParamsReader(MLReader[RL]):
"""
Specialization of :py:class:`MLReader` for :py:class:`Params` types
@@ -494,12 +531,12 @@ class DefaultParamsReader(MLReader):
.. versionadded:: 2.3.0
"""
- def __init__(self, cls):
+ def __init__(self, cls: Type[DefaultParamsReadable[RL]]):
super(DefaultParamsReader, self).__init__()
self.cls = cls
@staticmethod
- def __get_class(clazz):
+ def __get_class(clazz: str) -> Type[RL]:
"""
Loads Python class from its name.
"""
@@ -510,16 +547,16 @@ class DefaultParamsReader(MLReader):
m = getattr(m, comp)
return m
- def load(self, path):
+ def load(self, path: str) -> RL:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
- py_type = DefaultParamsReader.__get_class(metadata["class"])
+ py_type: Type[RL] = DefaultParamsReader.__get_class(metadata["class"])
instance = py_type()
- instance._resetUid(metadata["uid"])
+ cast("Params", instance)._resetUid(metadata["uid"])
DefaultParamsReader.getAndSetParams(instance, metadata)
return instance
@staticmethod
- def loadMetadata(path, sc, expectedClassName=""):
+ def loadMetadata(path: str, sc: SparkContext, expectedClassName: str = "")
-> Dict[str, Any]:
"""
Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
@@ -536,7 +573,7 @@ class DefaultParamsReader(MLReader):
return loadedVals
@staticmethod
- def _parseMetaData(metadataStr, expectedClassName=""):
+ def _parseMetaData(metadataStr: str, expectedClassName: str = "") ->
Dict[str, Any]:
"""
Parse metadata JSON string produced by
:py:meth`DefaultParamsWriter._get_metadata_to_save`.
This is a helper function for
:py:meth:`DefaultParamsReader.loadMetadata`.
@@ -558,16 +595,18 @@ class DefaultParamsReader(MLReader):
return metadata
@staticmethod
- def getAndSetParams(instance, metadata, skipParams=None):
+ def getAndSetParams(
+ instance: RL, metadata: Dict[str, Any], skipParams:
Optional[List[str]] = None
+ ) -> None:
"""
Extract Params from metadata, and set them in the instance.
"""
# Set user-supplied param values
for paramName in metadata["paramMap"]:
- param = instance.getParam(paramName)
+ param = cast("Params", instance).getParam(paramName)
if skipParams is None or paramName not in skipParams:
paramValue = metadata["paramMap"][paramName]
- instance.set(param, paramValue)
+ cast("Params", instance).set(param, paramValue)
# Set default param values
majorAndMinorVersions =
VersionUtils.majorMinorVersion(metadata["sparkVersion"])
@@ -582,14 +621,14 @@ class DefaultParamsReader(MLReader):
for paramName in metadata["defaultParamMap"]:
paramValue = metadata["defaultParamMap"][paramName]
- instance._setDefault(**{paramName: paramValue})
+ cast("Params", instance)._setDefault(**{paramName: paramValue})
@staticmethod
- def isPythonParamsInstance(metadata):
+ def isPythonParamsInstance(metadata: Dict[str, Any]) -> bool:
return metadata["class"].startswith("pyspark.ml.")
@staticmethod
- def loadParamsInstance(path, sc):
+ def loadParamsInstance(path: str, sc: SparkContext) -> RL:
"""
Load a :py:class:`Params` instance from the given path, and return it.
This assumes the instance inherits from :py:class:`MLReadable`.
@@ -599,41 +638,41 @@ class DefaultParamsReader(MLReader):
pythonClassName = metadata["class"]
else:
pythonClassName = metadata["class"].replace("org.apache.spark",
"pyspark")
- py_type = DefaultParamsReader.__get_class(pythonClassName)
+ py_type: Type[RL] = DefaultParamsReader.__get_class(pythonClassName)
instance = py_type.load(path)
return instance
@inherit_doc
-class HasTrainingSummary:
+class HasTrainingSummary(Generic[T]):
"""
Base class for models that provides Training summary.
.. versionadded:: 3.0.0
"""
- @property
+ @property # type: ignore[misc]
@since("2.1.0")
- def hasSummary(self):
+ def hasSummary(self) -> bool:
"""
Indicates whether a training summary exists for this model
instance.
"""
- return self._call_java("hasSummary")
+ return cast("JavaWrapper", self)._call_java("hasSummary")
- @property
+ @property # type: ignore[misc]
@since("2.1.0")
- def summary(self):
+ def summary(self) -> T:
"""
Gets summary of the model trained on the training set. An exception is
thrown if
no summary exists.
"""
- return self._call_java("summary")
+ return cast("JavaWrapper", self)._call_java("summary")
class MetaAlgorithmReadWrite:
@staticmethod
- def isMetaEstimator(pyInstance):
+ def isMetaEstimator(pyInstance: Any) -> bool:
from pyspark.ml import Estimator, Pipeline
from pyspark.ml.tuning import _ValidatorParams
from pyspark.ml.classification import OneVsRest
@@ -645,13 +684,15 @@ class MetaAlgorithmReadWrite:
)
@staticmethod
- def getAllNestedStages(pyInstance):
+ def getAllNestedStages(pyInstance: Any) -> List["PipelineStage"]:
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.tuning import _ValidatorParams
from pyspark.ml.classification import OneVsRest, OneVsRestModel
# TODO: We need to handle `RFormulaModel.pipelineModel` here after
Pyspark RFormulaModel
# support pipelineModel property.
+ pySubStages: List["PipelineStage"]
+
if isinstance(pyInstance, Pipeline):
pySubStages = pyInstance.getStages()
elif isinstance(pyInstance, PipelineModel):
@@ -661,7 +702,9 @@ class MetaAlgorithmReadWrite:
elif isinstance(pyInstance, OneVsRest):
pySubStages = [pyInstance.getClassifier()]
elif isinstance(pyInstance, OneVsRestModel):
- pySubStages = [pyInstance.getClassifier()] + pyInstance.models
+ pySubStages = [
+ pyInstance.getClassifier()
+ ] + pyInstance.models # type: ignore[assignment, operator]
else:
pySubStages = []
@@ -672,7 +715,7 @@ class MetaAlgorithmReadWrite:
return [pyInstance] + nestedStages
@staticmethod
- def getUidMap(instance):
+ def getUidMap(instance: Any) -> Dict[str, "PipelineStage"]:
nestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance)
uidMap = {stage.uid: stage for stage in nestedStages}
if len(nestedStages) != len(uidMap):
diff --git a/python/pyspark/ml/util.pyi b/python/pyspark/ml/util.pyi
deleted file mode 100644
index db28c09..0000000
--- a/python/pyspark/ml/util.pyi
+++ /dev/null
@@ -1,136 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union
-
-from pyspark import SparkContext as SparkContext, since as since # noqa: F401
-from pyspark.ml.common import inherit_doc as inherit_doc # noqa: F401
-from pyspark.sql import SparkSession as SparkSession
-from pyspark.util import VersionUtils as VersionUtils # noqa: F401
-
-S = TypeVar("S")
-R = TypeVar("R", bound=MLReadable)
-
-class Identifiable:
- uid: str
- def __init__(self) -> None: ...
-
-class BaseReadWrite:
- def __init__(self) -> None: ...
- def session(self, sparkSession: SparkSession) -> Union[MLWriter,
MLReader]: ...
- @property
- def sparkSession(self) -> SparkSession: ...
- @property
- def sc(self) -> SparkContext: ...
-
-class MLWriter(BaseReadWrite):
- shouldOverwrite: bool = ...
- def __init__(self) -> None: ...
- def save(self, path: str) -> None: ...
- def saveImpl(self, path: str) -> None: ...
- def overwrite(self) -> MLWriter: ...
-
-class GeneralMLWriter(MLWriter):
- source: str
- def format(self, source: str) -> MLWriter: ...
-
-class JavaMLWriter(MLWriter):
- def __init__(self, instance: JavaMLWritable) -> None: ...
- def save(self, path: str) -> None: ...
- def overwrite(self) -> JavaMLWriter: ...
- def option(self, key: str, value: Any) -> JavaMLWriter: ...
- def session(self, sparkSession: SparkSession) -> JavaMLWriter: ...
-
-class GeneralJavaMLWriter(JavaMLWriter):
- def __init__(self, instance: MLWritable) -> None: ...
- def format(self, source: str) -> GeneralJavaMLWriter: ...
-
-class MLWritable:
- def write(self) -> MLWriter: ...
- def save(self, path: str) -> None: ...
-
-class JavaMLWritable(MLWritable):
- def write(self) -> JavaMLWriter: ...
-
-class GeneralJavaMLWritable(JavaMLWritable):
- def write(self) -> GeneralJavaMLWriter: ...
-
-class MLReader(BaseReadWrite, Generic[R]):
- def load(self, path: str) -> R: ...
-
-class JavaMLReader(MLReader[R]):
- def __init__(self, clazz: Type[JavaMLReadable]) -> None: ...
- def load(self, path: str) -> R: ...
- def session(self, sparkSession: SparkSession) -> JavaMLReader[R]: ...
-
-class MLReadable(Generic[R]):
- @classmethod
- def read(cls: Type[R]) -> MLReader[R]: ...
- @classmethod
- def load(cls: Type[R], path: str) -> R: ...
-
-class JavaMLReadable(MLReadable[R]):
- @classmethod
- def read(cls: Type[R]) -> JavaMLReader[R]: ...
-
-class DefaultParamsWritable(MLWritable):
- def write(self) -> MLWriter: ...
-
-class DefaultParamsWriter(MLWriter):
- instance: DefaultParamsWritable
- def __init__(self, instance: DefaultParamsWritable) -> None: ...
- def saveImpl(self, path: str) -> None: ...
- @staticmethod
- def saveMetadata(
- instance: DefaultParamsWritable,
- path: str,
- sc: SparkContext,
- extraMetadata: Optional[Dict[str, Any]] = ...,
- paramMap: Optional[Dict[str, Any]] = ...,
- ) -> None: ...
-
-class DefaultParamsReadable(MLReadable[R]):
- @classmethod
- def read(cls: Type[R]) -> MLReader[R]: ...
-
-class DefaultParamsReader(MLReader[R]):
- cls: Type[R]
- def __init__(self, cls: Type[MLReadable]) -> None: ...
- def load(self, path: str) -> R: ...
- @staticmethod
- def loadMetadata(
- path: str, sc: SparkContext, expectedClassName: str = ...
- ) -> Dict[str, Any]: ...
- @staticmethod
- def getAndSetParams(instance: R, metadata: Dict[str, Any]) -> None: ...
- @staticmethod
- def loadParamsInstance(path: str, sc: SparkContext) -> R: ...
-
-class HasTrainingSummary(Generic[S]):
- @property
- def hasSummary(self) -> bool: ...
- @property
- def summary(self) -> S: ...
-
-class MetaAlgorithmReadWrite:
- @staticmethod
- def isMetaEstimator(pyInstance: Any) -> bool: ...
- @staticmethod
- def getAllNestedStages(pyInstance: Any) -> list: ...
- @staticmethod
- def getUidMap(instance: Any) -> dict: ...
diff --git a/python/pyspark/ml/wrapper.pyi b/python/pyspark/ml/wrapper.pyi
index a238436..7b3bfb4 100644
--- a/python/pyspark/ml/wrapper.pyi
+++ b/python/pyspark/ml/wrapper.pyi
@@ -28,6 +28,7 @@ from pyspark.sql.dataframe import DataFrame
class JavaWrapper:
def __init__(self, java_obj: Optional[Any] = ...) -> None: ...
def __del__(self) -> None: ...
+ def _call_java(self, name: str, *args: Any) -> Any: ...
class JavaParams(JavaWrapper, Params, metaclass=abc.ABCMeta):
def copy(self: P, extra: Optional[ParamMap] = ...) -> P: ...
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]