This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4e4a848c275 [SPARK-40707][CONNECT] Add groupby to connect DSL and test 
more than one grouping expressions
4e4a848c275 is described below

commit 4e4a848c2759577464f4c11c4ea938c7d931f214
Author: Rui Wang <[email protected]>
AuthorDate: Tue Oct 11 12:35:08 2022 +0800

    [SPARK-40707][CONNECT] Add groupby to connect DSL and test more than one 
grouping expressions
    
    ### What changes were proposed in this pull request?
    
    1. Add `groupby` to connect DSL and test more than one grouping expressions
    2. Pass limited data types through connect proto for LocalRelation's 
attributes.
    3. Cleanup unused `Trait` in the testing code.
    
    ### Why are the changes needed?
    
    Enhance connect's support for GROUP BY.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UT
    
    Closes #38155 from amaliujia/support_more_than_one_grouping_set.
    
    Authored-by: Rui Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../src/main/protobuf/spark/connect/commands.proto |  4 +-
 .../main/protobuf/spark/connect/expressions.proto  |  7 ++--
 .../main/protobuf/spark/connect/relations.proto    | 18 +--------
 .../src/main/protobuf/spark/connect/types.proto    | 12 +++---
 .../org/apache/spark/sql/connect/dsl/package.scala | 13 ++++++
 .../connect/planner/DataTypeProtoConverter.scala   | 46 ++++++++++++++++++++++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 19 ++++-----
 .../connect/planner/SparkConnectPlannerSuite.scala | 20 +++-------
 .../connect/planner/SparkConnectProtoSuite.scala   | 19 +++++++--
 9 files changed, 101 insertions(+), 57 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/commands.proto 
