This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b200898967 [SPARK-41321][CONNECT] Support target field for 
UnresolvedStar
7b200898967 is described below

commit 7b20089896716a5fa7cad595bd560640d1b5afcf
Author: dengziming <dengzim...@bytedance.com>
AuthorDate: Thu Dec 1 10:40:00 2022 +0800

    [SPARK-41321][CONNECT] Support target field for UnresolvedStar
    
    ### What changes were proposed in this pull request?
    1. Support target field UnresolvedStar
    2. UnresolvedStar can be used simultaneously with other expression.
    
    ### Why are the changes needed?
    This is a necessary feature for UnresolvedStar
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added 2 new unit tests.
    
    Closes #38838 from dengziming/SPARK-41321.
    
    Authored-by: dengziming <dengzim...@bytedance.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  |   4 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  17 ++--
 .../connect/planner/SparkConnectPlannerSuite.scala | 105 ++++++++++++++++++++-
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  55 ++++++-----
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  13 +++
 5 files changed, 158 insertions(+), 36 deletions(-)

diff --git 
a/connector/connect/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index 2a1159c1d04..b90f7619b8f 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -18,7 +18,6 @@
 syntax = 'proto3';
 
 import "spark/connect/types.proto";
-import "google/protobuf/any.proto";
 
 package spark.connect;
 
@@ -142,6 +141,9 @@ message Expression {
 
   // UnresolvedStar is used to expand all the fields of a relation or struct.
   message UnresolvedStar {
+    // (Optional) The target of the expansion, either be a table name or 
struct name, this
+    // is a list of identifiers that is the path of the expansion.
+    repeated string target = 1;
   }
 
   message Alias {
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d1d4c3d4fa9..5ebe7c7cce3 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -392,13 +392,8 @@ class SparkConnectPlanner(session: SparkSession) {
     } else {
       logical.OneRowRelation()
     }
-    // TODO: support the target field for *.
     val projection =
-      if (rel.getExpressionsCount == 1 && 
rel.getExpressions(0).hasUnresolvedStar) {
-        Seq(UnresolvedStar(Option.empty))
-      } else {
-        
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
-      }
+      
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
     logical.Project(projectList = projection.toSeq, child = baseRel)
   }
 
@@ -416,6 +411,8 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias)
       case proto.Expression.ExprTypeCase.EXPRESSION_STRING =>
         transformExpressionString(exp.getExpressionString)
+      case proto.Expression.ExprTypeCase.UNRESOLVED_STAR =>
+        transformUnresolvedStar(exp.getUnresolvedStar)
       case _ =>
         throw InvalidPlanInput(
           s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not 
supported")
@@ -573,6 +570,14 @@ class SparkConnectPlanner(session: SparkSession) {
     session.sessionState.sqlParser.parseExpression(expr.getExpression)
   }
 
+  private def transformUnresolvedStar(regex: proto.Expression.UnresolvedStar): 
Expression = {
+    if (regex.getTargetList.isEmpty) {
+      UnresolvedStar(Option.empty)
+    } else {
+      UnresolvedStar(Some(regex.getTargetList.asScala.toSeq))
+    }
+  }
+
   private def transformSetOperation(u: proto.SetOperation): LogicalPlan = {
     assert(u.hasLeftInput && u.hasRightInput, "Union must have 2 inputs")
 
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 8fbf2be3730..81e5ee3d0ce 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -23,7 +23,7 @@ import com.google.protobuf.ByteString
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.Expression.UnresolvedStar
+import org.apache.spark.connect.proto.Expression.{Alias, ExpressionString, 
UnresolvedStar}
 import org.apache.spark.sql.{AnalysisException, Dataset}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
UnsafeProjection}
@@ -468,4 +468,107 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
     }
     assert(e.getMessage.contains("part1, part2"))
   }
