This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new faf094a4a21a [SPARK-55904][PYTHON][CONNECT] Utilize
_check_same_session to narrow down types
faf094a4a21a is described below
commit faf094a4a21af79c4fd9dcbf3ca70aef2d50b4f3
Author: Tian Gao <[email protected]>
AuthorDate: Tue Mar 10 10:42:21 2026 +0800
[SPARK-55904][PYTHON][CONNECT] Utilize _check_same_session to narrow down
types
### What changes were proposed in this pull request?
* Make `_check_same_session` return the input argument if it's on the same
session
* Use that value to narrow down types so we can throw away some type: ignore
### Why are the changes needed?
When we can narrow down types, we should do it so type hint works better
(less exceptions).
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
mypy passed.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54707 from gaogaotiantian/check-same-session.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 66 ++++++++++++++-------------------
1 file changed, 28 insertions(+), 38 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 4c40f9512ce3..f07a14403698 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -315,20 +315,22 @@ class DataFrame(ParentDataFrame):
return table[0][0].as_py()
def crossJoin(self, other: ParentDataFrame) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
return DataFrame(
- plan.Join(
- left=self._plan, right=other._plan, on=None, how="cross" #
type: ignore[arg-type]
- ),
+ plan.Join(left=self._plan, right=other._plan, on=None,
how="cross"),
session=self._session,
)
- def _check_same_session(self, other: ParentDataFrame) -> None:
- if self._session.session_id != other._session.session_id: # type:
ignore[attr-defined]
+ def _check_same_session(self, other: ParentDataFrame) -> "DataFrame":
+ if (
+ not isinstance(other, DataFrame)
+ or self._session.session_id != other._session.session_id
+ ):
raise SessionNotSameException(
errorClass="SESSION_NOT_SAME",
messageParameters={},
)
+ return other
def coalesce(self, numPartitions: int) -> ParentDataFrame:
if not numPartitions > 0:
@@ -724,11 +726,11 @@ class DataFrame(ParentDataFrame):
on: Optional[Union[str, List[str], Column, List[Column]]] = None,
how: Optional[str] = None,
) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
if how is not None and isinstance(how, str):
how = how.lower().replace("_", "")
return DataFrame(
- plan.Join(left=self._plan, right=other._plan, on=on, how=how), #
type: ignore[arg-type]
+ plan.Join(left=self._plan, right=other._plan, on=on, how=how),
session=self._session,
)
@@ -738,13 +740,11 @@ class DataFrame(ParentDataFrame):
on: Optional[Column] = None,
how: Optional[str] = None,
) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
if how is not None and isinstance(how, str):
how = how.lower().replace("_", "")
return DataFrame(
- plan.LateralJoin(
- left=self._plan, right=cast(plan.LogicalPlan, other._plan),
on=on, how=how
- ),
+ plan.LateralJoin(left=self._plan, right=other._plan, on=on,
how=how),
session=self._session,
)
@@ -760,7 +760,7 @@ class DataFrame(ParentDataFrame):
allowExactMatches: bool = True,
direction: str = "backward",
) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
if how is None:
how = "inner"
assert isinstance(how, str), "how should be a string"
@@ -777,7 +777,7 @@ class DataFrame(ParentDataFrame):
return DataFrame(
plan.AsOfJoin(
left=self._plan,
- right=other._plan, # type: ignore[arg-type]
+ right=other._plan,
left_as_of=_convert_col(self, leftAsOfColumn),
right_as_of=_convert_col(other, rightAsOfColumn),
on=on,
@@ -1159,15 +1159,13 @@ class DataFrame(ParentDataFrame):
return None
def union(self, other: ParentDataFrame) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
return self.unionAll(other)
def unionAll(self, other: ParentDataFrame) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
res = DataFrame(
- plan.SetOperation(
- self._plan, other._plan, "union", is_all=True # type:
ignore[arg-type]
- ),
+ plan.SetOperation(self._plan, other._plan, "union", is_all=True),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
@@ -1176,11 +1174,11 @@ class DataFrame(ParentDataFrame):
def unionByName(
self, other: ParentDataFrame, allowMissingColumns: bool = False
) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
res = DataFrame(
plan.SetOperation(
self._plan,
- other._plan, # type: ignore[arg-type]
+ other._plan,
"union",
by_name=True,
allow_missing_columns=allowMissingColumns,
@@ -1191,22 +1189,18 @@ class DataFrame(ParentDataFrame):
return res
def subtract(self, other: ParentDataFrame) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
res = DataFrame(
- plan.SetOperation(
- self._plan, other._plan, "except", is_all=False # type:
ignore[arg-type]
- ),
+ plan.SetOperation(self._plan, other._plan, "except", is_all=False),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
return res
def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
res = DataFrame(
- plan.SetOperation(
- self._plan, other._plan, "except", is_all=True # type:
ignore[arg-type]
- ),
+ plan.SetOperation(self._plan, other._plan, "except", is_all=True),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
@@ -1218,22 +1212,18 @@ class DataFrame(ParentDataFrame):
)
def intersect(self, other: ParentDataFrame) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
res = DataFrame(
- plan.SetOperation(
- self._plan, other._plan, "intersect", is_all=False # type:
ignore[arg-type]
- ),
+ plan.SetOperation(self._plan, other._plan, "intersect",
is_all=False),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
return res
def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame:
- self._check_same_session(other)
+ other = self._check_same_session(other)
res = DataFrame(
- plan.SetOperation(
- self._plan, other._plan, "intersect", is_all=True # type:
ignore[arg-type]
- ),
+ plan.SetOperation(self._plan, other._plan, "intersect",
is_all=True),
session=self._session,
)
res._cached_schema = self._merge_cached_schema(other)
@@ -2226,7 +2216,7 @@ class DataFrame(ParentDataFrame):
errorClass="NOT_DATAFRAME",
messageParameters={"arg_name": "other", "arg_type":
type(other).__name__},
)
- self._check_same_session(other)
+ other = self._check_same_session(other)
return self._session.client.same_semantics(
plan=self._plan.to_proto(self._session.client),
other=other._plan.to_proto(other._session.client),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]