This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4e4a848c275 [SPARK-40707][CONNECT] Add groupby to connect DSL and test
more than one grouping expressions
4e4a848c275 is described below
commit 4e4a848c2759577464f4c11c4ea938c7d931f214
Author: Rui Wang <[email protected]>
AuthorDate: Tue Oct 11 12:35:08 2022 +0800
[SPARK-40707][CONNECT] Add groupby to connect DSL and test more than one
grouping expressions
### What changes were proposed in this pull request?
1. Add `groupby` to connect DSL and test more than one grouping expressions
2. Pass limited data types through connect proto for LocalRelation's
attributes.
3. Cleanup unused `Trait` in the testing code.
### Why are the changes needed?
Enhance connect's support for GROUP BY.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #38155 from amaliujia/support_more_than_one_grouping_set.
Authored-by: Rui Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/protobuf/spark/connect/commands.proto | 4 +-
.../main/protobuf/spark/connect/expressions.proto | 7 ++--
.../main/protobuf/spark/connect/relations.proto | 18 +--------
.../src/main/protobuf/spark/connect/types.proto | 12 +++---
.../org/apache/spark/sql/connect/dsl/package.scala | 13 ++++++
.../connect/planner/DataTypeProtoConverter.scala | 46 ++++++++++++++++++++++
.../sql/connect/planner/SparkConnectPlanner.scala | 19 ++++-----
.../connect/planner/SparkConnectPlannerSuite.scala | 20 +++-------
.../connect/planner/SparkConnectProtoSuite.scala | 19 +++++++--
9 files changed, 101 insertions(+), 57 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/commands.proto
b/connector/connect/src/main/protobuf/spark/connect/commands.proto
index 425857b842e..0a83e4543f5 100644
--- a/connector/connect/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/commands.proto
@@ -44,8 +44,8 @@ message CreateScalarFunction {
repeated string parts = 1;
FunctionLanguage language = 2;
bool temporary = 3;
- repeated Type argument_types = 4;
- Type return_type = 5;
+ repeated DataType argument_types = 4;
+ DataType return_type = 5;
// How the function body is defined:
oneof function_definition {
diff --git
a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index 9b3029a32b0..791b1b5887b 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -65,10 +65,10 @@ message Expression {
// Timestamp in units of microseconds since the UNIX epoch.
int64 timestamp_tz = 27;
bytes uuid = 28;
- Type null = 29; // a typed null literal
+ DataType null = 29; // a typed null literal
List list = 30;
- Type.List empty_list = 31;
- Type.Map empty_map = 32;
+ DataType.List empty_list = 31;
+ DataType.Map empty_map = 32;
UserDefined user_defined = 33;
}
@@ -164,5 +164,6 @@ message Expression {
// by the analyzer.
message QualifiedAttribute {
string name = 1;
+ DataType type = 2;
}
}
diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 25bc4e8a16b..30f36fa6ceb 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -130,22 +130,8 @@ message Fetch {
// Relation of type [[Aggregate]].
message Aggregate {
Relation input = 1;
-
- // Grouping sets are used in rollups
- repeated GroupingSet grouping_sets = 2;
-
- // Measures
- repeated Measure measures = 3;
-
- message GroupingSet {
- repeated Expression aggregate_expressions = 1;
- }
-
- message Measure {
- AggregateFunction function = 1;
- // Conditional filter for SUM(x FILTER WHERE x < 10)
- Expression filter = 2;
- }
+ repeated Expression grouping_expressions = 2;
+ repeated AggregateFunction result_expressions = 3;
message AggregateFunction {
string name = 1;
diff --git a/connector/connect/src/main/protobuf/spark/connect/types.proto
b/connector/connect/src/main/protobuf/spark/connect/types.proto
index c46afa2afc6..98b0c48b1e0 100644
--- a/connector/connect/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/types.proto
@@ -22,9 +22,9 @@ package spark.connect;
option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";
-// This message describes the logical [[Type]] of something. It does not carry
the value
+// This message describes the logical [[DataType]] of something. It does not
carry the value
// itself but only describes it.
-message Type {
+message DataType {
oneof kind {
Boolean bool = 1;
I8 i8 = 2;
@@ -168,20 +168,20 @@ message Type {
}
message Struct {
- repeated Type types = 1;
+ repeated DataType types = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}
message List {
- Type type = 1;
+ DataType DataType = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}
message Map {
- Type key = 1;
- Type value = 2;
+ DataType key = 1;
+ DataType value = 2;
uint32 type_variation_reference = 3;
Nullability nullability = 4;
}
diff --git
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 234b423a803..3ccf71c26b7 100644
---
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -67,6 +67,19 @@ package object dsl {
}
relation.setJoin(join).build()
}
+
+ def groupBy(
+ groupingExprs: proto.Expression*)(aggregateExprs:
proto.Expression*): proto.Relation = {
+ val agg = proto.Aggregate.newBuilder()
+ agg.setInput(logicalPlan)
+
+ for (groupingExpr <- groupingExprs) {
+ agg.addGroupingExpressions(groupingExpr)
+ }
+ // TODO: support aggregateExprs, which is blocked by supporting any
builtin function
+ // resolution only by name in the analyzer.
+ proto.Relation.newBuilder().setAggregate(agg.build()).build()
+ }
}
}
}
diff --git
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
new file mode 100644
index 00000000000..b31855bfca9
--- /dev/null
+++
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
+
+/**
+ * This object offers methods to convert to/from connect proto to catalyst
types.
+ */
+object DataTypeProtoConverter {
+ def toCatalystType(t: proto.DataType): DataType = {
+ t.getKindCase match {
+ case proto.DataType.KindCase.I32 => IntegerType
+ case proto.DataType.KindCase.STRING => StringType
+ case _ =>
+ throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to
catalyst types.")
+ }
+ }
+
+ def toConnectProtoType(t: DataType): proto.DataType = {
+ t match {
+ case IntegerType =>
+
proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
+ case StringType =>
+
proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
+ case _ =>
+ throw InvalidPlanInput(s"Does not support convert ${t.typeName} to
connect proto types.")
+ }
+ }
+}
diff --git
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e3bb7e29322..66560f5e62f 100644
---
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -77,8 +77,7 @@ class SparkConnectPlanner(plan: proto.Relation, session:
SparkSession) {
}
private def transformAttribute(exp: proto.Expression.QualifiedAttribute):
Attribute = {
- // TODO: use data type from the proto.
- AttributeReference(exp.getName, IntegerType)()
+ AttributeReference(exp.getName,
DataTypeProtoConverter.toCatalystType(exp.getType))()
}
private def transformReadRel(
@@ -271,11 +270,9 @@ class SparkConnectPlanner(plan: proto.Relation, session:
SparkSession) {
private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
assert(rel.hasInput)
- assert(rel.getGroupingSetsCount == 1, "Only one grouping set is supported")
- val groupingSet = rel.getGroupingSetsList.asScala.take(1)
- val ge = groupingSet
- .flatMap(f => f.getAggregateExpressionsList.asScala)
+ val groupingExprs =
+ rel.getGroupingExpressionsList.asScala
.map(transformExpression)
.map {
case x @ UnresolvedAttribute(_) => x
@@ -284,18 +281,18 @@ class SparkConnectPlanner(plan: proto.Relation, session:
SparkSession) {
logical.Aggregate(
child = transformRelation(rel.getInput),
- groupingExpressions = ge.toSeq,
+ groupingExpressions = groupingExprs.toSeq,
aggregateExpressions =
- (rel.getMeasuresList.asScala.map(transformAggregateExpression) ++
ge).toSeq)
+
rel.getResultExpressionsList.asScala.map(transformAggregateExpression).toSeq)
}
private def transformAggregateExpression(
- exp: proto.Aggregate.Measure): expressions.NamedExpression = {
- val fun = exp.getFunction.getName
+ exp: proto.Aggregate.AggregateFunction): expressions.NamedExpression = {
+ val fun = exp.getName
UnresolvedAlias(
UnresolvedFunction(
name = fun,
- arguments =
exp.getFunction.getArgumentsList.asScala.map(transformExpression).toSeq,
+ arguments =
exp.getArgumentsList.asScala.map(transformExpression).toSeq,
isDistinct = false))
}
diff --git
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 37d80e01f72..10e17f121f0 100644
---
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -45,11 +45,6 @@ trait SparkConnectPlanTest {
.build()
}
-trait SparkConnectSessionTest {
- protected var spark: SparkSession
-
-}
-
/**
* This is a rudimentary test class for SparkConnect. The main goal of these
basic tests is to
* ensure that the transformation from Proto to LogicalPlan works and that the
right nodes are
@@ -222,16 +217,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with
SparkConnectPlanTest {
val agg = proto.Aggregate.newBuilder
.setInput(readRel)
- .addAllMeasures(
- Seq(
- proto.Aggregate.Measure.newBuilder
- .setFunction(proto.Aggregate.AggregateFunction.newBuilder
- .setName("sum")
- .addArguments(unresolvedAttribute))
- .build()).asJava)
- .addGroupingSets(proto.Aggregate.GroupingSet.newBuilder
- .addAggregateExpressions(unresolvedAttribute)
- .build())
+ .addResultExpressions(
+ proto.Aggregate.AggregateFunction.newBuilder
+ .setName("sum")
+ .addArguments(unresolvedAttribute))
+ .addGroupingExpressions(unresolvedAttribute)
.build()
val res = transform(proto.Relation.newBuilder.setAggregate(agg).build())
diff --git
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 4f3f0fea387..441a3a9f1e4 100644
---
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -31,11 +31,11 @@ import
org.apache.spark.sql.catalyst.plans.logical.LocalRelation
*/
class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
- lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int))
+ lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int,
$"name".string))
lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int,
$"value".int))
- lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int)
+ lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int,
$"name".string)
lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int,
$"value".int)
@@ -81,12 +81,23 @@ class SparkConnectProtoSuite extends PlanTest with
SparkConnectPlanTest {
}
}
+ test("Aggregate with more than 1 grouping expressions") {
+ val connectPlan = {
+ import org.apache.spark.sql.connect.dsl.expressions._
+ import org.apache.spark.sql.connect.dsl.plans._
+ transform(connectTestRelation.groupBy("id".protoAttr,
"name".protoAttr)())
+ }
+ val sparkPlan = sparkTestRelation.groupBy($"id", $"name")()
+ comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
+ }
+
private def createLocalRelationProto(attrs: Seq[AttributeReference]):
proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
- // TODO: set data types for each local relation attribute one proto
supports data type.
for (attr <- attrs) {
localRelationBuilder.addAttributes(
-
proto.Expression.QualifiedAttribute.newBuilder().setName(attr.name).build()
+ proto.Expression.QualifiedAttribute.newBuilder()
+ .setName(attr.name)
+ .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))
)
}
proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]