This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7b200898967 [SPARK-41321][CONNECT] Support target field for UnresolvedStar 7b200898967 is described below commit 7b20089896716a5fa7cad595bd560640d1b5afcf Author: dengziming <dengzim...@bytedance.com> AuthorDate: Thu Dec 1 10:40:00 2022 +0800 [SPARK-41321][CONNECT] Support target field for UnresolvedStar ### What changes were proposed in this pull request? 1. Support target field UnresolvedStar 2. UnresolvedStar can be used simultaneously with other expression. ### Why are the changes needed? This is a necessary feature for UnresolvedStar ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added 2 new unit tests. Closes #38838 from dengziming/SPARK-41321. Authored-by: dengziming <dengzim...@bytedance.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/expressions.proto | 4 +- .../sql/connect/planner/SparkConnectPlanner.scala | 17 ++-- .../connect/planner/SparkConnectPlannerSuite.scala | 105 ++++++++++++++++++++- .../pyspark/sql/connect/proto/expressions_pb2.py | 55 ++++++----- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 13 +++ 5 files changed, 158 insertions(+), 36 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto index 2a1159c1d04..b90f7619b8f 100644 --- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto @@ -18,7 +18,6 @@ syntax = 'proto3'; import "spark/connect/types.proto"; -import "google/protobuf/any.proto"; package spark.connect; @@ -142,6 +141,9 @@ message Expression { // UnresolvedStar is used to expand all the fields of a relation or struct. message UnresolvedStar { + // (Optional) The target of the expansion, either be a table name or struct name, this + // is a list of identifiers that is the path of the expansion. + repeated string target = 1; } message Alias { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index d1d4c3d4fa9..5ebe7c7cce3 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -392,13 +392,8 @@ class SparkConnectPlanner(session: SparkSession) { } else { logical.OneRowRelation() } - // TODO: support the target field for *. val projection = - if (rel.getExpressionsCount == 1 && rel.getExpressions(0).hasUnresolvedStar) { - Seq(UnresolvedStar(Option.empty)) - } else { - rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) - } + rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_)) logical.Project(projectList = projection.toSeq, child = baseRel) } @@ -416,6 +411,8 @@ class SparkConnectPlanner(session: SparkSession) { case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias) case proto.Expression.ExprTypeCase.EXPRESSION_STRING => transformExpressionString(exp.getExpressionString) + case proto.Expression.ExprTypeCase.UNRESOLVED_STAR => + transformUnresolvedStar(exp.getUnresolvedStar) case _ => throw InvalidPlanInput( s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported") @@ -573,6 +570,14 @@ class SparkConnectPlanner(session: SparkSession) { session.sessionState.sqlParser.parseExpression(expr.getExpression) } + private def transformUnresolvedStar(regex: proto.Expression.UnresolvedStar): Expression = { + if (regex.getTargetList.isEmpty) { + UnresolvedStar(Option.empty) + } else { + UnresolvedStar(Some(regex.getTargetList.asScala.toSeq)) + } + } + private def transformSetOperation(u: proto.SetOperation): LogicalPlan = { assert(u.hasLeftInput && u.hasRightInput, "Union must have 2 inputs") diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 8fbf2be3730..81e5ee3d0ce 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -23,7 +23,7 @@ import com.google.protobuf.ByteString import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.Expression.UnresolvedStar +import org.apache.spark.connect.proto.Expression.{Alias, ExpressionString, UnresolvedStar} import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection} @@ -468,4 +468,107 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } assert(e.getMessage.contains("part1, part2")) } + + test("transform UnresolvedStar and ExpressionString") { + val sql = + "SELECT * FROM VALUES (1,'spark',1), (2,'hadoop',2), (3,'kafka',3) AS tab(id, name, value)" + val input = proto.Relation + .newBuilder() + .setSql( + proto.SQL + .newBuilder() + .setQuery(sql) + .build()) + + val project = + proto.Project + .newBuilder() + .setInput(input) + .addExpressions( + proto.Expression + .newBuilder() + .setUnresolvedStar(UnresolvedStar.newBuilder().build()) + .build()) + .addExpressions( + proto.Expression + .newBuilder() + .setExpressionString(ExpressionString.newBuilder().setExpression("name").build()) + .build()) + .build() + + val df = + Dataset.ofRows(spark, transform(proto.Relation.newBuilder.setProject(project).build())) + val array = df.collect() + assert(array.length == 3) + assert(array(0).toString == InternalRow(1, "spark", 1, "spark").toString) + assert(array(1).toString == InternalRow(2, "hadoop", 2, "hadoop").toString) + assert(array(2).toString == InternalRow(3, "kafka", 3, "kafka").toString) + } + + test("transform UnresolvedStar with target field") { + val rows = (0 until 10).map { i => + InternalRow(InternalRow(InternalRow(i, i + 1))) + } + + val schema = StructType( + Seq( + StructField( + "a", + StructType(Seq(StructField( + "b", + StructType(Seq(StructField("c", IntegerType), StructField("d", IntegerType))))))))) + val inputRows = rows.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + + val localRelation = createLocalRelationProto(schema.toAttributes, inputRows) + + val project = + proto.Project + .newBuilder() + .setInput(localRelation) + .addExpressions( + proto.Expression + .newBuilder() + .setUnresolvedStar(UnresolvedStar.newBuilder().addTarget("a").addTarget("b").build()) + .build()) + .build() + + val df = + Dataset.ofRows(spark, transform(proto.Relation.newBuilder.setProject(project).build())) + assertResult(df.schema)( + StructType(Seq(StructField("c", IntegerType), StructField("d", IntegerType)))) + + val array = df.collect() + assert(array.length == 10) + for (i <- 0 until 10) { + assert(i == array(i).getInt(0)) + assert(i + 1 == array(i).getInt(1)) + } + } + + test("transform Project with Alias") { + val input = proto.Expression + .newBuilder() + .setLiteral( + proto.Expression.Literal + .newBuilder() + .setInteger(1) + .build()) + + val project = + proto.Project + .newBuilder() + .addExpressions( + proto.Expression + .newBuilder() + .setAlias(Alias.newBuilder().setExpr(input).addName("id").build()) + .build()) + .build() + + val df = + Dataset.ofRows(spark, transform(proto.Relation.newBuilder.setProject(project).build())) + assert(df.schema.fields.toSeq.map(_.name) == Seq("id")) + } } diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index afa783742d2..a1d9dcb91b0 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -30,11 +30,10 @@ _sym_db = _symbol_database.Default() from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xa8\x12\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xc0\x12\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st [...] ) @@ -186,30 +185,30 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" - _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 2449 - _EXPRESSION_LITERAL._serialized_start = 613 - _EXPRESSION_LITERAL._serialized_end = 2071 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1509 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1626 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1628 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1726 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 1728 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 1795 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 1797 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 1863 - _EXPRESSION_LITERAL_MAP._serialized_start = 1866 - _EXPRESSION_LITERAL_MAP._serialized_end = 2055 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 1939 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2055 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2073 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2143 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2145 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2244 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2246 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2296 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2298 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2314 - _EXPRESSION_ALIAS._serialized_start = 2316 - _EXPRESSION_ALIAS._serialized_end = 2436 + _EXPRESSION._serialized_start = 78 + _EXPRESSION._serialized_end = 2446 + _EXPRESSION_LITERAL._serialized_start = 586 + _EXPRESSION_LITERAL._serialized_end = 2044 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1482 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1599 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1601 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1699 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 1701 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 1768 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 1770 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 1836 + _EXPRESSION_LITERAL_MAP._serialized_start = 1839 + _EXPRESSION_LITERAL_MAP._serialized_end = 2028 + _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 1912 + _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2028 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2046 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2116 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2118 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2217 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2219 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2269 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2271 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2311 + _EXPRESSION_ALIAS._serialized_start = 2313 + _EXPRESSION_ALIAS._serialized_end = 2433 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index f1c599964bf..ddd9338d85d 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -505,8 +505,21 @@ class Expression(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + TARGET_FIELD_NUMBER: builtins.int + @property + def target( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Optional) The target of the expansion, either be a table name or struct name, this + is a list of identifiers that is the path of the expansion. + """ def __init__( self, + *, + target: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["target", b"target"] ) -> None: ... class Alias(google.protobuf.message.Message): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org