This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cd3fa2f6b937 [SPARK-46677][SQL][CONNECT] Fix `dataframe["*"]`
resolution
cd3fa2f6b937 is described below
commit cd3fa2f6b9373fed59985ef9383a718884595f5b
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 16 09:25:10 2024 +0900
[SPARK-46677][SQL][CONNECT] Fix `dataframe["*"]` resolution
### What changes were proposed in this pull request?
On Spark Connect, `df.col("*")` should be resolved against the target plan
### Why are the changes needed?
```
In [6]: df1 = spark.createDataFrame([{"id": 1}])
In [7]: df2 = spark.createDataFrame([{"id": 1, "val": "v"}])
In [8]: df1.join(df2)
Out[8]: DataFrame[id: bigint, id: bigint, val: string]
In [9]: df1.join(df2).select(df1["*"])
Out[9]: DataFrame[id: bigint, id: bigint, val: string]
```
it should be
```
In [3]: df1.join(df2).select(df1["*"])
Out[3]: DataFrame[id: bigint]
```
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44689 from zhengruifeng/py_df_star.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/expressions.proto | 3 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 35 ++++++++------
python/pyspark/sql/connect/dataframe.py | 48 +++++++++++++------
python/pyspark/sql/connect/expressions.py | 8 +++-
python/pyspark/sql/connect/functions/builtin.py | 16 +++----
.../pyspark/sql/connect/proto/expressions_pb2.py | 54 +++++++++++-----------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 27 ++++++++++-
.../sql/tests/connect/test_connect_basic.py | 38 +++++++++++++++
python/pyspark/sql/tests/test_dataframe.py | 35 ++++++++++++++
.../catalyst/analysis/ColumnResolutionHelper.scala | 47 ++++++++++++++++++-
.../spark/sql/catalyst/analysis/unresolved.scala | 17 +++++++
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../spark/sql/errors/QueryCompilationErrors.scala | 14 +++---
.../main/scala/org/apache/spark/sql/Column.scala | 2 +-
14 files changed, 268 insertions(+), 77 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 4aac2bcc612b..c3333636bf68 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -261,6 +261,9 @@ message Expression {
// If set, it should end with '.*' and will be parsed by
'parseAttributeName'
// in the server side.
optional string unparsed_target = 1;
+
+ // (Optional) The id of corresponding connect plan.
+ optional int64 plan_id = 2;
}
// Represents all of the input attributes to a given relational operator,
for example in
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index dc57adc90c42..25c78413170e 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter,
Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue,
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
ExpressionEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
@@ -2109,19 +2109,28 @@ class SparkConnectPlanner(
parser.parseExpression(expr.getExpression)
}
- private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar):
UnresolvedStar = {
- if (star.hasUnparsedTarget) {
- val target = star.getUnparsedTarget
- if (!target.endsWith(".*")) {
- throw InvalidPlanInput(
- s"UnresolvedStar requires a unparsed target ending with '.*', " +
- s"but got $target.")
- }
+ private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar):
Expression = {
+ (star.hasUnparsedTarget, star.hasPlanId) match {
+ case (false, false) =>
+ // functions.col("*")
+ UnresolvedStar(None)
- UnresolvedStar(
- Some(UnresolvedAttribute.parseAttributeName(target.substring(0,
target.length - 2))))
- } else {
- UnresolvedStar(None)
+ case (true, false) =>
+ // functions.col("s.*")
+ val target = star.getUnparsedTarget
+ if (!target.endsWith(".*")) {
+ throw InvalidPlanInput(
+ s"UnresolvedStar requires a unparsed target ending with '.*', but
got $target.")
+ }
+ val parts = UnresolvedAttribute.parseAttributeName(target.dropRight(2))
+ UnresolvedStar(Some(parts))
+
+ case (false, true) =>
+ // dataframe.col("*")
+ UnresolvedDataFrameStar(star.getPlanId)
+
+ case _ =>
+ throw InvalidPlanInput("UnresolvedStar with both target and plan id is
not supported.")
}
}
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 66059ad96ebb..7ee27065208c 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -71,9 +71,12 @@ from pyspark.sql.connect.group import GroupedData
from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import UnresolvedRegex
+from pyspark.sql.connect.expressions import (
+ ColumnReference,
+ UnresolvedRegex,
+ UnresolvedStar,
+)
from pyspark.sql.connect.functions.builtin import (
- _to_col_with_plan_id,
_to_col,
_invoke_function,
col,
@@ -1702,9 +1705,11 @@ class DataFrame:
error_class="ATTRIBUTE_NOT_SUPPORTED",
message_parameters={"attr_name": name}
)
- return _to_col_with_plan_id(
- col=name,
- plan_id=self._plan._plan_id,
+ return Column(
+ ColumnReference(
+ unparsed_identifier=name,
+ plan_id=self._plan._plan_id,
+ )
)
__getattr__.__doc__ = PySparkDataFrame.__getattr__.__doc__
@@ -1719,14 +1724,31 @@ class DataFrame:
def __getitem__(self, item: Union[int, str, Column, List, Tuple]) ->
Union[Column, "DataFrame"]:
if isinstance(item, str):
- # validate the column name
- if not hasattr(self._session, "is_mock_session"):
- self.select(item).isLocal()
-
- return _to_col_with_plan_id(
- col=item,
- plan_id=self._plan._plan_id,
- )
+ if item == "*":
+ return Column(
+ UnresolvedStar(
+ unparsed_target=None,
+ plan_id=self._plan._plan_id,
+ )
+ )
+ else:
+ # TODO: revisit vanilla Spark's Dataset.col
+ # if
(sparkSession.sessionState.conf.supportQuotedRegexColumnName) {
+ # colRegex(colName)
+ # } else {
+ # Column(addDataFrameIdToCol(resolve(colName)))
+ # }
+
+ # validate the column name
+ if not hasattr(self._session, "is_mock_session"):
+ self.select(item).isLocal()
+
+ return Column(
+ ColumnReference(
+ unparsed_identifier=item,
+ plan_id=self._plan._plan_id,
+ )
+ )
elif isinstance(item, Column):
return self.filter(item)
elif isinstance(item, (list, tuple)):
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 384422eed7d1..f985e88d0f23 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -494,19 +494,23 @@ class ColumnReference(Expression):
class UnresolvedStar(Expression):
- def __init__(self, unparsed_target: Optional[str]):
+ def __init__(self, unparsed_target: Optional[str], plan_id: Optional[int]
= None):
super().__init__()
if unparsed_target is not None:
assert isinstance(unparsed_target, str) and
unparsed_target.endswith(".*")
-
self._unparsed_target = unparsed_target
+ assert plan_id is None or isinstance(plan_id, int)
+ self._plan_id = plan_id
+
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.unresolved_star.SetInParent()
if self._unparsed_target is not None:
expr.unresolved_star.unparsed_target = self._unparsed_target
+ if self._plan_id is not None:
+ expr.unresolved_star.plan_id = self._plan_id
return expr
def __repr__(self) -> str:
diff --git a/python/pyspark/sql/connect/functions/builtin.py
b/python/pyspark/sql/connect/functions/builtin.py
index c9bf5fadd91c..2eeefc9fae23 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -76,15 +76,6 @@ if TYPE_CHECKING:
from pyspark.sql.connect.udtf import UserDefinedTableFunction
-def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column:
- if col == "*":
- return Column(UnresolvedStar(unparsed_target=None))
- elif col.endswith(".*"):
- return Column(UnresolvedStar(unparsed_target=col))
- else:
- return Column(ColumnReference(unparsed_identifier=col,
plan_id=plan_id))
-
-
def _to_col(col: "ColumnOrName") -> Column:
assert isinstance(col, (Column, str))
return col if isinstance(col, Column) else column(col)
@@ -224,7 +215,12 @@ def _options_to_col(options: Dict[str, Any]) -> Column:
def col(col: str) -> Column:
- return _to_col_with_plan_id(col=col, plan_id=None)
+ if col == "*":
+ return Column(UnresolvedStar(unparsed_target=None))
+ elif col.endswith(".*"):
+ return Column(UnresolvedStar(unparsed_target=col))
+ else:
+ return Column(ColumnReference(unparsed_identifier=col))
col.__doc__ = pysparkfuncs.col.__doc__
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 1e943b8978c2..fb3ebf30d300 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x8a-\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xb4-\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -44,7 +44,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
)
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 5875
+ _EXPRESSION._serialized_end = 5917
_EXPRESSION_WINDOW._serialized_start = 1645
_EXPRESSION_WINDOW._serialized_end = 2428
_EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935
@@ -80,29 +80,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_EXPRESSIONSTRING._serialized_start = 4968
_EXPRESSION_EXPRESSIONSTRING._serialized_end = 5018
_EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5020
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5102
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5104
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5190
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5193
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5325
- _EXPRESSION_UPDATEFIELDS._serialized_start = 5328
- _EXPRESSION_UPDATEFIELDS._serialized_end = 5515
- _EXPRESSION_ALIAS._serialized_start = 5517
- _EXPRESSION_ALIAS._serialized_end = 5637
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5640
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5798
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5800
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5862
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5878
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6242
- _PYTHONUDF._serialized_start = 6245
- _PYTHONUDF._serialized_end = 6400
- _SCALARSCALAUDF._serialized_start = 6403
- _SCALARSCALAUDF._serialized_end = 6587
- _JAVAUDF._serialized_start = 6590
- _JAVAUDF._serialized_end = 6739
- _CALLFUNCTION._serialized_start = 6741
- _CALLFUNCTION._serialized_end = 6849
- _NAMEDARGUMENTEXPRESSION._serialized_start = 6851
- _NAMEDARGUMENTEXPRESSION._serialized_end = 6943
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5144
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5146
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5232
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5235
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5367
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 5370
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 5557
+ _EXPRESSION_ALIAS._serialized_start = 5559
+ _EXPRESSION_ALIAS._serialized_end = 5679
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5682
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5840
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5842
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5904
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5920
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6284
+ _PYTHONUDF._serialized_start = 6287
+ _PYTHONUDF._serialized_end = 6442
+ _SCALARSCALAUDF._serialized_start = 6445
+ _SCALARSCALAUDF._serialized_end = 6629
+ _JAVAUDF._serialized_start = 6632
+ _JAVAUDF._serialized_end = 6781
+ _CALLFUNCTION._serialized_start = 6783
+ _CALLFUNCTION._serialized_end = 6891
+ _NAMEDARGUMENTEXPRESSION._serialized_start = 6893
+ _NAMEDARGUMENTEXPRESSION._serialized_end = 6985
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 93a431dcc860..e397880a73e4 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -880,29 +880,52 @@ class Expression(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
UNPARSED_TARGET_FIELD_NUMBER: builtins.int
+ PLAN_ID_FIELD_NUMBER: builtins.int
unparsed_target: builtins.str
"""(Optional) The target of the expansion.
If set, it should end with '.*' and will be parsed by
'parseAttributeName'
in the server side.
"""
+ plan_id: builtins.int
+ """(Optional) The id of corresponding connect plan."""
def __init__(
self,
*,
unparsed_target: builtins.str | None = ...,
+ plan_id: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
- "_unparsed_target", b"_unparsed_target", "unparsed_target",
b"unparsed_target"
+ "_plan_id",
+ b"_plan_id",
+ "_unparsed_target",
+ b"_unparsed_target",
+ "plan_id",
+ b"plan_id",
+ "unparsed_target",
+ b"unparsed_target",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "_unparsed_target", b"_unparsed_target", "unparsed_target",
b"unparsed_target"
+ "_plan_id",
+ b"_plan_id",
+ "_unparsed_target",
+ b"_unparsed_target",
+ "plan_id",
+ b"plan_id",
+ "unparsed_target",
+ b"unparsed_target",
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_plan_id",
b"_plan_id"]
+ ) -> typing_extensions.Literal["plan_id"] | None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_unparsed_target",
b"_unparsed_target"]
) -> typing_extensions.Literal["unparsed_target"] | None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index a1cd00e79e1a..5cd97c9bb7f4 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -558,6 +558,44 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
):
cdf1.select(cdf2.a).schema
+ def test_invalid_star(self):
+ data1 = [Row(a=1, b=2, c=3)]
+ cdf1 = self.connect.createDataFrame(data1)
+
+ data2 = [Row(a=2, b=0)]
+ cdf2 = self.connect.createDataFrame(data2)
+
+ # Can find the target plan node, but fail to resolve with it
+ with self.assertRaisesRegex(
+ AnalysisException,
+ "CANNOT_RESOLVE_DATAFRAME_COLUMN",
+ ):
+ cdf3 = cdf1.select(cdf1.a)
+ cdf3.select(cdf1["*"]).schema
+
+ # Can find the target plan node, but fail to resolve with it
+ with self.assertRaisesRegex(
+ AnalysisException,
+ "CANNOT_RESOLVE_DATAFRAME_COLUMN",
+ ):
+ # column 'a has been replaced
+ cdf3 = cdf1.withColumn("a", CF.lit(0))
+ cdf3.select(cdf1["*"]).schema
+
+ # Can not find the target plan node by plan id
+ with self.assertRaisesRegex(
+ AnalysisException,
+ "CANNOT_RESOLVE_DATAFRAME_COLUMN",
+ ):
+ cdf1.select(cdf2["*"]).schema
+
+ # cdf1["*"] exists on both side
+ with self.assertRaisesRegex(
+ AnalysisException,
+ "AMBIGUOUS_COLUMN_REFERENCE",
+ ):
+ cdf1.join(cdf1).select(cdf1["*"]).schema
+
def test_collect(self):
cdf = self.connect.read.table(self.tbl_name)
sdf = self.spark.read.table(self.tbl_name)
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index c77e7fd89d01..407ab22a088c 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -69,6 +69,41 @@ class DataFrameTestsMixin:
self.assertEqual(self.spark.range(-2).count(), 0)
self.assertEqual(self.spark.range(3).count(), 3)
+ def test_dataframe_star(self):
+ df1 = self.spark.createDataFrame([{"a": 1}])
+ df2 = self.spark.createDataFrame([{"a": 1, "b": "v"}])
+ df3 = df2.withColumnsRenamed({"a": "x", "b": "y"})
+
+ df = df1.join(df2)
+ self.assertEqual(df.columns, ["a", "a", "b"])
+ self.assertEqual(df.select(df1["*"]).columns, ["a"])
+ self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+
+ df = df1.join(df2).withColumn("c", lit(0))
+ self.assertEqual(df.columns, ["a", "a", "b", "c"])
+ self.assertEqual(df.select(df1["*"]).columns, ["a"])
+ self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+
+ df = df1.join(df2, "a")
+ self.assertEqual(df.columns, ["a", "b"])
+ self.assertEqual(df.select(df1["*"]).columns, ["a"])
+ self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+
+ df = df1.join(df2, "a").withColumn("c", lit(0))
+ self.assertEqual(df.columns, ["a", "b", "c"])
+ self.assertEqual(df.select(df1["*"]).columns, ["a"])
+ self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+
+ df = df2.join(df3)
+ self.assertEqual(df.columns, ["a", "b", "x", "y"])
+ self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+ self.assertEqual(df.select(df3["*"]).columns, ["x", "y"])
+
+ df = df2.join(df3).withColumn("c", lit(0))
+ self.assertEqual(df.columns, ["a", "b", "x", "y", "c"])
+ self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
+ self.assertEqual(df.select(df3["*"]).columns, ["x", "y"])
+
def test_self_join(self):
df1 = self.spark.range(10).withColumn("a", lit(0))
df2 = df1.withColumnRenamed("a", "b")
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 3261aa51b9be..bc56afa73d99 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -490,7 +490,9 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
q: Seq[LogicalPlan]): Expression = e match {
case u: UnresolvedAttribute =>
resolveDataFrameColumn(u, q).getOrElse(u)
- case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) =>
+ case u: UnresolvedDataFrameStar =>
+ resolveDataFrameStar(u, q)
+ case _ if e.containsAnyPattern(UNRESOLVED_ATTRIBUTE, UNRESOLVED_DF_STAR) =>
e.mapChildren(c => tryResolveDataFrameColumns(c, q))
case _ => e
}
@@ -510,7 +512,7 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2.a) <- illegal reference df2.a
- throw QueryCompilationErrors.cannotResolveColumn(u)
+ throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
}
resolved
}
@@ -588,4 +590,45 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
}
(filtered, matched)
}
+
+ private def resolveDataFrameStar(
+ u: UnresolvedDataFrameStar,
+ q: Seq[LogicalPlan]): ResolvedStar = {
+ resolveDataFrameStarByPlanId(u, u.planId, q).getOrElse(
+ // Can not find the target plan node with plan id, e.g.
+ // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
+ // df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
+ // df1.select(df2["*"]) <- illegal reference df2.a
+ throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
+ )
+ }
+
+ private def resolveDataFrameStarByPlanId(
+ u: UnresolvedDataFrameStar,
+ id: Long,
+ q: Seq[LogicalPlan]): Option[ResolvedStar] = {
+ q.iterator.map(resolveDataFrameStarRecursively(u, id, _))
+ .foldLeft(Option.empty[ResolvedStar]) {
+ case (r1, r2) =>
+ if (r1.nonEmpty && r2.nonEmpty) {
+ throw QueryCompilationErrors.ambiguousColumnReferences(u)
+ }
+ if (r1.nonEmpty) r1 else r2
+ }
+ }
+
+ private def resolveDataFrameStarRecursively(
+ u: UnresolvedDataFrameStar,
+ id: Long,
+ p: LogicalPlan): Option[ResolvedStar] = {
+ val resolved = if (p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) {
+ Some(ResolvedStar(p.output))
+ } else {
+ resolveDataFrameStarByPlanId(u, id, p.children)
+ }
+ resolved.filter { r =>
+ val outputSet = AttributeSet(p.output ++ p.metadataOutput)
+ r.expressions.forall(_.references.subsetOf(outputSet))
+ }
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index b32ff671b2b7..63d4cfeb83fe 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -696,6 +696,23 @@ case class ResolvedStar(expressions: Seq[NamedExpression])
extends Star with Une
override def toString: String = expressions.mkString("ResolvedStar(", ", ",
")")
}
+/**
+ * Represents all input attributes to a given relational operator.
+ * This is used in Spark Connect dataframe, for example:
+ * df1 = spark.createDataFrame([{"id": 1}])
+ * df2 = spark.createDataFrame([{"id": 1, "val": "v"}])
+ * df1.join(df2, "id").select(df1["*"])
+ * @param planId the plan id of target node.
+ */
+case class UnresolvedDataFrameStar(planId: Long)
+ extends LeafExpression with Unevaluable {
+ override def nullable: Boolean = throw new UnresolvedException("nullable")
+ override def dataType: DataType = throw new UnresolvedException("dataType")
+ override lazy val resolved = false
+ final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_DF_STAR)
+ override def toString: String = "UnresolvedDataFrameStar"
+}
+
/**
* Extracts a value or values from an Expression
*
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index daa4ea0c8616..fcf65659c24f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -143,6 +143,7 @@ object TreePattern extends Enumeration {
val UNRESOLVED_ALIAS: Value = Value
val UNRESOLVED_ATTRIBUTE: Value = Value
val UNRESOLVED_DESERIALIZER: Value = Value
+ val UNRESOLVED_DF_STAR: Value = Value
val UNRESOLVED_HAVING: Value = Value
val UNRESOLVED_IDENTIFIER: Value = Value
val UNRESOLVED_ORDINAL: Value = Value
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index e0740a325358..200edd97a88a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkThrowable,
SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{ExtendedAnalysisException,
FunctionIdentifier, InternalRow, QualifiedTableName, TableIdentifier}
-import
org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException,
FunctionAlreadyExistsException, NamespaceAlreadyExistsException,
NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException,
NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException,
UnresolvedAttribute, UnresolvedRegex}
+import
org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException,
FunctionAlreadyExistsException, NamespaceAlreadyExistsException,
NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException,
NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException,
UnresolvedRegex}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable,
InvalidUDFClassException}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression,
GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction,
WindowSpecDefinition}
@@ -3952,19 +3952,19 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
"expectedSchema" -> toSQLType(expectedSchema)))
}
- def cannotResolveColumn(u: UnresolvedAttribute): Throwable = {
+ def cannotResolveDataFrameColumn(e: Expression): Throwable = {
new AnalysisException(
errorClass = "CANNOT_RESOLVE_DATAFRAME_COLUMN",
- messageParameters = Map("name" -> toSQLId(u.nameParts)),
- origin = u.origin
+ messageParameters = Map("name" -> toSQLExpr(e)),
+ origin = e.origin
)
}
- def ambiguousColumnReferences(u: UnresolvedAttribute): Throwable = {
+ def ambiguousColumnReferences(e: Expression): Throwable = {
new AnalysisException(
errorClass = "AMBIGUOUS_COLUMN_REFERENCE",
- messageParameters = Map("name" -> toSQLId(u.nameParts)),
- origin = u.origin
+ messageParameters = Map("name" -> toSQLExpr(e)),
+ origin = e.origin
)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index bfbc3287c63c..39d720c933a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -155,7 +155,7 @@ class Column(val expr: Expression) extends Logging {
name match {
case "*" => UnresolvedStar(None)
case _ if name.endsWith(".*") =>
- val parts = UnresolvedAttribute.parseAttributeName(name.substring(0,
name.length - 2))
+ val parts = UnresolvedAttribute.parseAttributeName(name.dropRight(2))
UnresolvedStar(Some(parts))
case _ => UnresolvedAttribute.quotedString(name)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]