b/connector/connect/src/main/protobuf/spark/connect/commands.proto
index 425857b842e..0a83e4543f5 100644
--- a/connector/connect/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/commands.proto
@@ -44,8 +44,8 @@ message CreateScalarFunction {
   repeated string parts = 1;
   FunctionLanguage language = 2;
   bool temporary = 3;
-  repeated Type argument_types = 4;
-  Type return_type = 5;
+  repeated DataType argument_types = 4;
+  DataType return_type = 5;
 
   // How the function body is defined:
   oneof function_definition {
diff --git 
a/connector/connect/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index 9b3029a32b0..791b1b5887b 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -65,10 +65,10 @@ message Expression {
       // Timestamp in units of microseconds since the UNIX epoch.
       int64 timestamp_tz = 27;
       bytes uuid = 28;
-      Type null = 29; // a typed null literal
+      DataType null = 29; // a typed null literal
       List list = 30;
-      Type.List empty_list = 31;
-      Type.Map empty_map = 32;
+      DataType.List empty_list = 31;
+      DataType.Map empty_map = 32;
       UserDefined user_defined = 33;
     }
 
@@ -164,5 +164,6 @@ message Expression {
   // by the analyzer.
   message QualifiedAttribute {
     string name = 1;
+    DataType type = 2;
   }
 }
diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 25bc4e8a16b..30f36fa6ceb 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -130,22 +130,8 @@ message Fetch {
 // Relation of type [[Aggregate]].
 message Aggregate {
   Relation input = 1;
-
-  // Grouping sets are used in rollups
-  repeated GroupingSet grouping_sets = 2;
-
-  // Measures
-  repeated Measure measures = 3;
-
-  message GroupingSet {
-    repeated Expression aggregate_expressions = 1;
-  }
-
-  message Measure {
-    AggregateFunction function = 1;
-    // Conditional filter for SUM(x FILTER WHERE x < 10)
-    Expression filter = 2;
-  }
+  repeated Expression grouping_expressions = 2;
+  repeated AggregateFunction result_expressions = 3;
 
   message AggregateFunction {
     string name = 1;
diff --git a/connector/connect/src/main/protobuf/spark/connect/types.proto 
b/connector/connect/src/main/protobuf/spark/connect/types.proto
index c46afa2afc6..98b0c48b1e0 100644
--- a/connector/connect/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/types.proto
@@ -22,9 +22,9 @@ package spark.connect;
 option java_multiple_files = true;
 option java_package = "org.apache.spark.connect.proto";
 
-// This message describes the logical [[Type]] of something. It does not carry 
the value
+// This message describes the logical [[DataType]] of something. It does not 
carry the value
 // itself but only describes it.
-message Type {
+message DataType {
   oneof kind {
     Boolean bool = 1;
     I8 i8 = 2;
@@ -168,20 +168,20 @@ message Type {
   }
 
   message Struct {
-    repeated Type types = 1;
+    repeated DataType types = 1;
     uint32 type_variation_reference = 2;
     Nullability nullability = 3;
   }
 
   message List {
-    Type type = 1;
+    DataType DataType = 1;
     uint32 type_variation_reference = 2;
     Nullability nullability = 3;
   }
 
   message Map {
-    Type key = 1;
-    Type value = 2;
+    DataType key = 1;
+    DataType value = 2;
     uint32 type_variation_reference = 3;
     Nullability nullability = 4;
   }
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 234b423a803..3ccf71c26b7 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -67,6 +67,19 @@ package object dsl {
         }
         relation.setJoin(join).build()
       }
+
+      def groupBy(
+          groupingExprs: proto.Expression*)(aggregateExprs: 
proto.Expression*): proto.Relation = {
+        val agg = proto.Aggregate.newBuilder()
+        agg.setInput(logicalPlan)
+
+        for (groupingExpr <- groupingExprs) {
+          agg.addGroupingExpressions(groupingExpr)
+        }
+        // TODO: support aggregateExprs, which is blocked by supporting any 
builtin function
+        // resolution only by name in the analyzer.
+        proto.Relation.newBuilder().setAggregate(agg.build()).build()
+      }
     }
   }
 }
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
new file mode 100644
index 00000000000..b31855bfca9
--- /dev/null
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
+
+/**
+ * This object offers methods to convert to/from connect proto to catalyst 
types.
+ */
+object DataTypeProtoConverter {
+  def toCatalystType(t: proto.DataType): DataType = {
+    t.getKindCase match {
+      case proto.DataType.KindCase.I32 => IntegerType
+      case proto.DataType.KindCase.STRING => StringType
+      case _ =>
+        throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to 
catalyst types.")
+    }
+  }
+
+  def toConnectProtoType(t: DataType): proto.DataType = {
+    t match {
+      case IntegerType =>
+        
proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
+      case StringType =>
+        
proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
+      case _ =>
+        throw InvalidPlanInput(s"Does not support convert ${t.typeName} to 
connect proto types.")
+    }
+  }
+}
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e3bb7e29322..66560f5e62f 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -77,8 +77,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
   }
 
   private def transformAttribute(exp: proto.Expression.QualifiedAttribute): 
Attribute = {
-    // TODO: use data type from the proto.
-    AttributeReference(exp.getName, IntegerType)()
+    AttributeReference(exp.getName, 
DataTypeProtoConverter.toCatalystType(exp.getType))()
   }
 
   private def transformReadRel(
@@ -271,11 +270,9 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
 
   private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
     assert(rel.hasInput)
-    assert(rel.getGroupingSetsCount == 1, "Only one grouping set is supported")
 
-    val groupingSet = rel.getGroupingSetsList.asScala.take(1)
-    val ge = groupingSet
-      .flatMap(f => f.getAggregateExpressionsList.asScala)
+    val groupingExprs =
+      rel.getGroupingExpressionsList.asScala
       .map(transformExpression)
       .map {
         case x @ UnresolvedAttribute(_) => x
@@ -284,18 +281,18 @@ class SparkConnectPlanner(plan: proto.Relation, session: 
SparkSession) {
 
     logical.Aggregate(
       child = transformRelation(rel.getInput),
-      groupingExpressions = ge.toSeq,
+      groupingExpressions = groupingExprs.toSeq,
       aggregateExpressions =
-        (rel.getMeasuresList.asScala.map(transformAggregateExpression) ++ 
ge).toSeq)
+        
rel.getResultExpressionsList.asScala.map(transformAggregateExpression).toSeq)
   }
 
   private def transformAggregateExpression(
-      exp: proto.Aggregate.Measure): expressions.NamedExpression = {
-    val fun = exp.getFunction.getName
+      exp: proto.Aggregate.AggregateFunction): expressions.NamedExpression = {
+    val fun = exp.getName
     UnresolvedAlias(
       UnresolvedFunction(
         name = fun,
-        arguments = 
exp.getFunction.getArgumentsList.asScala.map(transformExpression).toSeq,
+        arguments = 
exp.getArgumentsList.asScala.map(transformExpression).toSeq,
         isDistinct = false))
   }
 
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 37d80e01f72..10e17f121f0 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -45,11 +45,6 @@ trait SparkConnectPlanTest {
       .build()
 }
 
-trait SparkConnectSessionTest {
-  protected var spark: SparkSession
-
-}
-
 /**
  * This is a rudimentary test class for SparkConnect. The main goal of these 
basic tests is to
  * ensure that the transformation from Proto to LogicalPlan works and that the 
right nodes are
@@ -222,16 +217,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
 
     val agg = proto.Aggregate.newBuilder
       .setInput(readRel)
-      .addAllMeasures(
-        Seq(
-          proto.Aggregate.Measure.newBuilder
-            .setFunction(proto.Aggregate.AggregateFunction.newBuilder
-              .setName("sum")
-              .addArguments(unresolvedAttribute))
-            .build()).asJava)
-      .addGroupingSets(proto.Aggregate.GroupingSet.newBuilder
-        .addAggregateExpressions(unresolvedAttribute)
-        .build())
+      .addResultExpressions(
+        proto.Aggregate.AggregateFunction.newBuilder
+          .setName("sum")
+          .addArguments(unresolvedAttribute))
+      .addGroupingExpressions(unresolvedAttribute)
       .build()
 
     val res = transform(proto.Relation.newBuilder.setAggregate(agg).build())
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 4f3f0fea387..441a3a9f1e4 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -31,11 +31,11 @@ import 
org.apache.spark.sql.catalyst.plans.logical.LocalRelation
  */
 class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
 
-  lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int))
+  lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, 
$"name".string))
 
   lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, 
$"value".int))
 
-  lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int)
+  lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, 
$"name".string)
 
   lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, 
$"value".int)
 
@@ -81,12 +81,23 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
     }
   }
 
+  test("Aggregate with more than 1 grouping expressions") {
+    val connectPlan = {
+      import org.apache.spark.sql.connect.dsl.expressions._
+      import org.apache.spark.sql.connect.dsl.plans._
+      transform(connectTestRelation.groupBy("id".protoAttr, 
"name".protoAttr)())
+    }
+    val sparkPlan = sparkTestRelation.groupBy($"id", $"name")()
+    comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+  }
+
   private def createLocalRelationProto(attrs: Seq[AttributeReference]): 
proto.Relation = {
     val localRelationBuilder = proto.LocalRelation.newBuilder()
-    // TODO: set data types for each local relation attribute one proto 
supports data type.
     for (attr <- attrs) {
       localRelationBuilder.addAttributes(
-        
proto.Expression.QualifiedAttribute.newBuilder().setName(attr.name).build()
+        proto.Expression.QualifiedAttribute.newBuilder()
+          .setName(attr.name)
+          .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))
       )
     }
     
proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to