+
+  test("transform UnresolvedStar and ExpressionString") {
+    val sql =
+      "SELECT * FROM VALUES (1,'spark',1), (2,'hadoop',2), (3,'kafka',3) AS 
tab(id, name, value)"
+    val input = proto.Relation
+      .newBuilder()
+      .setSql(
+        proto.SQL
+          .newBuilder()
+          .setQuery(sql)
+          .build())
+
+    val project =
+      proto.Project
+        .newBuilder()
+        .setInput(input)
+        .addExpressions(
+          proto.Expression
+            .newBuilder()
+            .setUnresolvedStar(UnresolvedStar.newBuilder().build())
+            .build())
+        .addExpressions(
+          proto.Expression
+            .newBuilder()
+            
.setExpressionString(ExpressionString.newBuilder().setExpression("name").build())
+            .build())
+        .build()
+
+    val df =
+      Dataset.ofRows(spark, 
transform(proto.Relation.newBuilder.setProject(project).build()))
+    val array = df.collect()
+    assert(array.length == 3)
+    assert(array(0).toString == InternalRow(1, "spark", 1, "spark").toString)
+    assert(array(1).toString == InternalRow(2, "hadoop", 2, "hadoop").toString)
+    assert(array(2).toString == InternalRow(3, "kafka", 3, "kafka").toString)
+  }
+
+  test("transform UnresolvedStar with target field") {
+    val rows = (0 until 10).map { i =>
+      InternalRow(InternalRow(InternalRow(i, i + 1)))
+    }
+
+    val schema = StructType(
+      Seq(
+        StructField(
+          "a",
+          StructType(Seq(StructField(
+            "b",
+            StructType(Seq(StructField("c", IntegerType), StructField("d", 
IntegerType)))))))))
+    val inputRows = rows.map { row =>
+      val proj = UnsafeProjection.create(schema)
+      proj(row).copy()
+    }
+
+    val localRelation = createLocalRelationProto(schema.toAttributes, 
inputRows)
+
+    val project =
+      proto.Project
+        .newBuilder()
+        .setInput(localRelation)
+        .addExpressions(
+          proto.Expression
+            .newBuilder()
+            
.setUnresolvedStar(UnresolvedStar.newBuilder().addTarget("a").addTarget("b").build())
+            .build())
+        .build()
+
+    val df =
+      Dataset.ofRows(spark, 
transform(proto.Relation.newBuilder.setProject(project).build()))
+    assertResult(df.schema)(
+      StructType(Seq(StructField("c", IntegerType), StructField("d", 
IntegerType))))
+
+    val array = df.collect()
+    assert(array.length == 10)
+    for (i <- 0 until 10) {
+      assert(i == array(i).getInt(0))
+      assert(i + 1 == array(i).getInt(1))
+    }
+  }
+
+  test("transform Project with Alias") {
+    val input = proto.Expression
+      .newBuilder()
+      .setLiteral(
+        proto.Expression.Literal
+          .newBuilder()
+          .setInteger(1)
+          .build())
+
+    val project =
+      proto.Project
+        .newBuilder()
+        .addExpressions(
+          proto.Expression
+            .newBuilder()
+            .setAlias(Alias.newBuilder().setExpr(input).addName("id").build())
+            .build())
+        .build()
+
+    val df =
+      Dataset.ofRows(spark, 
transform(proto.Relation.newBuilder.setProject(project).build()))
+    assert(df.schema.fields.toSeq.map(_.name) == Seq("id"))
+  }
 }
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index afa783742d2..a1d9dcb91b0 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -30,11 +30,10 @@ _sym_db = _symbol_database.Default()
 
 
 from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__pb2
-from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xa8\x12\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xc0\x12\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
 )
 
 
@@ -186,30 +185,30 @@ if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
-    _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 2449
-    _EXPRESSION_LITERAL._serialized_start = 613
-    _EXPRESSION_LITERAL._serialized_end = 2071
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1509
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1626
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1628
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1726
-    _EXPRESSION_LITERAL_STRUCT._serialized_start = 1728
-    _EXPRESSION_LITERAL_STRUCT._serialized_end = 1795
-    _EXPRESSION_LITERAL_ARRAY._serialized_start = 1797
-    _EXPRESSION_LITERAL_ARRAY._serialized_end = 1863
-    _EXPRESSION_LITERAL_MAP._serialized_start = 1866
-    _EXPRESSION_LITERAL_MAP._serialized_end = 2055
-    _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 1939
-    _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2055
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2073
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2143
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2145
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2244
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2246
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2296
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2298
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2314
-    _EXPRESSION_ALIAS._serialized_start = 2316
-    _EXPRESSION_ALIAS._serialized_end = 2436
+    _EXPRESSION._serialized_start = 78
+    _EXPRESSION._serialized_end = 2446
+    _EXPRESSION_LITERAL._serialized_start = 586
+    _EXPRESSION_LITERAL._serialized_end = 2044
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1482
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1599
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1601
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1699
+    _EXPRESSION_LITERAL_STRUCT._serialized_start = 1701
+    _EXPRESSION_LITERAL_STRUCT._serialized_end = 1768
+    _EXPRESSION_LITERAL_ARRAY._serialized_start = 1770
+    _EXPRESSION_LITERAL_ARRAY._serialized_end = 1836
+    _EXPRESSION_LITERAL_MAP._serialized_start = 1839
+    _EXPRESSION_LITERAL_MAP._serialized_end = 2028
+    _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 1912
+    _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2028
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2046
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2116
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2118
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2217
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2219
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2269
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2271
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2311
+    _EXPRESSION_ALIAS._serialized_start = 2313
+    _EXPRESSION_ALIAS._serialized_end = 2433
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index f1c599964bf..ddd9338d85d 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -505,8 +505,21 @@ class Expression(google.protobuf.message.Message):
 
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+        TARGET_FIELD_NUMBER: builtins.int
+        @property
+        def target(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+            """(Optional) The target of the expansion, either be a table name 
or struct name, this
+            is a list of identifiers that is the path of the expansion.
+            """
         def __init__(
             self,
+            *,
+            target: collections.abc.Iterable[builtins.str] | None = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["target", b"target"]
         ) -> None: ...
 
     class Alias(google.protobuf.message.Message):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to