This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cd3fa2f6b937 [SPARK-46677][SQL][CONNECT] Fix `dataframe["*"]` 
resolution
cd3fa2f6b937 is described below

commit cd3fa2f6b9373fed59985ef9383a718884595f5b
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 16 09:25:10 2024 +0900

    [SPARK-46677][SQL][CONNECT] Fix `dataframe["*"]` resolution
    
    ### What changes were proposed in this pull request?
    On Spark Connect, `df.col("*")` should be resolved against the target plan
    
    ### Why are the changes needed?
    ```
    In [6]: df1 = spark.createDataFrame([{"id": 1}])
    
    In [7]: df2 = spark.createDataFrame([{"id": 1, "val": "v"}])
    
    In [8]: df1.join(df2)
    Out[8]: DataFrame[id: bigint, id: bigint, val: string]
    
    In [9]: df1.join(df2).select(df1["*"])
    Out[9]: DataFrame[id: bigint, id: bigint, val: string]
    ```
    
    it should be
    ```
    In [3]: df1.join(df2).select(df1["*"])
    Out[3]: DataFrame[id: bigint]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44689 from zhengruifeng/py_df_star.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/protobuf/spark/connect/expressions.proto  |  3 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 35 ++++++++------
 python/pyspark/sql/connect/dataframe.py            | 48 +++++++++++++------
 python/pyspark/sql/connect/expressions.py          |  8 +++-
 python/pyspark/sql/connect/functions/builtin.py    | 16 +++----
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 54 +++++++++++-----------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 27 ++++++++++-
 .../sql/tests/connect/test_connect_basic.py        | 38 +++++++++++++++
 python/pyspark/sql/tests/test_dataframe.py         | 35 ++++++++++++++
 .../catalyst/analysis/ColumnResolutionHelper.scala | 47 ++++++++++++++++++-
 .../spark/sql/catalyst/analysis/unresolved.scala   | 17 +++++++
 .../spark/sql/catalyst/trees/TreePatterns.scala    |  1 +
 .../spark/sql/errors/QueryCompilationErrors.scala  | 14 +++---
 .../main/scala/org/apache/spark/sql/Column.scala   |  2 +-
 14 files changed, 268 insertions(+), 77 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 4aac2bcc612b..c3333636bf68 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -261,6 +261,9 @@ message Expression {
     // If set, it should end with '.*' and will be parsed by 
'parseAttributeName'
     // in the server side.
     optional string unparsed_target = 1;
+
+    // (Optional) The id of corresponding connect plan.
+    optional int64 plan_id = 2;
   }
 
   // Represents all of the input attributes to a given relational operator, 
for example in
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index dc57adc90c42..25c78413170e 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.ml.{functions => MLFunctions}
 import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, 
Observation, RelationalGroupedDataset, SparkSession}
 import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, 
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
@@ -2109,19 +2109,28 @@ class SparkConnectPlanner(
     parser.parseExpression(expr.getExpression)
   }
 
-  private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar): 
UnresolvedStar = {
-    if (star.hasUnparsedTarget) {
-      val target = star.getUnparsedTarget
-      if (!target.endsWith(".*")) {
-        throw InvalidPlanInput(
-          s"UnresolvedStar requires a unparsed target ending with '.*', " +
-            s"but got $target.")
-      }
+  private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar): 
Expression = {
+    (star.hasUnparsedTarget, star.hasPlanId) match {
+      case (false, false) =>
+        // functions.col("*")
+        UnresolvedStar(None)
 
-      UnresolvedStar(
-        Some(UnresolvedAttribute.parseAttributeName(target.substring(0, 
target.length - 2))))
-    } else {
-      UnresolvedStar(None)
+      case (true, false) =>
+        // functions.col("s.*")
+        val target = star.getUnparsedTarget
+        if (!target.endsWith(".*")) {
+          throw InvalidPlanInput(
+            s"UnresolvedStar requires a unparsed target ending with '.*', but 
got $target.")
+        }
+        val parts = UnresolvedAttribute.parseAttributeName(target.dropRight(2))
+        UnresolvedStar(Some(parts))
+
+      case (false, true) =>
+        // dataframe.col("*")
+        UnresolvedDataFrameStar(star.getPlanId)
+
+      case _ =>
+        throw InvalidPlanInput("UnresolvedStar with both target and plan id is 
not supported.")
     }
   }
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 66059ad96ebb..7ee27065208c 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -71,9 +71,12 @@ from pyspark.sql.connect.group import GroupedData
 from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
 from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import UnresolvedRegex
+from pyspark.sql.connect.expressions import (
+    ColumnReference,
+    UnresolvedRegex,
+    UnresolvedStar,
+)
 from pyspark.sql.connect.functions.builtin import (
-    _to_col_with_plan_id,
     _to_col,
     _invoke_function,
     col,
@@ -1702,9 +1705,11 @@ class DataFrame:
                 error_class="ATTRIBUTE_NOT_SUPPORTED", 
message_parameters={"attr_name": name}
             )
 
-        return _to_col_with_plan_id(
-            col=name,
-            plan_id=self._plan._plan_id,
+        return Column(
+            ColumnReference(
+                unparsed_identifier=name,
+                plan_id=self._plan._plan_id,
+            )
         )
 
     __getattr__.__doc__ = PySparkDataFrame.__getattr__.__doc__
@@ -1719,14 +1724,31 @@ class DataFrame:
 
     def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> 
Union[Column, "DataFrame"]:
         if isinstance(item, str):
-            # validate the column name
-            if not hasattr(self._session, "is_mock_session"):
-                self.select(item).isLocal()
-
-            return _to_col_with_plan_id(
-                col=item,
-                plan_id=self._plan._plan_id,
-            )
+            if item == "*":
+                return Column(
+                    UnresolvedStar(
+                        unparsed_target=None,
+                        plan_id=self._plan._plan_id,
+                    )
+                )
+            else:
+                # TODO: revisit vanilla Spark's Dataset.col
+                # if 
(sparkSession.sessionState.conf.supportQuotedRegexColumnName) {
+                #   colRegex(colName)
+                # } else {
+                #   Column(addDataFrameIdToCol(resolve(colName)))
+                # }
+
+                # validate the column name
+                if not hasattr(self._session, "is_mock_session"):
+                    self.select(item).isLocal()
+
+                return Column(
+                    ColumnReference(
+                        unparsed_identifier=item,
+                        plan_id=self._plan._plan_id,
+                    )
+                )
         elif isinstance(item, Column):
             return self.filter(item)
         elif isinstance(item, (list, tuple)):
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 384422eed7d1..f985e88d0f23 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -494,19 +494,23 @@ class ColumnReference(Expression):
 
 
 class UnresolvedStar(Expression):
-    def __init__(self, unparsed_target: Optional[str]):
+    def __init__(self, unparsed_target: Optional[str], plan_id: Optional[int] 
= None):
         super().__init__()
 
         if unparsed_target is not None:
             assert isinstance(unparsed_target, str) and 
unparsed_target.endswith(".*")
-
         self._unparsed_target = unparsed_target
 
+        assert plan_id is None or isinstance(plan_id, int)
+        self._plan_id = plan_id
+
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
         expr = proto.Expression()
         expr.unresolved_star.SetInParent()
         if self._unparsed_target is not None:
             expr.unresolved_star.unparsed_target = self._unparsed_target
+        if self._plan_id is not None:
+            expr.unresolved_star.plan_id = self._plan_id
         return expr
 
     def __repr__(self) -> str:
diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index c9bf5fadd91c..2eeefc9fae23 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -76,15 +76,6 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.udtf import UserDefinedTableFunction
 
 
-def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column:
-    if col == "*":
-        return Column(UnresolvedStar(unparsed_target=None))
-    elif col.endswith(".*"):
-        return Column(UnresolvedStar(unparsed_target=col))
-    else:
-        return Column(ColumnReference(unparsed_identifier=col, 
plan_id=plan_id))
-
-
 def _to_col(col: "ColumnOrName") -> Column:
     assert isinstance(col, (Column, str))
     return col if isinstance(col, Column) else column(col)
@@ -224,7 +215,12 @@ def _options_to_col(options: Dict[str, Any]) -> Column:
 
 
 def col(col: str) -> Column:
-    return _to_col_with_plan_id(col=col, plan_id=None)
+    if col == "*":
+        return Column(UnresolvedStar(unparsed_target=None))
+    elif col.endswith(".*"):
+        return Column(UnresolvedStar(unparsed_target=col))
+    else:
+        return Column(ColumnReference(unparsed_identifier=col))
 
 
 col.__doc__ = pysparkfuncs.col.__doc__
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 1e943b8978c2..fb3ebf30d300 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x8a-\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xb4-\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -44,7 +44,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
         b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
     )
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 5875
+    _EXPRESSION._serialized_end = 5917
     _EXPRESSION_WINDOW._serialized_start = 1645
     _EXPRESSION_WINDOW._serialized_end = 2428
     _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935
@@ -80,29 +80,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4968
     _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5018
     _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5020
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5102
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5104
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5190
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5193
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5325
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 5328
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 5515
-    _EXPRESSION_ALIAS._serialized_start = 5517
-    _EXPRESSION_ALIAS._serialized_end = 5637
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5640
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5798
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5800
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5862
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5878
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6242
-    _PYTHONUDF._serialized_start = 6245
-    _PYTHONUDF._serialized_end = 6400
-    _SCALARSCALAUDF._serialized_start = 6403
-    _SCALARSCALAUDF._serialized_end = 6587
-    _JAVAUDF._serialized_start = 6590
-    _JAVAUDF._serialized_end = 6739
-    _CALLFUNCTION._serialized_start = 6741
-    _CALLFUNCTION._serialized_end = 6849
-    _NAMEDARGUMENTEXPRESSION._serialized_start = 6851
-    _NAMEDARGUMENTEXPRESSION._serialized_end = 6943
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5144
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5146
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5232
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5235
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5367
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 5370
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 5557
+    _EXPRESSION_ALIAS._serialized_start = 5559
+    _EXPRESSION_ALIAS._serialized_end = 5679
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5682
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5840
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5842
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5904
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5920
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6284
+    _PYTHONUDF._serialized_start = 6287
+    _PYTHONUDF._serialized_end = 6442
+    _SCALARSCALAUDF._serialized_start = 6445
+    _SCALARSCALAUDF._serialized_end = 6629
+    _JAVAUDF._serialized_start = 6632
+    _JAVAUDF._serialized_end = 6781
+    _CALLFUNCTION._serialized_start = 6783
+    _CALLFUNCTION._serialized_end = 6891
+    _NAMEDARGUMENTEXPRESSION._serialized_start = 6893
+    _NAMEDARGUMENTEXPRESSION._serialized_end = 6985
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 93a431dcc860..e397880a73e4 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -880,29 +880,52 @@ class Expression(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         UNPARSED_TARGET_FIELD_NUMBER: builtins.int
+        PLAN_ID_FIELD_NUMBER: builtins.int
         unparsed_target: builtins.str
         """(Optional) The target of the expansion.
 
         If set, it should end with '.*' and will be parsed by 
'parseAttributeName'
         in the server side.
         """
+        plan_id: builtins.int
+        """(Optional) The id of corresponding connect plan."""
         def __init__(
             self,
             *,
             unparsed_target: builtins.str | None = ...,
+            plan_id: builtins.int | None = ...,
         ) -> None: ...
         def HasField(
             self,
             field_name: typing_extensions.Literal[
-                "_unparsed_target", b"_unparsed_target", "unparsed_target", 
b"unparsed_target"
+                "_plan_id",
+                b"_plan_id",
+                "_unparsed_target",
+                b"_unparsed_target",
+                "plan_id",
+                b"plan_id",
+                "unparsed_target",
+                b"unparsed_target",
             ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "_unparsed_target", b"_unparsed_target", "unparsed_target", 
b"unparsed_target"
+                "_plan_id",
+                b"_plan_id",
+                "_unparsed_target",
+                b"_unparsed_target",
+                "plan_id",
+                b"plan_id",
+                "unparsed_target",
+                b"unparsed_target",
             ],
         ) -> None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_plan_id", 
b"_plan_id"]
+        ) -> typing_extensions.Literal["plan_id"] | None: ...
+        @typing.overload
         def WhichOneof(
             self, oneof_group: typing_extensions.Literal["_unparsed_target", 
b"_unparsed_target"]
         ) -> typing_extensions.Literal["unparsed_target"] | None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index a1cd00e79e1a..5cd97c9bb7f4 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -558,6 +558,44 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         ):
             cdf1.select(cdf2.a).schema
 
+    def test_invalid_star(self):
+        data1 = [Row(a=1, b=2, c=3)]
+        cdf1 = self.connect.createDataFrame(data1)
+
+        data2 = [Row(a=2, b=0)]
+        cdf2 = self.connect.createDataFrame(data2)
+
+        # Can find the target plan node, but fail to resolve with it
+        with self.assertRaisesRegex(
+            AnalysisException,
+            "CANNOT_RESOLVE_DATAFRAME_COLUMN",
+        ):
+            cdf3 = cdf1.select(cdf1.a)
+            cdf3.select(cdf1["*"]).schema
+
+        # Can find the target plan node, but fail to resolve with it
+        with self.assertRaisesRegex(
+            AnalysisException,
+            "CANNOT_RESOLVE_DATAFRAME_COLUMN",
+        ):
+            # column 'a has been replaced
+            cdf3 = cdf1.withColumn("a", CF.lit(0))
+            cdf3.select(cdf1["*"]).schema
+
+        # Can not find the target plan node by plan id
+        with self.assertRaisesRegex(
+            AnalysisException,
+            "CANNOT_RESOLVE_DATAFRAME_COLUMN",
+        ):
+            cdf1.select(cdf2["*"]).schema
+
+        # cdf1["*"] exists on both side
+        with self.assertRaisesRegex(
+            AnalysisException,
+            "AMBIGUOUS_COLUMN_REFERENCE",
+        ):
+            cdf1.join(cdf1).select(cdf1["*"]).schema
+
     def test_collect(self):
         cdf = self.connect.read.table(self.tbl_name)
         sdf = self.spark.read.table(self.tbl_name)
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index c77e7fd89d01..407ab22a088c 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -69,6 +69,41 @@ class DataFrameTestsMixin:
         self.assertEqual(self.spark.range(-2).count(), 0)
         self.assertEqual(self.spark.range(3).count(), 3)
 
+    def test_dataframe_star(self):
+        df1 = self.spark.createDataFrame([{"a": 1}])
+        df2 = self.spark.createDataFrame([{"a": 1, "b": "v"}])
+        df3 = df2.withColumnsRenamed({"a": "x", "b": "y"})
+
+        df = df1.join(df2)
+        self.assertEqual(df.columns, ["a", "a", "b"])
+        self.assertEqual(df.select(df1["*"]).columns, ["a"])
+        self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+
+        df = df1.join(df2).withColumn("c", lit(0))
+        self.assertEqual(df.columns, ["a", "a", "b", "c"])
+        self.assertEqual(df.select(df1["*"]).columns, ["a"])
+        self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+
+        df = df1.join(df2, "a")
+        self.assertEqual(df.columns, ["a", "b"])
+        self.assertEqual(df.select(df1["*"]).columns, ["a"])
+        self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+
+        df = df1.join(df2, "a").withColumn("c", lit(0))
+        self.assertEqual(df.columns, ["a", "b", "c"])
+        self.assertEqual(df.select(df1["*"]).columns, ["a"])
+        self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+
+        df = df2.join(df3)
+        self.assertEqual(df.columns, ["a", "b", "x", "y"])
+        self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+        self.assertEqual(df.select(df3["*"]).columns, ["x", "y"])
+
+        df = df2.join(df3).withColumn("c", lit(0))
+        self.assertEqual(df.columns, ["a", "b", "x", "y", "c"])
+        self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+        self.assertEqual(df.select(df3["*"]).columns, ["x", "y"])
+
     def test_self_join(self):
         df1 = self.spark.range(10).withColumn("a", lit(0))
         df2 = df1.withColumnRenamed("a", "b")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 3261aa51b9be..bc56afa73d99 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -490,7 +490,9 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
       q: Seq[LogicalPlan]): Expression = e match {
     case u: UnresolvedAttribute =>
       resolveDataFrameColumn(u, q).getOrElse(u)
-    case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) =>
+    case u: UnresolvedDataFrameStar =>
+      resolveDataFrameStar(u, q)
+    case _ if e.containsAnyPattern(UNRESOLVED_ATTRIBUTE, UNRESOLVED_DF_STAR) =>
       e.mapChildren(c => tryResolveDataFrameColumns(c, q))
     case _ => e
   }
@@ -510,7 +512,7 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
       //  df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
       //  df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
       //  df1.select(df2.a)   <-   illegal reference df2.a
-      throw QueryCompilationErrors.cannotResolveColumn(u)
+      throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
     }
     resolved
   }
@@ -588,4 +590,45 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
     }
     (filtered, matched)
   }
+
+  private def resolveDataFrameStar(
+      u: UnresolvedDataFrameStar,
+      q: Seq[LogicalPlan]): ResolvedStar = {
+    resolveDataFrameStarByPlanId(u, u.planId, q).getOrElse(
+      // Can not find the target plan node with plan id, e.g.
+      //  df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
+      //  df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
+      //  df1.select(df2["*"])   <-   illegal reference df2.a
+      throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
+    )
+  }
+
+  private def resolveDataFrameStarByPlanId(
+      u: UnresolvedDataFrameStar,
+      id: Long,
+      q: Seq[LogicalPlan]): Option[ResolvedStar] = {
+    q.iterator.map(resolveDataFrameStarRecursively(u, id, _))
+      .foldLeft(Option.empty[ResolvedStar]) {
+        case (r1, r2) =>
+          if (r1.nonEmpty && r2.nonEmpty) {
+            throw QueryCompilationErrors.ambiguousColumnReferences(u)
+          }
+          if (r1.nonEmpty) r1 else r2
+      }
+  }
+
+   private def resolveDataFrameStarRecursively(
+      u: UnresolvedDataFrameStar,
+      id: Long,
+      p: LogicalPlan): Option[ResolvedStar] = {
+     val resolved = if (p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) {
+       Some(ResolvedStar(p.output))
+     } else {
+       resolveDataFrameStarByPlanId(u, id, p.children)
+     }
+     resolved.filter { r =>
+       val outputSet = AttributeSet(p.output ++ p.metadataOutput)
+       r.expressions.forall(_.references.subsetOf(outputSet))
+     }
+   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index b32ff671b2b7..63d4cfeb83fe 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -696,6 +696,23 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) 
extends Star with Une
   override def toString: String = expressions.mkString("ResolvedStar(", ", ", 
")")
 }
 
+/**
+ * Represents all input attributes to a given relational operator.
+ * This is used in Spark Connect dataframe, for example:
+ *    df1 = spark.createDataFrame([{"id": 1}])
+ *    df2 = spark.createDataFrame([{"id": 1, "val": "v"}])
+ *    df1.join(df2, "id").select(df1["*"])
+ * @param planId the plan id of target node.
+ */
+case class UnresolvedDataFrameStar(planId: Long)
+  extends LeafExpression with Unevaluable {
+  override def nullable: Boolean = throw new UnresolvedException("nullable")
+  override def dataType: DataType = throw new UnresolvedException("dataType")
+  override lazy val resolved = false
+  final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_DF_STAR)
+  override def toString: String = "UnresolvedDataFrameStar"
+}
+
 /**
  * Extracts a value or values from an Expression
  *
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index daa4ea0c8616..fcf65659c24f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -143,6 +143,7 @@ object TreePattern extends Enumeration  {
   val UNRESOLVED_ALIAS: Value = Value
   val UNRESOLVED_ATTRIBUTE: Value = Value
   val UNRESOLVED_DESERIALIZER: Value = Value
+  val UNRESOLVED_DF_STAR: Value = Value
   val UNRESOLVED_HAVING: Value = Value
   val UNRESOLVED_IDENTIFIER: Value = Value
   val UNRESOLVED_ORDINAL: Value = Value
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index e0740a325358..200edd97a88a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkThrowable, 
SparkUnsupportedOperationException}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, 
FunctionIdentifier, InternalRow, QualifiedTableName, TableIdentifier}
-import 
org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, 
FunctionAlreadyExistsException, NamespaceAlreadyExistsException, 
NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, 
NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException, 
UnresolvedAttribute, UnresolvedRegex}
+import 
org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, 
FunctionAlreadyExistsException, NamespaceAlreadyExistsException, 
NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, 
NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException, 
UnresolvedRegex}
 import org.apache.spark.sql.catalyst.catalog.{CatalogTable, 
InvalidUDFClassException}
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, 
GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, 
WindowSpecDefinition}
@@ -3952,19 +3952,19 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
         "expectedSchema" -> toSQLType(expectedSchema)))
   }
 
-  def cannotResolveColumn(u: UnresolvedAttribute): Throwable = {
+  def cannotResolveDataFrameColumn(e: Expression): Throwable = {
     new AnalysisException(
       errorClass = "CANNOT_RESOLVE_DATAFRAME_COLUMN",
-      messageParameters = Map("name" -> toSQLId(u.nameParts)),
-      origin = u.origin
+      messageParameters = Map("name" -> toSQLExpr(e)),
+      origin = e.origin
     )
   }
 
-  def ambiguousColumnReferences(u: UnresolvedAttribute): Throwable = {
+  def ambiguousColumnReferences(e: Expression): Throwable = {
     new AnalysisException(
       errorClass = "AMBIGUOUS_COLUMN_REFERENCE",
-      messageParameters = Map("name" -> toSQLId(u.nameParts)),
-      origin = u.origin
+      messageParameters = Map("name" -> toSQLExpr(e)),
+      origin = e.origin
     )
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index bfbc3287c63c..39d720c933a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -155,7 +155,7 @@ class Column(val expr: Expression) extends Logging {
     name match {
       case "*" => UnresolvedStar(None)
       case _ if name.endsWith(".*") =>
-        val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, 
name.length - 2))
+        val parts = UnresolvedAttribute.parseAttributeName(name.dropRight(2))
         UnresolvedStar(Some(parts))
       case _ => UnresolvedAttribute.quotedString(name)
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to