This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0084a86  [SPARK-37415][PYTHON][ML] Inline type hints for 
pyspark.ml.util
0084a86 is described below

commit 0084a8677ad77143b109dff3d3e9be4035d00fd4
Author: zero323 <[email protected]>
AuthorDate: Sun Feb 6 11:09:41 2022 +0100

    [SPARK-37415][PYTHON][ML] Inline type hints for pyspark.ml.util
    
    ### What changes were proposed in this pull request?
    
    This PR migrates type `pyspark.ml.util` annotations from stub file to 
inline type hints.
    
    ### Why are the changes needed?
    
    Part of ongoing migration of type hints.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #35367 from zero323/SPARK-37415.
    
    Authored-by: zero323 <[email protected]>
    Signed-off-by: zero323 <[email protected]>
---
 python/pyspark/ml/base.py     |   4 +-
 python/pyspark/ml/util.py     | 207 +++++++++++++++++++++++++-----------------
 python/pyspark/ml/util.pyi    | 136 ---------------------------
 python/pyspark/ml/wrapper.pyi |   1 +
 4 files changed, 128 insertions(+), 220 deletions(-)

diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 4f4ddef..9e8252d 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -104,7 +104,7 @@ class _FitMultipleIterator(Generic[M]):
 
 
 @inherit_doc
-class Estimator(Generic[M], Params, metaclass=ABCMeta):
+class Estimator(Params, Generic[M], metaclass=ABCMeta):
     """
     Abstract class for estimators that fit models to data.
 
@@ -382,7 +382,7 @@ class Predictor(Estimator[M], _PredictorParams, 
metaclass=ABCMeta):
 
 
 @inherit_doc
-class PredictionModel(Generic[T], Transformer, _PredictorParams, 
metaclass=ABCMeta):
+class PredictionModel(Model, _PredictorParams, Generic[T], metaclass=ABCMeta):
     """
     Model for prediction tasks (regression and classification).
     """
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index ac60ded..1dacffc 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -20,13 +20,31 @@ import os
 import time
 import uuid
 
+from typing import Any, Dict, Generic, List, Optional, Sequence, Type, 
TypeVar, cast, TYPE_CHECKING
+
+
 from pyspark import SparkContext, since
 from pyspark.ml.common import inherit_doc
 from pyspark.sql import SparkSession
 from pyspark.util import VersionUtils
 
+if TYPE_CHECKING:
+    from py4j.java_gateway import JavaGateway, JavaObject
+    from pyspark.ml._typing import PipelineStage
+
+    from pyspark.ml.param import Param
+    from pyspark.ml.base import Params
+    from pyspark.ml.wrapper import JavaWrapper
 
-def _jvm():
+T = TypeVar("T")
+RW = TypeVar("RW", bound="BaseReadWrite")
+W = TypeVar("W", bound="MLWriter")
+JW = TypeVar("JW", bound="JavaMLWriter")
+RL = TypeVar("RL", bound="MLReadable")
+JR = TypeVar("JR", bound="JavaMLReader")
+
+
+def _jvm() -> "JavaGateway":
     """
     Returns the JVM view associated with SparkContext. Must be called
     after SparkContext is initialized.
@@ -43,15 +61,15 @@ class Identifiable:
     Object with a unique ID.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         #: A unique id for the object.
         self.uid = self._randomUID()
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self.uid
 
     @classmethod
-    def _randomUID(cls):
+    def _randomUID(cls) -> str:
         """
         Generate a unique string id for the object. The default implementation
         concatenates the class name, "_", and 12 random hex chars.
@@ -68,10 +86,10 @@ class BaseReadWrite:
     .. versionadded:: 2.3.0
     """
 
-    def __init__(self):
-        self._sparkSession = None
+    def __init__(self) -> None:
+        self._sparkSession: Optional[SparkSession] = None
 
-    def session(self, sparkSession):
+    def session(self: RW, sparkSession: SparkSession) -> RW:
         """
         Sets the Spark Session to use for saving/loading.
         """
@@ -79,19 +97,21 @@ class BaseReadWrite:
         return self
 
     @property
