This is an automated email from the ASF dual-hosted git repository.
cloud-fan pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 12538d4c98a3 [SPARK-56917][TEST][CONNECT] Expand Connect-specific
tests for DataFrame column resolution
12538d4c98a3 is described below
commit 12538d4c98a34c0c9151acb2245bb7ee92011fa7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun May 31 20:41:37 2026 +0800
[SPARK-56917][TEST][CONNECT] Expand Connect-specific tests for DataFrame
column resolution
### What changes were proposed in this pull request?
This PR widens test coverage of how a tagged DataFrame column reference
(`df.col` / `df["col"]`, which carries the source DataFrame's plan id) resolves
after a range of operators, pinning the behavior across Spark Classic and both
modes of `spark.sql.analyzer.strictDataFrameColumnResolution` on Spark Connect.
The new tests are added to the shared `ColumnTestsMixin` in
`python/pyspark/sql/tests/test_column.py`, so each one runs in three
environments:
- `ColumnTests` - Classic (default `strictDataFrameColumnResolution=true`),
- `ColumnParityTests` - Connect strict,
- `ColumnParityTestsWithNonStrictDFColResolution` - Connect lenient
(extends the strict suite).
The base mixin asserts the behavior shared by all three. The few cases
where Connect diverges are overridden in the Connect parity suites
(`python/pyspark/sql/tests/connect/test_parity_column.py`) rather than
duplicated, so the shared cases stay mode-agnostic. This follows the base-suite
+ per-mode-override structure suggested in review.
**Per-operator resolution tests** (added to `ColumnTestsMixin`):
- Pass-through (`filter`, `sort`, `distinct`) - all modes resolve.
- Attribute-id propagation (`groupBy().count()`, `pivot`, `intersect`,
temp-view roundtrip) - all modes resolve; the source attribute id flows through.
- Removal (`withColumnRenamed`, `drop`) - all modes raise (column gone by
id and by name).
- Shadowing (`withColumn` chain, `select` + alias, `agg` + alias) - Classic
and Connect strict raise; **Connect lenient** resolves the shadowed name via
name-based fallback (overridden in the lenient suite).
- Union - Classic resolves the left-side reference (Union keeps the left
child's attribute ids); **Connect** raises `CANNOT_RESOLVE_DATAFRAME_COLUMN` in
both modes, because Union is treated as a leaf during plan-id resolution
(overridden in the strict suite, inherited by lenient).
- Self-join - the alias form is ambiguous and raises in all modes; the
documented `withColumnRenamed` form resolves in all modes (the renamed side is
filtered out during disambiguation).
- Cross-DataFrame illegal reference (`df1.select(df2.col)`) - raises in all
modes; the throw is not gated by the strict/lenient switch.
- `df["*"]` star expansion and sort-missing-attribute recovery - resolve in
all modes.
**Mixed-surface layered programs** (3): each chains 4-5 transformations -
semi-joins (DataFrame-API EXISTS/IN), window functions, cube aggregation,
NTILE, UDFs, struct-field access - and then references the outermost layer's
columns in both `filter` and `select`. These exercise plan-id propagation
across interacting analyzer rules, which single-operator tests miss. All modes
resolve.
### Why are the changes needed?
apache/spark#55531 added the `strictDataFrameColumnResolution` config and a
single shadowing test. This PR enumerates the Connect-vs-Classic resolution
behavior across shadowing variants, attribute-id-propagating operators, set
operations, self-joins, and multi-operator layered pipelines. A future
tightening of Connect's column resolution will then surface as a clear test
failure rather than a silent regression.
### Does this PR introduce _any_ user-facing change?
No. Test-only change.
### How was this patch tested?
New tests, run locally across all three suites:
```
python/run-tests --testnames "pyspark.sql.tests.test_column ColumnTests"
python/run-tests --testnames "pyspark.sql.tests.connect.test_parity_column
ColumnParityTests"
python/run-tests --testnames "pyspark.sql.tests.connect.test_parity_column
ColumnParityTestsWithNonStrictDFColResolution"
```
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (Anthropic), claude-opus-4-7
Closes #55947 from zhengruifeng/SC-229895-connect-col-tests.
Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 37fcee4fbb5817af4c7abcff3aa6491e60df228b)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/tests/connect/test_parity_column.py | 36 ++
python/pyspark/sql/tests/test_column.py | 386 ++++++++++++++++++++-
2 files changed, 421 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py
b/python/pyspark/sql/tests/connect/test_parity_column.py
index 3903bb57a375..a2b00d7955ee 100644
--- a/python/pyspark/sql/tests/connect/test_parity_column.py
+++ b/python/pyspark/sql/tests/connect/test_parity_column.py
@@ -17,6 +17,8 @@
import unittest
+from pyspark.errors import AnalysisException
+from pyspark.sql import functions as sf
from pyspark.sql.tests.test_column import ColumnTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
@@ -38,6 +40,16 @@ class ColumnParityTests(ColumnTestsMixin,
ReusedConnectTestCase):
def test_validate_column_types(self):
super().test_validate_column_types()
+ def test_resolve_after_union(self):
+ # Connect diverges from Classic here: Union is treated as a leaf when
+ # walking the plan tree for plan-id resolution, so the left-side plan
+ # id is never found and CANNOT_RESOLVE_DATAFRAME_COLUMN is thrown
+ # before any name-based fallback - in both strict and lenient modes.
+ df1 = self.spark.sql("SELECT 1 AS c")
+ df2 = self.spark.sql("SELECT 2 AS c")
+ with self.assertRaisesRegex(AnalysisException,
"CANNOT_RESOLVE_DATAFRAME_COLUMN"):
+ df1.union(df2).select(df1.c).collect()
+
def test_df_col_resolution_mode(self):
self.assertEqual(
self.spark.conf.get("spark.sql.analyzer.strictDataFrameColumnResolution"),
@@ -68,6 +80,30 @@ class
ColumnParityTestsWithNonStrictDFColResolution(ColumnParityTests):
"false",
)
+ # The shadowing trio diverges in lenient mode: where Classic and Connect
+ # strict raise, lenient resolves the tagged reference by name against the
+ # current (shadowed) output.
+
+ def test_resolve_after_chained_withcolumn_shadow(self):
+ df = self.spark.sql("SELECT 1 AS c")
+ rows = (
+ df.withColumn("c", sf.col("c").cast("string"))
+ .withColumn("c", sf.col("c").cast("int"))
+ .select(df.c)
+ .collect()
+ )
+ self.assertEqual([r.c for r in rows], [1])
+
+ def test_resolve_after_select_alias_shadow(self):
+ df = self.spark.sql("SELECT 1 AS c")
+ rows = df.select(df.c.cast("string").alias("c")).select(df.c).collect()
+ self.assertEqual([r.c for r in rows], ["1"])
+
+ def test_resolve_after_agg_alias_shadow(self):
+ df = self.spark.sql("SELECT 1 AS c")
+ rows = df.groupBy().agg(sf.sum("c").alias("c")).select(df.c).collect()
+ self.assertEqual([r.c for r in rows], [1])
+
if __name__ == "__main__":
from pyspark.testing import main
diff --git a/python/pyspark/sql/tests/test_column.py
b/python/pyspark/sql/tests/test_column.py
index 74a7746b154d..6a99c7de1a52 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -20,10 +20,12 @@ from enum import Enum
from itertools import chain
import datetime
import unittest
+import uuid
from pyspark.sql import Column, Row
from pyspark.sql import functions as sf
-from pyspark.sql.types import StructType, StructField, IntegerType, LongType
+from pyspark.sql.window import Window
+from pyspark.sql.types import StructType, StructField, IntegerType, LongType,
StringType
from pyspark.errors import AnalysisException, PySparkTypeError,
PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import have_pandas, pandas_requirement_message
@@ -605,6 +607,388 @@ class ColumnTestsMixin:
self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD",
"colE"])
self.assertEqual(df4.count(), 1)
+ # --- Mixed-surface layered DataFrame programs ---------------------------
+ #
+ # These tests chain multiple DataFrame transformations - semi-joins
+ # (for SQL EXISTS/IN), window functions, cube aggregations, UDFs and
+ # struct field access - into 4-5 layer pipelines, then reference the
+ # final layered DataFrame's columns via ``layered.col`` in both filter
+ # and select at the outermost surface. The goal is to catch regressions
+ # in plan-id propagation across analyzer rules that single-operator
+ # tests miss when rules interact.
+
+ def test_layered_semijoin_groupby_window(self):
+ # 4-layer DataFrame pipeline: filter -> semi-join -> groupBy/agg
+ # -> window functions. ``layered.col`` references appear in both
+ # filter and select at the outermost surface.
+ events_data = [
+ (1, 1, "Books", 100.0, 2, True),
+ (2, 1, "Books", 50.0, 3, True),
+ (3, 2, "Electronics", 200.0, 1, True),
+ (4, 2, "Electronics", 300.0, 2, True),
+ (5, 3, "Home", 80.0, 4, True),
+ (6, 4, "Books", 60.0, 1, False),
+ ]
+ users_data = [(1, 25), (2, 30), (3, 22), (4, 18)]
+ events_cols = ["id", "user_id", "category", "amount", "quantity",
"is_active"]
+ users_cols = ["id", "age"]
+
+ events = self.spark.createDataFrame(events_data, events_cols)
+ users = self.spark.createDataFrame(users_data, users_cols)
+ # Layer 1: filter + semi-join (DataFrame-API equivalent of
+ # WHERE is_active AND EXISTS (user with age > 20)).
+ active = events.where(events.is_active).join(
+ users.where(users.age > 20),
+ events.user_id == users.id,
+ "left_semi",
+ )
+ # Layer 2: groupBy + agg, then post-agg filter (HAVING equivalent).
+ agg = active.groupBy("category").agg(
+ sf.sum(active.amount * active.quantity *
sf.lit(0.1)).alias("total_amt"),
+ sf.sum(active.amount).alias("amount_sum"),
+ )
+ totals = agg.where(agg.amount_sum > 50).select("category", "total_amt")
+ # Layer 3: window functions on top of the aggregate.
+ running = Window.orderBy("total_amt").rowsBetween(-1, 1)
+ ranking = Window.orderBy(totals.total_amt.desc())
+ windowed = totals.select(
+ "category",
+ "total_amt",
+ sf.avg(totals.total_amt).over(running).alias("running_avg"),
+ sf.rank().over(ranking).alias("rank_num"),
+ )
+ # Layer 4: outer filter.
+ layered = windowed.where(windowed.rank_num <= 5)
+
+ rows = (
+ layered.filter(layered.rank_num <= 3)
+ .select(
+ layered.category,
+ layered.total_amt,
+ layered.running_avg,
+ layered.rank_num,
+ )
+ .collect()
+ )
+ result = sorted((r.category, r.rank_num) for r in rows)
+ self.assertEqual(result, [("Books", 2), ("Electronics", 1), ("Home",
3)])
+
+ def test_layered_struct_semijoin_cube_ntile(self):
+ # 5-layer DataFrame pipeline: filter -> semi-join -> struct field
+ # access -> cube aggregation -> window NTILE. ``layered.col``
+ # references appear in both filter and select at the outermost
+ # surface.
+ events_schema = StructType(
+ [
+ StructField("id", IntegerType()),
+ StructField("category", StringType()),
+ StructField("status", StringType()),
+ StructField("amount", IntegerType()),
+ StructField("quantity", IntegerType()),
+ StructField(
+ "detail",
+ StructType(
+ [
+ StructField("name", StringType()),
+ StructField("nested", StructType([StructField("x",
IntegerType())])),
+ ]
+ ),
+ ),
+ ]
+ )
+ events_data = [
+ (1, "Books", "A", 100, 5, ("alpha", (1,))),
+ (2, "Electronics", "B", 200, 3, ("beta", (2,))),
+ (3, "Books", "A", 50, 7, ("alpha", (1,))),
+ (4, "Electronics", "B", 300, 4, ("beta", (2,))),
+ (5, "Home", "C", 80, 2, ("gamma", (3,))),
+ ]
+ categories_data = [("Books", 1), ("Electronics", 2), ("Home", 3),
("Toys", 5)]
+ categories_cols = ["name", "priority"]
+
+ events = self.spark.createDataFrame(events_data, events_schema)
+ categories = self.spark.createDataFrame(categories_data,
categories_cols)
+ # Layer 1: filter + semi-join (DataFrame-API equivalent of
+ # WHERE quantity > 1 AND category IN (SELECT ...)).
+ filtered = events.where(events.quantity > 1).join(
+ categories.where(categories.priority <= 3),
+ events.category == categories.name,
+ "left_semi",
+ )
+ # Layer 2: project with struct field access (struct subfields use
+ # bracket access since ``detail.name`` would hit ``Column.name``).
+ base = filtered.select(
+ filtered.id,
+ filtered.category,
+ filtered.status,
+ filtered.amount,
+ filtered.detail["name"].alias("detail_name"),
+ filtered.detail["nested"]["x"].alias("nx"),
+ )
+ # Layer 3: cube aggregation (mixed grouping levels - similar
+ # surface area to SQL GROUPING SETS without an exact equivalent
+ # in the DataFrame API).
+ agg = base.cube("category", "status", "detail_name").agg(
+ sf.sum(base.amount).alias("total"),
sf.count(sf.lit(1)).alias("cnt")
+ )
+ grouped = agg.where(agg.category.isNotNull() & agg.status.isNotNull())
+ # Layer 4: NTILE window.
+ tiled = grouped.withColumn("tile",
sf.ntile(2).over(Window.orderBy(grouped.total.desc())))
+ # Layer 5: outer filter.
+ layered = tiled.where(tiled.tile <= 2)
+
+ rows = (
+ layered.filter(layered.tile >= 1)
+ .select(
+ layered.category,
+ layered.status,
+ layered.detail_name,
+ layered.total,
+ layered.cnt,
+ layered.tile,
+ )
+ .collect()
+ )
+ # Cube emits one (category, status, detail_name) group per distinct
+ # combination plus one (category, status, NULL) subtotal per distinct
+ # (category, status) pair. The where filter keeps both.
+ self.assertEqual(len(rows), 6)
+ self.assertEqual({r.category for r in rows}, {"Books", "Electronics",
"Home"})
+ self.assertEqual({r.total for r in rows}, {80, 150, 500})
+ self.assertEqual({r.tile for r in rows}, {1, 2})
+
+ def test_layered_window_window_udf(self):
+ # 4-layer DataFrame pipeline: filter -> running-total window ->
+ # per-partition max window -> UDF wrap. ``layered.col`` references
+ # appear in both filter and select at the outermost surface.
+ data = [
+ (1, "A", 100),
+ (2, "A", 200),
+ (3, "B", 150),
+ (4, "B", 250),
+ (5, "C", 50),
+ ]
+ cols = ["id", "category", "amount"]
+
+ df = self.spark.createDataFrame(data, cols)
+ # Layer 1: filter (replaces WHERE EXISTS amount > 0).
+ filtered = df.where(df.amount > 0)
+ # Layer 2: running total window.
+ run_w = Window.partitionBy("category").orderBy("id")
+ with_run = filtered.withColumn("run_amt",
sf.sum(filtered.amount).over(run_w))
+ # Layer 3: per-category max window (replaces correlated subquery
+ # for cat_max).
+ cat_w = Window.partitionBy("category")
+ with_max = with_run.withColumn("cat_max",
sf.max(with_run.amount).over(cat_w))
+ # Layer 4: UDF.
+ double = sf.udf(lambda x: x * 2 if x is not None else None,
IntegerType())
+ layered = with_max.withColumn("doubled_amt", double(with_max.amount))
+
+ rows = (
+ layered.filter(layered.amount > 0)
+ .select(
+ layered.id,
+ layered.category,
+ layered.amount,
+ layered.run_amt,
+ layered.cat_max,
+ layered.doubled_amt,
+ )
+ .collect()
+ )
+ result = sorted(
+ (r.id, r.category, r.amount, r.run_amt, r.cat_max, r.doubled_amt)
for r in rows
+ )
+ self.assertEqual(
+ result,
+ [
+ (1, "A", 100, 100, 200, 200),
+ (2, "A", 200, 300, 200, 400),
+ (3, "B", 150, 150, 250, 300),
+ (4, "B", 250, 400, 250, 500),
+ (5, "C", 50, 50, 50, 100),
+ ],
+ )
+
+ # --- Tagged DataFrame column resolution --------------------------------
+ #
+ # ``df.col`` / ``df["col"]`` carries the source DataFrame's plan id. These
+ # tests pin how that tagged reference resolves after assorted operators.
+ # The behavior is shared across Spark Classic and Spark Connect (both
+ # ``spark.sql.analyzer.strictDataFrameColumnResolution`` modes) except for
+ # a few diverging cases, which are overridden in the Connect parity suites
+ # (``ColumnParityTests`` / ``...WithNonStrictDFColResolution``):
+ #
+ # * the shadowing trio - Classic and Connect strict raise, Connect
+ # lenient resolves the shadowed name via name-based fallback;
+ # * union - Classic resolves via attribute-id propagation, Connect
+ # raises in both modes.
+
+ def test_resolve_after_chained_withcolumn_shadow(self):
+ # Two consecutive withColumn calls each shadow `c` with a new
+ # attribute of the same name, so the original `c` leaves the
+ # projection and the tagged `df.c` cannot resolve.
+ # Connect lenient diverges: name-based fallback resolves the
+ # shadowed name (overridden in the lenient parity suite).
+ df = self.spark.sql("SELECT 1 AS c")
+ with self.assertRaises(AnalysisException):
+ df.withColumn("c", sf.col("c").cast("string")).withColumn(
+ "c", sf.col("c").cast("int")
+ ).select(df.c).collect()
+
+ def test_resolve_after_select_alias_shadow(self):
+ # Same shadowing shape as withColumn but via select + alias.
+ # Connect lenient diverges: name-based fallback resolves the
+ # shadowed name (overridden in the lenient parity suite).
+ df = self.spark.sql("SELECT 1 AS c")
+ with self.assertRaises(AnalysisException):
+ df.select(df.c.cast("string").alias("c")).select(df.c).collect()
+
+ def test_resolve_after_withcolumnrenamed(self):
+ # withColumnRenamed drops the original `c` attribute and projects it
+ # as `c2`; the tagged `df.c` matches neither the original attribute
+ # nor a current column named `c`, so all modes raise.
+ df = self.spark.sql("SELECT 1 AS c")
+ with self.assertRaises(AnalysisException):
+ df.withColumnRenamed("c", "c2").select(df.c).collect()
+
+ def test_resolve_after_drop(self):
+ # drop("c") removes the column entirely; the tagged `df.c` cannot
+ # resolve under any mode.
+ df = self.spark.sql("SELECT 1 AS c, 2 AS d")
+ with self.assertRaises(AnalysisException):
+ df.drop("c").select(df.c).collect()
+
+ def test_resolve_through_filter(self):
+ # filter is a pass-through operator: the child Project's attributes
+ # flow through unchanged, so the tagged reference resolves.
+ df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c")
+ rows = df.filter(df.c > 0).select(df.c).collect()
+ self.assertEqual(sorted(r.c for r in rows), [1, 2])
+
+ def test_resolve_through_sort(self):
+ # sort is also a pass-through operator.
+ df = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c")
+ rows = df.sort(df.c).select(df.c).collect()
+ self.assertEqual([r.c for r in rows], [1, 2])
+
+ def test_resolve_through_distinct(self):
+ # distinct preserves attribute identity for column resolution.
+ df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c")
+ rows = df.distinct().select(df.c).collect()
+ self.assertEqual([r.c for r in rows], [1])
+
+ def test_resolve_after_groupby_count(self):
+ # groupBy("c").count() preserves the grouping key's attribute id, so
+ # the tagged reference resolves.
+ df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL
SELECT 2 AS c")
+ rows = df.groupBy("c").count().select(df.c).collect()
+ self.assertEqual(sorted(r.c for r in rows), [1, 2])
+
+ def test_resolve_after_agg_alias_shadow(self):
+ # An aggregate output aliased `c` collides by name with the source
+ # `c`, but the tagged `df.c` still references the aggregated-away
+ # source attribute, so it cannot resolve.
+ # Connect lenient diverges: name-based fallback resolves the
+ # aliased name (overridden in the lenient parity suite).
+ df = self.spark.sql("SELECT 1 AS c")
+ with self.assertRaises(AnalysisException):
+ df.groupBy().agg(sf.sum("c").alias("c")).select(df.c).collect()
+
+ def test_resolve_after_pivot(self):
+ # pivot preserves the grouping key's attribute id, so the tagged
+ # reference resolves.
+ df = self.spark.sql(
+ "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS
k, 20 AS v"
+ )
+ rows = df.groupBy("c").pivot("k").sum("v").select(df.c).collect()
+ self.assertEqual(sorted(r.c for r in rows), [1, 2])
+
+ def test_resolve_after_union(self):
+ # Union's output keeps the left child's attribute ids
+ # (Union.mergeChildOutputs), so Classic resolves the tagged
+ # left-side reference directly against that output and succeeds.
+ # Connect resolves by walking the plan tree for the plan id but
+ # treats Union as a leaf (ColumnResolutionHelper), so the id below
+ # the Union is never found and it raises in both modes (overridden
+ # there).
+ df1 = self.spark.sql("SELECT 1 AS c")
+ df2 = self.spark.sql("SELECT 2 AS c")
+ rows = df1.union(df2).select(df1.c).collect()
+ self.assertEqual(sorted(r.c for r in rows), [1, 2])
+
+ def test_resolve_after_intersect(self):
+ # Intersect's output also keeps the left child's attribute ids
+ # (Intersect.mergeChildOutputs). Unlike Union, it is not treated as
+ # a leaf during plan-id resolution, so Connect's tree walk descends
+ # into the left child, finds the tagged node and resolves it; all
+ # modes succeed.
+ df1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c")
+ df2 = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c")
+ rows = df1.intersect(df2).select(df1.c).collect()
+ self.assertEqual([r.c for r in rows], [2])
+
+ def test_resolve_self_join_alias(self):
+ # Both self-join sides originate from the same plan-id-tagged
+ # ancestor, yielding two equal-depth candidates with the same
+ # attribute id. Disambiguation cannot tiebreak and all modes raise
+ # an ambiguous-reference error.
+ df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c")
+ a, b = df.alias("a"), df.alias("b")
+ with self.assertRaises(AnalysisException):
+ a.join(b, a.c == b.c).select(df.c).collect()
+
+ def test_resolve_after_subquery_view(self):
+ # Persisting the DataFrame as a temp view and reading it back via
+ # table() produces a new plan; the tagged reference still resolves in
+ # all modes.
+ view = f"v_{uuid.uuid4().hex}"
+ df = self.spark.sql("SELECT 1 AS c")
+ df.createOrReplaceTempView(view)
+ try:
+ rows = self.spark.table(view).select(df.c).collect()
+ self.assertEqual([r.c for r in rows], [1])
+ finally:
+ self.spark.sql(f"DROP VIEW IF EXISTS {view}")
+
+ def test_resolve_cross_dataframe_illegal_reference(self):
+ # Referencing a column from a DataFrame whose plan id is not an
+ # ancestor of the target plan (`df1.select(df2.id)`) fails in all
+ # modes; the strict / lenient switch does not gate this throw.
+ df1 = self.spark.range(3)
+ df2 = self.spark.range(5)
+ with self.assertRaises(AnalysisException):
+ df1.select(df2.id).collect()
+
+ def test_resolve_df_star(self):
+ # `df["*"]` is an UnresolvedDataFrameStar carrying df's plan id; the
+ # analyzer expands it to the matched node's output in all modes.
+ df = self.spark.sql(
+ "SELECT 'Books' AS c, 100 AS v UNION ALL SELECT 'Electronics' AS
c, 200 AS v"
+ )
+ rows = df.select(df["*"]).collect()
+ self.assertEqual(sorted((r.c, r.v) for r in rows), [("Books", 100),
("Electronics", 200)])
+
+ def test_resolve_self_join_withcolumnrenamed(self):
+ # Documented ColumnResolutionHelper case: df1 = range(10) + col `a`;
+ # df2 = df1 renamed `a` -> `b`; df1.join(df2, df1.a == df2.b). The
+ # node with df1's plan id is found on both Join sides; the right
+ # candidate is filtered out because its `a` is not in the renaming
+ # Project's output, so disambiguation succeeds in all modes.
+ df1 = self.spark.range(10).withColumn("a", sf.col("id"))
+ df2 = df1.withColumnRenamed("a", "b")
+ rows = df1.join(df2, df1.a == df2.b).select(df1.a, df2.b).collect()
+ self.assertEqual(len(rows), 10)
+
+ def test_resolve_sort_missing_attr_recovery(self):
+ # Documented ColumnResolutionHelper case: df.select(df.v).sort(df.id)
+ # where df.id is not in the select's output. The analyzer descends
+ # through the Project, resolves df.id via plan id at the source, and
+ # adds it back to the upstream projection. Works in all modes.
+ df = self.spark.range(10).withColumn("v", sf.col("id") + 1)
+ rows = df.select(df.v).sort(df.id).collect()
+ self.assertEqual(len(rows), 10)
+
class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase):
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]