This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 356eada314e [SPARK-42828][PYTHON][SQL] More explicit Python type 
annotations for GroupedData
356eada314e is described below

commit 356eada314e88a4c0a262c6aa28e76045880e38f
Author: Joe Wang <[email protected]>
AuthorDate: Mon Jul 3 15:37:52 2023 +0900

    [SPARK-42828][PYTHON][SQL] More explicit Python type annotations for 
GroupedData
    
    ### What changes were proposed in this pull request?
    
    Be more explicit in the `Callable` type annotation for `dfapi` and 
`df_varargs_api` to explicitly return a `DataFrame`.
    
    ### Why are the changes needed?
    
    In PySpark 3.3.x, type hints now infer the return value of something like 
`df.groupBy(...).count()` to be `Any`, whereas it should be `DataFrame`. This 
breaks type checking.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    No runtime changes introduced, so just relied on CI tests.
    
    Closes #40460 from j03wang/grouped-data-type.
    
    Authored-by: Joe Wang <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/group.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 9568a971229..1b64e7666fd 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
 __all__ = ["GroupedData"]
 
 
-def dfapi(f: Callable) -> Callable:
+def dfapi(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
     def _api(self: "GroupedData") -> DataFrame:
         name = f.__name__
         jdf = getattr(self._jgd, name)()
@@ -43,7 +43,7 @@ def dfapi(f: Callable) -> Callable:
     return _api
 
 
-def df_varargs_api(f: Callable) -> Callable:
+def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]:
     def _api(self: "GroupedData", *cols: str) -> DataFrame:
         name = f.__name__
         jdf = getattr(self._jgd, name)(_to_seq(self.session._sc, cols))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to