-    def sparkSession(self):
+    def sparkSession(self) -> SparkSession:
         """
         Returns the user-specified Spark Session or the default.
         """
         if self._sparkSession is None:
             self._sparkSession = SparkSession._getActiveSessionOrCreate()
+        assert self._sparkSession is not None
         return self._sparkSession
 
     @property
-    def sc(self):
+    def sc(self) -> SparkContext:
         """
         Returns the underlying `SparkContext`.
         """
+        assert self.sparkSession is not None
         return self.sparkSession.sparkContext
 
 
@@ -103,37 +123,41 @@ class MLWriter(BaseReadWrite):
     .. versionadded:: 2.0.0
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         super(MLWriter, self).__init__()
-        self.shouldOverwrite = False
-        self.optionMap = {}
+        self.shouldOverwrite: bool = False
+        self.optionMap: Dict[str, Any] = {}
 
-    def _handleOverwrite(self, path):
+    def _handleOverwrite(self, path: str) -> None:
         from pyspark.ml.wrapper import JavaWrapper
 
-        _java_obj = 
JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
+        _java_obj = JavaWrapper._new_java_obj(  # type: ignore[attr-defined]
+            "org.apache.spark.ml.util.FileSystemOverwrite"
+        )
         wrapper = JavaWrapper(_java_obj)
-        wrapper._call_java("handleOverwrite", path, True, 
self.sparkSession._jsparkSession)
+        wrapper._call_java(  # type: ignore[attr-defined]
+            "handleOverwrite", path, True, self.sparkSession._jsparkSession
+        )
 
-    def save(self, path):
+    def save(self, path: str) -> None:
         """Save the ML instance to the input path."""
         if self.shouldOverwrite:
             self._handleOverwrite(path)
         self.saveImpl(path)
 
