This is an automated email from the ASF dual-hosted git repository.

cloud-fan pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 12538d4c98a3 [SPARK-56917][TEST][CONNECT] Expand Connect-specific 
tests for DataFrame column resolution
12538d4c98a3 is described below

commit 12538d4c98a34c0c9151acb2245bb7ee92011fa7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun May 31 20:41:37 2026 +0800

    [SPARK-56917][TEST][CONNECT] Expand Connect-specific tests for DataFrame 
column resolution
    
    ### What changes were proposed in this pull request?
    
    This PR widens test coverage of how a tagged DataFrame column reference 
(`df.col` / `df["col"]`, which carries the source DataFrame's plan id) resolves 
after a range of operators, pinning the behavior across Spark Classic and both 
modes of `spark.sql.analyzer.strictDataFrameColumnResolution` on Spark Connect.
    
    The new tests are added to the shared `ColumnTestsMixin` in 
`python/pyspark/sql/tests/test_column.py`, so each one runs in three 
environments:
    
    - `ColumnTests` - Classic (default `strictDataFrameColumnResolution=true`),
    - `ColumnParityTests` - Connect strict,
    - `ColumnParityTestsWithNonStrictDFColResolution` - Connect lenient 
(extends the strict suite).
    
    The base mixin asserts the behavior shared by all three. The few cases 
where Connect diverges are overridden in the Connect parity suites 
(`python/pyspark/sql/tests/connect/test_parity_column.py`) rather than 
duplicated, so the shared cases stay mode-agnostic. This follows the base-suite 
+ per-mode-override structure suggested in review.
    
    **Per-operator resolution tests** (added to `ColumnTestsMixin`):
    
    - Pass-through (`filter`, `sort`, `distinct`) - all modes resolve.
    - Attribute-id propagation (`groupBy().count()`, `pivot`, `intersect`, 
temp-view roundtrip) - all modes resolve; the source attribute id flows through.
    - Removal (`withColumnRenamed`, `drop`) - all modes raise (column gone by 
id and by name).
    - Shadowing (`withColumn` chain, `select` + alias, `agg` + alias) - Classic 
and Connect strict raise; **Connect lenient** resolves the shadowed name via 
name-based fallback (overridden in the lenient suite).
    - Union - Classic resolves the left-side reference (Union keeps the left 
child's attribute ids); **Connect** raises `CANNOT_RESOLVE_DATAFRAME_COLUMN` in 
both modes, because Union is treated as a leaf during plan-id resolution 
(overridden in the strict suite, inherited by lenient).
    - Self-join - the alias form is ambiguous and raises in all modes; the 
documented `withColumnRenamed` form resolves in all modes (the renamed side is 
filtered out during disambiguation).
    - Cross-DataFrame illegal reference (`df1.select(df2.col)`) - raises in all 
modes; the throw is not gated by the strict/lenient switch.
    - `df["*"]` star expansion and sort-missing-attribute recovery - resolve in 
all modes.
    
    **Mixed-surface layered programs** (3): each chains 4-5 transformations - 
semi-joins (DataFrame-API EXISTS/IN), window functions, cube aggregation, 
NTILE, UDFs, struct-field access - and then references the outermost layer's 
columns in both `filter` and `select`. These exercise plan-id propagation 
across interacting analyzer rules, which single-operator tests miss. All modes 
resolve.
    
    ### Why are the changes needed?
    
    apache/spark#55531 added the `strictDataFrameColumnResolution` config and a 
single shadowing test. This PR enumerates the Connect-vs-Classic resolution 
behavior across shadowing variants, attribute-id-propagating operators, set 
operations, self-joins, and multi-operator layered pipelines. A future 
tightening of Connect's column resolution will then surface as a clear test 
failure rather than a silent regression.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. Test-only change.
    
    ### How was this patch tested?
    
    New tests, run locally across all three suites:
    
    ```
    python/run-tests --testnames "pyspark.sql.tests.test_column ColumnTests"
    python/run-tests --testnames "pyspark.sql.tests.connect.test_parity_column 
ColumnParityTests"
    python/run-tests --testnames "pyspark.sql.tests.connect.test_parity_column 
ColumnParityTestsWithNonStrictDFColResolution"
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code (Anthropic), claude-opus-4-7
    
    Closes #55947 from zhengruifeng/SC-229895-connect-col-tests.
    
    Lead-authored-by: Ruifeng Zheng <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 37fcee4fbb5817af4c7abcff3aa6491e60df228b)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/tests/connect/test_parity_column.py        |  36 ++
 python/pyspark/sql/tests/test_column.py            | 386 ++++++++++++++++++++-
 2 files changed, 421 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py 
b/python/pyspark/sql/tests/connect/test_parity_column.py
index 3903bb57a375..a2b00d7955ee 100644
--- a/python/pyspark/sql/tests/connect/test_parity_column.py
+++ b/python/pyspark/sql/tests/connect/test_parity_column.py
@@ -17,6 +17,8 @@
 
 import unittest
 
+from pyspark.errors import AnalysisException
+from pyspark.sql import functions as sf
 from pyspark.sql.tests.test_column import ColumnTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
@@ -38,6 +40,16 @@ class ColumnParityTests(ColumnTestsMixin, 
ReusedConnectTestCase):
     def test_validate_column_types(self):
         super().test_validate_column_types()
 
+    def test_resolve_after_union(self):
+        # Connect diverges from Classic here: Union is treated as a leaf when
+        # walking the plan tree for plan-id resolution, so the left-side plan
+        # id is never found and CANNOT_RESOLVE_DATAFRAME_COLUMN is thrown
+        # before any name-based fallback - in both strict and lenient modes.
+        df1 = self.spark.sql("SELECT 1 AS c")
+        df2 = self.spark.sql("SELECT 2 AS c")
+        with self.assertRaisesRegex(AnalysisException, 
"CANNOT_RESOLVE_DATAFRAME_COLUMN"):
+            df1.union(df2).select(df1.c).collect()
+
     def test_df_col_resolution_mode(self):
         self.assertEqual(
             
self.spark.conf.get("spark.sql.analyzer.strictDataFrameColumnResolution"),
@@ -68,6 +80,30 @@ class 
ColumnParityTestsWithNonStrictDFColResolution(ColumnParityTests):
             "false",
         )
 
+    # The shadowing trio diverges in lenient mode: where Classic and Connect
+    # strict raise, lenient resolves the tagged reference by name against the
+    # current (shadowed) output.
+
+    def test_resolve_after_chained_withcolumn_shadow(self):
+        df = self.spark.sql("SELECT 1 AS c")
+        rows = (
+            df.withColumn("c", sf.col("c").cast("string"))
+            .withColumn("c", sf.col("c").cast("int"))
+            .select(df.c)
+            .collect()
+        )
+        self.assertEqual([r.c for r in rows], [1])
+
+    def test_resolve_after_select_alias_shadow(self):
+        df = self.spark.sql("SELECT 1 AS c")
+        rows = df.select(df.c.cast("string").alias("c")).select(df.c).collect()
+        self.assertEqual([r.c for r in rows], ["1"])
+
+    def test_resolve_after_agg_alias_shadow(self):
+        df = self.spark.sql("SELECT 1 AS c")
+        rows = df.groupBy().agg(sf.sum("c").alias("c")).select(df.c).collect()
+        self.assertEqual([r.c for r in rows], [1])
+
 
 if __name__ == "__main__":
     from pyspark.testing import main
diff --git a/python/pyspark/sql/tests/test_column.py 
b/python/pyspark/sql/tests/test_column.py
index 74a7746b154d..6a99c7de1a52 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -20,10 +20,12 @@ from enum import Enum
 from itertools import chain
 import datetime
 import unittest
+import uuid
 
 from pyspark.sql import Column, Row
 from pyspark.sql import functions as sf
-from pyspark.sql.types import StructType, StructField, IntegerType, LongType
+from pyspark.sql.window import Window
+from pyspark.sql.types import StructType, StructField, IntegerType, LongType, 
StringType
 from pyspark.errors import AnalysisException, PySparkTypeError, 
PySparkValueError
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 from pyspark.testing.utils import have_pandas, pandas_requirement_message
@@ -605,6 +607,388 @@ class ColumnTestsMixin:
         self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD", 
"colE"])
         self.assertEqual(df4.count(), 1)
 
+    # --- Mixed-surface layered DataFrame programs ---------------------------
+    #
+    # These tests chain multiple DataFrame transformations - semi-joins
+    # (for SQL EXISTS/IN), window functions, cube aggregations, UDFs and
+    # struct field access - into 4-5 layer pipelines, then reference the
+    # final layered DataFrame's columns via ``layered.col`` in both filter
+    # and select at the outermost surface. The goal is to catch regressions
+    # in plan-id propagation across analyzer rules that single-operator
+    # tests miss when rules interact.
+
+    def test_layered_semijoin_groupby_window(self):
+        # 4-layer DataFrame pipeline: filter -> semi-join -> groupBy/agg
+        # -> window functions. ``layered.col`` references appear in both
+        # filter and select at the outermost surface.
+        events_data = [
+            (1, 1, "Books", 100.0, 2, True),
+            (2, 1, "Books", 50.0, 3, True),
+            (3, 2, "Electronics", 200.0, 1, True),
+            (4, 2, "Electronics", 300.0, 2, True),
+            (5, 3, "Home", 80.0, 4, True),
+            (6, 4, "Books", 60.0, 1, False),
+        ]
+        users_data = [(1, 25), (2, 30), (3, 22), (4, 18)]
+        events_cols = ["id", "user_id", "category", "amount", "quantity", 
"is_active"]
+        users_cols = ["id", "age"]
+
+        events = self.spark.createDataFrame(events_data, events_cols)
+        users = self.spark.createDataFrame(users_data, users_cols)
+        # Layer 1: filter + semi-join (DataFrame-API equivalent of
+        # WHERE is_active AND EXISTS (user with age > 20)).
+        active = events.where(events.is_active).join(
+            users.where(users.age > 20),
+            events.user_id == users.id,
+            "left_semi",
+        )
+        # Layer 2: groupBy + agg, then post-agg filter (HAVING equivalent).
+        agg = active.groupBy("category").agg(
+            sf.sum(active.amount * active.quantity * 
sf.lit(0.1)).alias("total_amt"),
+            sf.sum(active.amount).alias("amount_sum"),
+        )
+        totals = agg.where(agg.amount_sum > 50).select("category", "total_amt")
+        # Layer 3: window functions on top of the aggregate.
+        running = Window.orderBy("total_amt").rowsBetween(-1, 1)
+        ranking = Window.orderBy(totals.total_amt.desc())
+        windowed = totals.select(
+            "category",
+            "total_amt",
+            sf.avg(totals.total_amt).over(running).alias("running_avg"),
+            sf.rank().over(ranking).alias("rank_num"),
+        )
+        # Layer 4: outer filter.
+        layered = windowed.where(windowed.rank_num <= 5)
+
+        rows = (
+            layered.filter(layered.rank_num <= 3)
+            .select(
+                layered.category,
+                layered.total_amt,
+                layered.running_avg,
+                layered.rank_num,
+            )
+            .collect()
+        )
+        result = sorted((r.category, r.rank_num) for r in rows)
+        self.assertEqual(result, [("Books", 2), ("Electronics", 1), ("Home", 
3)])
+
+    def test_layered_struct_semijoin_cube_ntile(self):
+        # 5-layer DataFrame pipeline: filter -> semi-join -> struct field
+        # access -> cube aggregation -> window NTILE. ``layered.col``
+        # references appear in both filter and select at the outermost
+        # surface.
+        events_schema = StructType(
+            [
+                StructField("id", IntegerType()),
+                StructField("category", StringType()),
+                StructField("status", StringType()),
+                StructField("amount", IntegerType()),
+                StructField("quantity", IntegerType()),
+                StructField(
+                    "detail",
+                    StructType(
+                        [
+                            StructField("name", StringType()),
+                            StructField("nested", StructType([StructField("x", 
IntegerType())])),
+                        ]
+                    ),
+                ),
+            ]
+        )
+        events_data = [
+            (1, "Books", "A", 100, 5, ("alpha", (1,))),
+            (2, "Electronics", "B", 200, 3, ("beta", (2,))),
+            (3, "Books", "A", 50, 7, ("alpha", (1,))),
+            (4, "Electronics", "B", 300, 4, ("beta", (2,))),
+            (5, "Home", "C", 80, 2, ("gamma", (3,))),
+        ]
+        categories_data = [("Books", 1), ("Electronics", 2), ("Home", 3), 
("Toys", 5)]
+        categories_cols = ["name", "priority"]
+
+        events = self.spark.createDataFrame(events_data, events_schema)
+        categories = self.spark.createDataFrame(categories_data, 
categories_cols)
+        # Layer 1: filter + semi-join (DataFrame-API equivalent of
+        # WHERE quantity > 1 AND category IN (SELECT ...)).
+        filtered = events.where(events.quantity > 1).join(
+            categories.where(categories.priority <= 3),
+            events.category == categories.name,
+            "left_semi",
+        )
+        # Layer 2: project with struct field access (struct subfields use
+        # bracket access since ``detail.name`` would hit ``Column.name``).
+        base = filtered.select(
+            filtered.id,
+            filtered.category,
+            filtered.status,
+            filtered.amount,
+            filtered.detail["name"].alias("detail_name"),
+            filtered.detail["nested"]["x"].alias("nx"),
+        )
+        # Layer 3: cube aggregation (mixed grouping levels - similar
+        # surface area to SQL GROUPING SETS without an exact equivalent
+        # in the DataFrame API).
+        agg = base.cube("category", "status", "detail_name").agg(
+            sf.sum(base.amount).alias("total"), 
sf.count(sf.lit(1)).alias("cnt")
+        )
+        grouped = agg.where(agg.category.isNotNull() & agg.status.isNotNull())
+        # Layer 4: NTILE window.
+        tiled = grouped.withColumn("tile", 
sf.ntile(2).over(Window.orderBy(grouped.total.desc())))
+        # Layer 5: outer filter.
+        layered = tiled.where(tiled.tile <= 2)
+
+        rows = (
+            layered.filter(layered.tile >= 1)
+            .select(
+                layered.category,
+                layered.status,
+                layered.detail_name,
+                layered.total,
+                layered.cnt,
+                layered.tile,
+            )
+            .collect()
+        )
+        # Cube emits one (category, status, detail_name) group per distinct
+        # combination plus one (category, status, NULL) subtotal per distinct
+        # (category, status) pair. The where filter keeps both.
+        self.assertEqual(len(rows), 6)
+        self.assertEqual({r.category for r in rows}, {"Books", "Electronics", 
"Home"})
+        self.assertEqual({r.total for r in rows}, {80, 150, 500})
+        self.assertEqual({r.tile for r in rows}, {1, 2})
+
+    def test_layered_window_window_udf(self):
+        # 4-layer DataFrame pipeline: filter -> running-total window ->
+        # per-partition max window -> UDF wrap. ``layered.col`` references
+        # appear in both filter and select at the outermost surface.
+        data = [
+            (1, "A", 100),
+            (2, "A", 200),
+            (3, "B", 150),
+            (4, "B", 250),
+            (5, "C", 50),
+        ]
+        cols = ["id", "category", "amount"]
+
+        df = self.spark.createDataFrame(data, cols)
+        # Layer 1: filter (replaces WHERE EXISTS amount > 0).
+        filtered = df.where(df.amount > 0)
+        # Layer 2: running total window.
+        run_w = Window.partitionBy("category").orderBy("id")
+        with_run = filtered.withColumn("run_amt", 
sf.sum(filtered.amount).over(run_w))
+        # Layer 3: per-category max window (replaces correlated subquery
+        # for cat_max).
+        cat_w = Window.partitionBy("category")
+        with_max = with_run.withColumn("cat_max", 
sf.max(with_run.amount).over(cat_w))
+        # Layer 4: UDF.
+        double = sf.udf(lambda x: x * 2 if x is not None else None, 
IntegerType())
+        layered = with_max.withColumn("doubled_amt", double(with_max.amount))
+
+        rows = (
+            layered.filter(layered.amount > 0)
+            .select(
+                layered.id,
+                layered.category,
+                layered.amount,
+                layered.run_amt,
+                layered.cat_max,
+                layered.doubled_amt,
+            )
+            .collect()
+        )
+        result = sorted(
+            (r.id, r.category, r.amount, r.run_amt, r.cat_max, r.doubled_amt) 
for r in rows
+        )
+        self.assertEqual(
+            result,
+            [
+                (1, "A", 100, 100, 200, 200),
+                (2, "A", 200, 300, 200, 400),
+                (3, "B", 150, 150, 250, 300),
+                (4, "B", 250, 400, 250, 500),
+                (5, "C", 50, 50, 50, 100),
+            ],
+        )
+
+    # --- Tagged DataFrame column resolution --------------------------------
+    #
+    # ``df.col`` / ``df["col"]`` carries the source DataFrame's plan id. These
+    # tests pin how that tagged reference resolves after assorted operators.
+    # The behavior is shared across Spark Classic and Spark Connect (both
+    # ``spark.sql.analyzer.strictDataFrameColumnResolution`` modes) except for
+    # a few diverging cases, which are overridden in the Connect parity suites
+    # (``ColumnParityTests`` / ``...WithNonStrictDFColResolution``):
+    #
+    #   * the shadowing trio - Classic and Connect strict raise, Connect
+    #     lenient resolves the shadowed name via name-based fallback;
+    #   * union - Classic resolves via attribute-id propagation, Connect
+    #     raises in both modes.
+
+    def test_resolve_after_chained_withcolumn_shadow(self):
+        # Two consecutive withColumn calls each shadow `c` with a new
+        # attribute of the same name, so the original `c` leaves the
+        # projection and the tagged `df.c` cannot resolve.
+        # Connect lenient diverges: name-based fallback resolves the
+        # shadowed name (overridden in the lenient parity suite).
+        df = self.spark.sql("SELECT 1 AS c")
+        with self.assertRaises(AnalysisException):
+            df.withColumn("c", sf.col("c").cast("string")).withColumn(
+                "c", sf.col("c").cast("int")
+            ).select(df.c).collect()
+
+    def test_resolve_after_select_alias_shadow(self):
+        # Same shadowing shape as withColumn but via select + alias.
+        # Connect lenient diverges: name-based fallback resolves the
+        # shadowed name (overridden in the lenient parity suite).
+        df = self.spark.sql("SELECT 1 AS c")
+        with self.assertRaises(AnalysisException):
+            df.select(df.c.cast("string").alias("c")).select(df.c).collect()
+
+    def test_resolve_after_withcolumnrenamed(self):
+        # withColumnRenamed drops the original `c` attribute and projects it
+        # as `c2`; the tagged `df.c` matches neither the original attribute
+        # nor a current column named `c`, so all modes raise.
+        df = self.spark.sql("SELECT 1 AS c")
+        with self.assertRaises(AnalysisException):
+            df.withColumnRenamed("c", "c2").select(df.c).collect()
+
+    def test_resolve_after_drop(self):
+        # drop("c") removes the column entirely; the tagged `df.c` cannot
+        # resolve under any mode.
+        df = self.spark.sql("SELECT 1 AS c, 2 AS d")
+        with self.assertRaises(AnalysisException):
+            df.drop("c").select(df.c).collect()
+
+    def test_resolve_through_filter(self):
+        # filter is a pass-through operator: the child Project's attributes
+        # flow through unchanged, so the tagged reference resolves.
+        df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c")
+        rows = df.filter(df.c > 0).select(df.c).collect()
+        self.assertEqual(sorted(r.c for r in rows), [1, 2])
+
+    def test_resolve_through_sort(self):
+        # sort is also a pass-through operator.
+        df = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c")
+        rows = df.sort(df.c).select(df.c).collect()
+        self.assertEqual([r.c for r in rows], [1, 2])
+
+    def test_resolve_through_distinct(self):
+        # distinct preserves attribute identity for column resolution.
+        df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c")
+        rows = df.distinct().select(df.c).collect()
+        self.assertEqual([r.c for r in rows], [1])
+
+    def test_resolve_after_groupby_count(self):
+        # groupBy("c").count() preserves the grouping key's attribute id, so
+        # the tagged reference resolves.
+        df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL 
SELECT 2 AS c")
+        rows = df.groupBy("c").count().select(df.c).collect()
+        self.assertEqual(sorted(r.c for r in rows), [1, 2])
+
+    def test_resolve_after_agg_alias_shadow(self):
+        # An aggregate output aliased `c` collides by name with the source
+        # `c`, but the tagged `df.c` still references the aggregated-away
+        # source attribute, so it cannot resolve.
+        # Connect lenient diverges: name-based fallback resolves the
+        # aliased name (overridden in the lenient parity suite).
+        df = self.spark.sql("SELECT 1 AS c")
+        with self.assertRaises(AnalysisException):
+            df.groupBy().agg(sf.sum("c").alias("c")).select(df.c).collect()
+
+    def test_resolve_after_pivot(self):
+        # pivot preserves the grouping key's attribute id, so the tagged
+        # reference resolves.
+        df = self.spark.sql(
+            "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS 
k, 20 AS v"
+        )
+        rows = df.groupBy("c").pivot("k").sum("v").select(df.c).collect()
+        self.assertEqual(sorted(r.c for r in rows), [1, 2])
+
+    def test_resolve_after_union(self):
+        # Union's output keeps the left child's attribute ids
+        # (Union.mergeChildOutputs), so Classic resolves the tagged
+        # left-side reference directly against that output and succeeds.
+        # Connect resolves by walking the plan tree for the plan id but
+        # treats Union as a leaf (ColumnResolutionHelper), so the id below
+        # the Union is never found and it raises in both modes (overridden
+        # there).
+        df1 = self.spark.sql("SELECT 1 AS c")
+        df2 = self.spark.sql("SELECT 2 AS c")
+        rows = df1.union(df2).select(df1.c).collect()
+        self.assertEqual(sorted(r.c for r in rows), [1, 2])
+
+    def test_resolve_after_intersect(self):
+        # Intersect's output also keeps the left child's attribute ids
+        # (Intersect.mergeChildOutputs). Unlike Union, it is not treated as
+        # a leaf during plan-id resolution, so Connect's tree walk descends
+        # into the left child, finds the tagged node and resolves it; all
+        # modes succeed.
+        df1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c")
+        df2 = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c")
+        rows = df1.intersect(df2).select(df1.c).collect()
+        self.assertEqual([r.c for r in rows], [2])
+
+    def test_resolve_self_join_alias(self):
+        # Both self-join sides originate from the same plan-id-tagged
+        # ancestor, yielding two equal-depth candidates with the same
+        # attribute id. Disambiguation cannot tiebreak and all modes raise
+        # an ambiguous-reference error.
+        df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c")
+        a, b = df.alias("a"), df.alias("b")
+        with self.assertRaises(AnalysisException):
+            a.join(b, a.c == b.c).select(df.c).collect()
+
+    def test_resolve_after_subquery_view(self):
+        # Persisting the DataFrame as a temp view and reading it back via
+        # table() produces a new plan; the tagged reference still resolves in
+        # all modes.
+        view = f"v_{uuid.uuid4().hex}"
+        df = self.spark.sql("SELECT 1 AS c")
+        df.createOrReplaceTempView(view)
+        try:
+            rows = self.spark.table(view).select(df.c).collect()
+            self.assertEqual([r.c for r in rows], [1])
+        finally:
+            self.spark.sql(f"DROP VIEW IF EXISTS {view}")
+
+    def test_resolve_cross_dataframe_illegal_reference(self):
+        # Referencing a column from a DataFrame whose plan id is not an
+        # ancestor of the target plan (`df1.select(df2.id)`) fails in all
+        # modes; the strict / lenient switch does not gate this throw.
+        df1 = self.spark.range(3)
+        df2 = self.spark.range(5)
+        with self.assertRaises(AnalysisException):
+            df1.select(df2.id).collect()
+
+    def test_resolve_df_star(self):
+        # `df["*"]` is an UnresolvedDataFrameStar carrying df's plan id; the
+        # analyzer expands it to the matched node's output in all modes.
+        df = self.spark.sql(
+            "SELECT 'Books' AS c, 100 AS v UNION ALL SELECT 'Electronics' AS 
c, 200 AS v"
+        )
+        rows = df.select(df["*"]).collect()
+        self.assertEqual(sorted((r.c, r.v) for r in rows), [("Books", 100), 
("Electronics", 200)])
+
+    def test_resolve_self_join_withcolumnrenamed(self):
+        # Documented ColumnResolutionHelper case: df1 = range(10) + col `a`;
+        # df2 = df1 renamed `a` -> `b`; df1.join(df2, df1.a == df2.b). The
+        # node with df1's plan id is found on both Join sides; the right
+        # candidate is filtered out because its `a` is not in the renaming
+        # Project's output, so disambiguation succeeds in all modes.
+        df1 = self.spark.range(10).withColumn("a", sf.col("id"))
+        df2 = df1.withColumnRenamed("a", "b")
+        rows = df1.join(df2, df1.a == df2.b).select(df1.a, df2.b).collect()
+        self.assertEqual(len(rows), 10)
+
+    def test_resolve_sort_missing_attr_recovery(self):
+        # Documented ColumnResolutionHelper case: df.select(df.v).sort(df.id)
+        # where df.id is not in the select's output. The analyzer descends
+        # through the Project, resolves df.id via plan id at the source, and
+        # adds it back to the upstream projection. Works in all modes.
+        df = self.spark.range(10).withColumn("v", sf.col("id") + 1)
+        rows = df.select(df.v).sort(df.id).collect()
+        self.assertEqual(len(rows), 10)
+
 
 class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to