-    def saveImpl(self, path):
+    def saveImpl(self, path: str) -> None:
         """
         save() handles overwriting and then calls this method.  Subclasses 
should override this
         method to implement the actual saving of the instance.
         """
         raise NotImplementedError("MLWriter is not yet implemented for type: 
%s" % type(self))
 
-    def overwrite(self):
+    def overwrite(self) -> "MLWriter":
         """Overwrites if the output path already exists."""
         self.shouldOverwrite = True
         return self
 
-    def option(self, key, value):
+    def option(self, key: str, value: Any) -> "MLWriter":
         """
         Adds an option to the underlying MLWriter. See the documentation for 
the specific model's
         writer for possible options. The option name (key) is case-insensitive.
@@ -150,7 +174,7 @@ class GeneralMLWriter(MLWriter):
     .. versionadded:: 2.4.0
     """
 
-    def format(self, source):
+    def format(self, source: str) -> "GeneralMLWriter":
         """
         Specifies the format of ML export ("pmml", "internal", or the fully 
qualified class
         name for export).
@@ -165,27 +189,29 @@ class JavaMLWriter(MLWriter):
     (Private) Specialization of :py:class:`MLWriter` for 
:py:class:`JavaParams` types
     """
 
-    def __init__(self, instance):
+    _jwrite: "JavaObject"
+
+    def __init__(self, instance: "JavaMLWritable"):
         super(JavaMLWriter, self).__init__()
-        _java_obj = instance._to_java()
+        _java_obj = instance._to_java()  # type: ignore[attr-defined]
         self._jwrite = _java_obj.write()
 
-    def save(self, path):
+    def save(self, path: str) -> None:
         """Save the ML instance to the input path."""
         if not isinstance(path, str):
             raise TypeError("path should be a string, got type %s" % 
type(path))
         self._jwrite.save(path)
 
-    def overwrite(self):
+    def overwrite(self) -> "JavaMLWriter":
         """Overwrites if the output path already exists."""
         self._jwrite.overwrite()
         return self
 
-    def option(self, key, value):
+    def option(self, key: str, value: str) -> "JavaMLWriter":
         self._jwrite.option(key, value)
         return self
 
-    def session(self, sparkSession):
+    def session(self, sparkSession: SparkSession) -> "JavaMLWriter":
         """Sets the Spark Session to use for saving."""
         self._jwrite.session(sparkSession._jsparkSession)
         return self
@@ -197,10 +223,10 @@ class GeneralJavaMLWriter(JavaMLWriter):
     (Private) Specialization of :py:class:`GeneralMLWriter` for 
:py:class:`JavaParams` types
     """
 
-    def __init__(self, instance):
+    def __init__(self, instance: "JavaMLWritable"):
         super(GeneralJavaMLWriter, self).__init__(instance)
 
-    def format(self, source):
+    def format(self, source: str) -> "GeneralJavaMLWriter":
         """
         Specifies the format of ML export ("pmml", "internal", or the fully 
qualified class
         name for export).
@@ -217,11 +243,11 @@ class MLWritable:
     .. versionadded:: 2.0.0
     """
 
-    def write(self):
+    def write(self) -> MLWriter:
         """Returns an MLWriter instance for this ML instance."""
         raise NotImplementedError("MLWritable is not yet implemented for type: 
%r" % type(self))
 
-    def save(self, path):
+    def save(self, path: str) -> None:
         """Save this ML instance to the given path, a shortcut of 
'write().save(path)'."""
         self.write().save(path)
 
@@ -232,7 +258,7 @@ class JavaMLWritable(MLWritable):
     (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
     """
 
-    def write(self):
+    def write(self) -> JavaMLWriter:
         """Returns an MLWriter instance for this ML instance."""
         return JavaMLWriter(self)
 
@@ -243,39 +269,39 @@ class GeneralJavaMLWritable(JavaMLWritable):
     (Private) Mixin for ML instances that provide 
:py:class:`GeneralJavaMLWriter`.
     """
 
-    def write(self):
+    def write(self) -> GeneralJavaMLWriter:
         """Returns an GeneralMLWriter instance for this ML instance."""
         return GeneralJavaMLWriter(self)
 
 
 @inherit_doc
-class MLReader(BaseReadWrite):
+class MLReader(BaseReadWrite, Generic[RL]):
     """
     Utility class that can load ML instances.
 
     .. versionadded:: 2.0.0
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         super(MLReader, self).__init__()
 
-    def load(self, path):
+    def load(self, path: str) -> RL:
         """Load the ML instance from the input path."""
         raise NotImplementedError("MLReader is not yet implemented for type: 
%s" % type(self))
 
 
 @inherit_doc
-class JavaMLReader(MLReader):
+class JavaMLReader(MLReader[RL]):
     """
     (Private) Specialization of :py:class:`MLReader` for 
:py:class:`JavaParams` types
     """
 
-    def __init__(self, clazz):
+    def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
         super(JavaMLReader, self).__init__()
         self._clazz = clazz
         self._jread = self._load_java_obj(clazz).read()
 
-    def load(self, path):
+    def load(self, path: str) -> RL:
         """Load the ML instance from the input path."""
         if not isinstance(path, str):
             raise TypeError("path should be a string, got type %s" % 
type(path))
@@ -284,15 +310,15 @@ class JavaMLReader(MLReader):
             raise NotImplementedError(
                 "This Java ML type cannot be loaded into Python currently: %r" 
% self._clazz
             )
-        return self._clazz._from_java(java_obj)
+        return self._clazz._from_java(java_obj)  # type: ignore[attr-defined]
 
-    def session(self, sparkSession):
+    def session(self: JR, sparkSession: SparkSession) -> JR:
         """Sets the Spark Session to use for loading."""
         self._jread.session(sparkSession._jsparkSession)
         return self
 
     @classmethod
-    def _java_loader_class(cls, clazz):
+    def _java_loader_class(cls, clazz: Type["JavaMLReadable[RL]"]) -> str:
         """
         Returns the full class name of the Java ML instance. The default
         implementation replaces "pyspark" by "org.apache.spark" in
@@ -305,7 +331,7 @@ class JavaMLReader(MLReader):
         return java_package + "." + clazz.__name__
 
     @classmethod
-    def _load_java_obj(cls, clazz):
+    def _load_java_obj(cls, clazz: Type["JavaMLReadable[RL]"]) -> "JavaObject":
         """Load the peer Java object of the ML instance."""
         java_class = cls._java_loader_class(clazz)
         java_obj = _jvm()
@@ -315,7 +341,7 @@ class JavaMLReader(MLReader):
 
 
 @inherit_doc
-class MLReadable:
+class MLReadable(Generic[RL]):
     """
     Mixin for instances that provide :py:class:`MLReader`.
 
@@ -323,24 +349,24 @@ class MLReadable:
     """
 
     @classmethod
-    def read(cls):
+    def read(cls) -> MLReader[RL]:
         """Returns an MLReader instance for this class."""
         raise NotImplementedError("MLReadable.read() not implemented for type: 
%r" % cls)
 
     @classmethod
-    def load(cls, path):
+    def load(cls, path: str) -> RL:
         """Reads an ML instance from the input path, a shortcut of 
`read().load(path)`."""
         return cls.read().load(path)
 
 
 @inherit_doc
-class JavaMLReadable(MLReadable):
+class JavaMLReadable(MLReadable[RL]):
     """
     (Private) Mixin for instances that provide JavaMLReader.
     """
 
     @classmethod
-    def read(cls):
+    def read(cls) -> JavaMLReader[RL]:
         """Returns an MLReader instance for this class."""
         return JavaMLReader(cls)
 
@@ -358,7 +384,7 @@ class DefaultParamsWritable(MLWritable):
     .. versionadded:: 2.3.0
     """
 
-    def write(self):
+    def write(self) -> MLWriter:
         """Returns a DefaultParamsWriter instance for this class."""
         from pyspark.ml.param import Params
 
@@ -382,15 +408,15 @@ class DefaultParamsWriter(MLWriter):
     .. versionadded:: 2.3.0
     """
 
-    def __init__(self, instance):
+    def __init__(self, instance: "Params"):
         super(DefaultParamsWriter, self).__init__()
         self.instance = instance
 
-    def saveImpl(self, path):
+    def saveImpl(self, path: str) -> None:
         DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
 
     @staticmethod
-    def extractJsonParams(instance, skipParams):
+    def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> 
Dict[str, Any]:
         paramMap = instance.extractParamMap()
         jsonParams = {
             param.name: value for param, value in paramMap.items() if 
param.name not in skipParams
@@ -398,7 +424,13 @@ class DefaultParamsWriter(MLWriter):
         return jsonParams
 
     @staticmethod
-    def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
+    def saveMetadata(
+        instance: "Params",
+        path: str,
+        sc: SparkContext,
+        extraMetadata: Optional[Dict[str, Any]] = None,
+        paramMap: Optional[Dict[str, "Param"]] = None,
+    ) -> None:
         """
         Saves metadata + Params to: path + "/metadata"
 
@@ -424,7 +456,12 @@ class DefaultParamsWriter(MLWriter):
         sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
 
     @staticmethod
-    def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
+    def _get_metadata_to_save(
+        instance: "Params",
+        sc: SparkContext,
+        extraMetadata: Optional[Dict[str, Any]] = None,
+        paramMap: Optional[Dict[str, "Param"]] = None,
+    ) -> str:
         """
         Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts 
the JSON to save.
         This is useful for ensemble models which need to save metadata for 
many sub-models.
@@ -460,11 +497,11 @@ class DefaultParamsWriter(MLWriter):
         }
         if extraMetadata is not None:
             basicMetadata.update(extraMetadata)
-        return json.dumps(basicMetadata, separators=[",", ":"])
+        return json.dumps(basicMetadata, separators=(",", ":"))
 
 
 @inherit_doc
-class DefaultParamsReadable(MLReadable):
+class DefaultParamsReadable(MLReadable[RL]):
     """
     Helper trait for making simple :py:class:`Params` types readable.
     If a :py:class:`Params` class stores all data as :py:class:`Param` values,
@@ -477,13 +514,13 @@ class DefaultParamsReadable(MLReadable):
     """
 
     @classmethod
-    def read(cls):
+    def read(cls) -> "DefaultParamsReader[RL]":
         """Returns a DefaultParamsReader instance for this class."""
         return DefaultParamsReader(cls)
 
 
 @inherit_doc
-class DefaultParamsReader(MLReader):
+class DefaultParamsReader(MLReader[RL]):
     """
     Specialization of :py:class:`MLReader` for :py:class:`Params` types
 
@@ -494,12 +531,12 @@ class DefaultParamsReader(MLReader):
     .. versionadded:: 2.3.0
     """
 
-    def __init__(self, cls):
+    def __init__(self, cls: Type[DefaultParamsReadable[RL]]):
         super(DefaultParamsReader, self).__init__()
         self.cls = cls
 
     @staticmethod
-    def __get_class(clazz):
+    def __get_class(clazz: str) -> Type[RL]:
         """
         Loads Python class from its name.
         """
@@ -510,16 +547,16 @@ class DefaultParamsReader(MLReader):
             m = getattr(m, comp)
         return m
 
-    def load(self, path):
+    def load(self, path: str) -> RL:
         metadata = DefaultParamsReader.loadMetadata(path, self.sc)
-        py_type = DefaultParamsReader.__get_class(metadata["class"])
+        py_type: Type[RL] = DefaultParamsReader.__get_class(metadata["class"])
         instance = py_type()
-        instance._resetUid(metadata["uid"])
+        cast("Params", instance)._resetUid(metadata["uid"])
         DefaultParamsReader.getAndSetParams(instance, metadata)
         return instance
 
     @staticmethod
-    def loadMetadata(path, sc, expectedClassName=""):
+    def loadMetadata(path: str, sc: SparkContext, expectedClassName: str = "") 
-> Dict[str, Any]:
         """
         Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
 
@@ -536,7 +573,7 @@ class DefaultParamsReader(MLReader):
         return loadedVals
 
     @staticmethod
-    def _parseMetaData(metadataStr, expectedClassName=""):
+    def _parseMetaData(metadataStr: str, expectedClassName: str = "") -> 
Dict[str, Any]:
         """
         Parse metadata JSON string produced by 
:py:meth`DefaultParamsWriter._get_metadata_to_save`.
         This is a helper function for 
:py:meth:`DefaultParamsReader.loadMetadata`.
@@ -558,16 +595,18 @@ class DefaultParamsReader(MLReader):
         return metadata
 
     @staticmethod
-    def getAndSetParams(instance, metadata, skipParams=None):
+    def getAndSetParams(
+        instance: RL, metadata: Dict[str, Any], skipParams: 
Optional[List[str]] = None
+    ) -> None:
         """
         Extract Params from metadata, and set them in the instance.
         """
         # Set user-supplied param values
         for paramName in metadata["paramMap"]:
-            param = instance.getParam(paramName)
+            param = cast("Params", instance).getParam(paramName)
             if skipParams is None or paramName not in skipParams:
                 paramValue = metadata["paramMap"][paramName]
-                instance.set(param, paramValue)
+                cast("Params", instance).set(param, paramValue)
 
         # Set default param values
         majorAndMinorVersions = 
VersionUtils.majorMinorVersion(metadata["sparkVersion"])
@@ -582,14 +621,14 @@ class DefaultParamsReader(MLReader):
 
             for paramName in metadata["defaultParamMap"]:
                 paramValue = metadata["defaultParamMap"][paramName]
-                instance._setDefault(**{paramName: paramValue})
+                cast("Params", instance)._setDefault(**{paramName: paramValue})
 
     @staticmethod
-    def isPythonParamsInstance(metadata):
+    def isPythonParamsInstance(metadata: Dict[str, Any]) -> bool:
         return metadata["class"].startswith("pyspark.ml.")
 
     @staticmethod
-    def loadParamsInstance(path, sc):
+    def loadParamsInstance(path: str, sc: SparkContext) -> RL:
         """
         Load a :py:class:`Params` instance from the given path, and return it.
         This assumes the instance inherits from :py:class:`MLReadable`.
@@ -599,41 +638,41 @@ class DefaultParamsReader(MLReader):
             pythonClassName = metadata["class"]
         else:
             pythonClassName = metadata["class"].replace("org.apache.spark", 
"pyspark")
-        py_type = DefaultParamsReader.__get_class(pythonClassName)
+        py_type: Type[RL] = DefaultParamsReader.__get_class(pythonClassName)
         instance = py_type.load(path)
         return instance
 
 
 @inherit_doc
-class HasTrainingSummary:
+class HasTrainingSummary(Generic[T]):
     """
     Base class for models that provides Training summary.
 
     .. versionadded:: 3.0.0
     """
 
-    @property
+    @property  # type: ignore[misc]
     @since("2.1.0")
-    def hasSummary(self):
+    def hasSummary(self) -> bool:
         """
         Indicates whether a training summary exists for this model
         instance.
         """
-        return self._call_java("hasSummary")
+        return cast("JavaWrapper", self)._call_java("hasSummary")
 
-    @property
+    @property  # type: ignore[misc]
     @since("2.1.0")
-    def summary(self):
+    def summary(self) -> T:
         """
         Gets summary of the model trained on the training set. An exception is 
thrown if
         no summary exists.
         """
-        return self._call_java("summary")
+        return cast("JavaWrapper", self)._call_java("summary")
 
 
 class MetaAlgorithmReadWrite:
     @staticmethod
-    def isMetaEstimator(pyInstance):
+    def isMetaEstimator(pyInstance: Any) -> bool:
         from pyspark.ml import Estimator, Pipeline
         from pyspark.ml.tuning import _ValidatorParams
         from pyspark.ml.classification import OneVsRest
@@ -645,13 +684,15 @@ class MetaAlgorithmReadWrite:
         )
 
     @staticmethod
-    def getAllNestedStages(pyInstance):
+    def getAllNestedStages(pyInstance: Any) -> List["PipelineStage"]:
         from pyspark.ml import Pipeline, PipelineModel
         from pyspark.ml.tuning import _ValidatorParams
         from pyspark.ml.classification import OneVsRest, OneVsRestModel
 
         # TODO: We need to handle `RFormulaModel.pipelineModel` here after 
Pyspark RFormulaModel
         #  support pipelineModel property.
+        pySubStages: List["PipelineStage"]
+
         if isinstance(pyInstance, Pipeline):
             pySubStages = pyInstance.getStages()
         elif isinstance(pyInstance, PipelineModel):
@@ -661,7 +702,9 @@ class MetaAlgorithmReadWrite:
         elif isinstance(pyInstance, OneVsRest):
             pySubStages = [pyInstance.getClassifier()]
         elif isinstance(pyInstance, OneVsRestModel):
-            pySubStages = [pyInstance.getClassifier()] + pyInstance.models
+            pySubStages = [
+                pyInstance.getClassifier()
+            ] + pyInstance.models  # type: ignore[assignment, operator]
         else:
             pySubStages = []
 
@@ -672,7 +715,7 @@ class MetaAlgorithmReadWrite:
         return [pyInstance] + nestedStages
 
     @staticmethod
-    def getUidMap(instance):
+    def getUidMap(instance: Any) -> Dict[str, "PipelineStage"]:
         nestedStages = MetaAlgorithmReadWrite.getAllNestedStages(instance)
         uidMap = {stage.uid: stage for stage in nestedStages}
         if len(nestedStages) != len(uidMap):
diff --git a/python/pyspark/ml/util.pyi b/python/pyspark/ml/util.pyi
deleted file mode 100644
index db28c09..0000000
--- a/python/pyspark/ml/util.pyi
+++ /dev/null
@@ -1,136 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union
-
-from pyspark import SparkContext as SparkContext, since as since  # noqa: F401
-from pyspark.ml.common import inherit_doc as inherit_doc  # noqa: F401
-from pyspark.sql import SparkSession as SparkSession
-from pyspark.util import VersionUtils as VersionUtils  # noqa: F401
-
-S = TypeVar("S")
-R = TypeVar("R", bound=MLReadable)
-
-class Identifiable:
-    uid: str
-    def __init__(self) -> None: ...
-
-class BaseReadWrite:
-    def __init__(self) -> None: ...
-    def session(self, sparkSession: SparkSession) -> Union[MLWriter, 
MLReader]: ...
-    @property
-    def sparkSession(self) -> SparkSession: ...
-    @property
-    def sc(self) -> SparkContext: ...
-
-class MLWriter(BaseReadWrite):
-    shouldOverwrite: bool = ...
-    def __init__(self) -> None: ...
-    def save(self, path: str) -> None: ...
-    def saveImpl(self, path: str) -> None: ...
-    def overwrite(self) -> MLWriter: ...
-
-class GeneralMLWriter(MLWriter):
-    source: str
-    def format(self, source: str) -> MLWriter: ...
-
-class JavaMLWriter(MLWriter):
-    def __init__(self, instance: JavaMLWritable) -> None: ...
-    def save(self, path: str) -> None: ...
-    def overwrite(self) -> JavaMLWriter: ...
-    def option(self, key: str, value: Any) -> JavaMLWriter: ...
-    def session(self, sparkSession: SparkSession) -> JavaMLWriter: ...
-
-class GeneralJavaMLWriter(JavaMLWriter):
-    def __init__(self, instance: MLWritable) -> None: ...
-    def format(self, source: str) -> GeneralJavaMLWriter: ...
-
-class MLWritable:
-    def write(self) -> MLWriter: ...
-    def save(self, path: str) -> None: ...
-
-class JavaMLWritable(MLWritable):
-    def write(self) -> JavaMLWriter: ...
-
-class GeneralJavaMLWritable(JavaMLWritable):
-    def write(self) -> GeneralJavaMLWriter: ...
-
-class MLReader(BaseReadWrite, Generic[R]):
-    def load(self, path: str) -> R: ...
-
-class JavaMLReader(MLReader[R]):
-    def __init__(self, clazz: Type[JavaMLReadable]) -> None: ...
-    def load(self, path: str) -> R: ...
-    def session(self, sparkSession: SparkSession) -> JavaMLReader[R]: ...
-
-class MLReadable(Generic[R]):
-    @classmethod
-    def read(cls: Type[R]) -> MLReader[R]: ...
-    @classmethod
-    def load(cls: Type[R], path: str) -> R: ...
-
-class JavaMLReadable(MLReadable[R]):
-    @classmethod
-    def read(cls: Type[R]) -> JavaMLReader[R]: ...
-
-class DefaultParamsWritable(MLWritable):
-    def write(self) -> MLWriter: ...
-
-class DefaultParamsWriter(MLWriter):
-    instance: DefaultParamsWritable
-    def __init__(self, instance: DefaultParamsWritable) -> None: ...
-    def saveImpl(self, path: str) -> None: ...
-    @staticmethod
-    def saveMetadata(
-        instance: DefaultParamsWritable,
-        path: str,
-        sc: SparkContext,
-        extraMetadata: Optional[Dict[str, Any]] = ...,
-        paramMap: Optional[Dict[str, Any]] = ...,
-    ) -> None: ...
-
-class DefaultParamsReadable(MLReadable[R]):
-    @classmethod
-    def read(cls: Type[R]) -> MLReader[R]: ...
-
-class DefaultParamsReader(MLReader[R]):
-    cls: Type[R]
-    def __init__(self, cls: Type[MLReadable]) -> None: ...
-    def load(self, path: str) -> R: ...
-    @staticmethod
-    def loadMetadata(
-        path: str, sc: SparkContext, expectedClassName: str = ...
-    ) -> Dict[str, Any]: ...
-    @staticmethod
-    def getAndSetParams(instance: R, metadata: Dict[str, Any]) -> None: ...
-    @staticmethod
-    def loadParamsInstance(path: str, sc: SparkContext) -> R: ...
-
-class HasTrainingSummary(Generic[S]):
-    @property
-    def hasSummary(self) -> bool: ...
-    @property
-    def summary(self) -> S: ...
-
-class MetaAlgorithmReadWrite:
-    @staticmethod
-    def isMetaEstimator(pyInstance: Any) -> bool: ...
-    @staticmethod
-    def getAllNestedStages(pyInstance: Any) -> list: ...
-    @staticmethod
-    def getUidMap(instance: Any) -> dict: ...
diff --git a/python/pyspark/ml/wrapper.pyi b/python/pyspark/ml/wrapper.pyi
index a238436..7b3bfb4 100644
--- a/python/pyspark/ml/wrapper.pyi
+++ b/python/pyspark/ml/wrapper.pyi
@@ -28,6 +28,7 @@ from pyspark.sql.dataframe import DataFrame
 class JavaWrapper:
     def __init__(self, java_obj: Optional[Any] = ...) -> None: ...
     def __del__(self) -> None: ...
+    def _call_java(self, name: str, *args: Any) -> Any: ...
 
 class JavaParams(JavaWrapper, Params, metaclass=abc.ABCMeta):
     def copy(self: P, extra: Optional[ParamMap] = ...) -> P: ...

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to