This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b14da8b1b65 [SPARK-40812][CONNECT] Add Deduplicate to Connect proto and DSL b14da8b1b65 is described below commit b14da8b1b65d9f00f49fab87f738715089bc43e8 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Fri Oct 21 13:14:43 2022 +0800 [SPARK-40812][CONNECT] Add Deduplicate to Connect proto and DSL ### What changes were proposed in this pull request? This PR supports `Deduplicate` to Connect proto and DSL. Note that `Deduplicate` can not be replaced by SQL's `SELECT DISTINCT col_list`. The difference is that `Deduplicate` allows to remove duplicated rows based on a set of columns but returns all the columns. SQL's `SELECT DISTINCT col_list`, instead, can only return the `col_list`. ### Why are the changes needed? 1. To improve proto API coverage. 2. `Deduplicate` blocks https://github.com/apache/spark/pull/38166 because we want support `Union(isAll=false)` but that will return `Union().Distinct()` to match existing DataFrame API. `Deduplicate` is needed to write test cases for `Union(isAll=false)`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38276 from amaliujia/supportDropDuplicates. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 9 ++ .../org/apache/spark/sql/connect/dsl/package.scala | 20 +++++ .../sql/connect/planner/SparkConnectPlanner.scala | 35 +++++++- .../planner/SparkConnectDeduplicateSuite.scala | 68 +++++++++++++++ .../connect/planner/SparkConnectPlannerSuite.scala | 29 ++++++- python/pyspark/sql/connect/proto/relations_pb2.py | 98 +++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 50 +++++++++++ 7 files changed, 257 insertions(+), 52 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index eadedf495d3..6adf0831ea2 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -43,6 +43,7 @@ message Relation { LocalRelation local_relation = 11; Sample sample = 12; Offset offset = 13; + Deduplicate deduplicate = 14; Unknown unknown = 999; } @@ -181,6 +182,14 @@ message Sort { } } +// Relation of type [[Deduplicate]] which have duplicate rows removed, could consider either only +// the subset of columns or all the columns. +message Deduplicate { + Relation input = 1; + repeated string column_names = 2; + bool all_columns_as_keys = 3; +} + message LocalRelation { repeated Expression.QualifiedAttribute attributes = 1; // TODO: support local data. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 8a267dff7d7..68bbc0487f9 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -215,6 +215,26 @@ package object dsl { .build() } + def deduplicate(colNames: Seq[String]): proto.Relation = + proto.Relation + .newBuilder() + .setDeduplicate( + proto.Deduplicate + .newBuilder() + .setInput(logicalPlan) + .addAllColumnNames(colNames.asJava)) + .build() + + def distinct(): proto.Relation = + proto.Relation + .newBuilder() + .setDeduplicate( + proto.Deduplicate + .newBuilder() + .setInput(logicalPlan) + .setAllColumnsAsKeys(true)) + .build() + def join( otherPlan: proto.Relation, joinType: JoinType = JoinType.JOIN_TYPE_INNER, diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 450283a9b81..92c8bf01cba 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -27,8 +27,9 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sample, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, Sample, SubqueryAlias} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types._ final case class InvalidPlanInput( @@ -60,6 +61,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset) case proto.Relation.RelTypeCase.JOIN => transformJoin(rel.getJoin) case proto.Relation.RelTypeCase.UNION => transformUnion(rel.getUnion) + case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate) case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) @@ -91,6 +93,37 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { transformRelation(rel.getInput)) } + private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { + if (!rel.hasInput) { + throw InvalidPlanInput("Deduplicate needs a plan input") + } + if (rel.getAllColumnsAsKeys && rel.getColumnNamesCount > 0) { + throw InvalidPlanInput("Cannot deduplicate on both all columns and a subset of columns") + } + if (!rel.getAllColumnsAsKeys && rel.getColumnNamesCount == 0) { + throw InvalidPlanInput( + "Deduplicate requires to either deduplicate on all columns or a subset of columns") + } + val queryExecution = new QueryExecution(session, transformRelation(rel.getInput)) + val resolver = session.sessionState.analyzer.resolver + val allColumns = queryExecution.analyzed.output + if (rel.getAllColumnsAsKeys) { + Deduplicate(allColumns, queryExecution.analyzed) + } else { + val toGroupColumnNames = rel.getColumnNamesList.asScala.toSeq + val groupCols = toGroupColumnNames.flatMap { (colName: String) => + // It is possibly there are more than one columns with the same name, + // so we call filter instead of find. + val cols = allColumns.filter(col => resolver(col.name, colName)) + if (cols.isEmpty) { + throw InvalidPlanInput(s"Invalid deduplicate column ${colName}") + } + cols + } + Deduplicate(groupCols, queryExecution.analyzed) + } + } + private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala new file mode 100644 index 00000000000..88af60581ba --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectDeduplicateSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.planner + +import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +/** + * [[SparkConnectPlanTestWithSparkSession]] contains a SparkSession for the connect planner. + * + * It is not recommended to use Catalyst DSL along with this trait because `SharedSparkSession` + * has also defined implicits over Catalyst LogicalPlan which will cause ambiguity with the + * implicits defined in Catalyst DSL. + */ +trait SparkConnectPlanTestWithSparkSession extends SharedSparkSession with SparkConnectPlanTest { + override def getSession(): SparkSession = spark +} + +class SparkConnectDeduplicateSuite extends SparkConnectPlanTestWithSparkSession { + lazy val connectTestRelation = createLocalRelationProto( + Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("key", StringType)(), + AttributeReference("value", StringType)())) + + lazy val sparkTestRelation = { + spark.createDataFrame( + new java.util.ArrayList[Row](), + StructType( + Seq( + StructField("id", IntegerType), + StructField("key", StringType), + StructField("value", StringType)))) + } + + test("Test basic deduplicate") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.plans._ + Dataset.ofRows(spark, transform(connectTestRelation.distinct())) + } + + val sparkPlan = sparkTestRelation.distinct() + comparePlans(connectPlan.queryExecution.analyzed, sparkPlan.queryExecution.analyzed, false) + + val connectPlan2 = { + import org.apache.spark.sql.connect.dsl.plans._ + Dataset.ofRows(spark, transform(connectTestRelation.deduplicate(Seq("key", "value")))) + } + val sparkPlan2 = sparkTestRelation.dropDuplicates(Seq("key", "value")) + comparePlans(connectPlan2.queryExecution.analyzed, sparkPlan2.queryExecution.analyzed, false) + } +} diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 83bf76efce1..980e899c26e 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -31,8 +31,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * test cases. */ trait SparkConnectPlanTest { + + def getSession(): SparkSession = None.orNull + def transform(rel: proto.Relation): LogicalPlan = { - new SparkConnectPlanner(rel, None.orNull).transform() + new SparkConnectPlanner(rel, getSession()).transform() } def readRel: proto.Relation = @@ -72,8 +75,6 @@ trait SparkConnectPlanTest { */ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { - protected var spark: SparkSession = null - test("Simple Limit") { assertThrows[IndexOutOfBoundsException] { new SparkConnectPlanner( @@ -266,4 +267,26 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { .build())) assert(e.getMessage.contains("DataSource requires a format")) } + + test("Test invalid deduplicate") { + val deduplicate = proto.Deduplicate + .newBuilder() + .setInput(readRel) + .setAllColumnsAsKeys(true) + .addColumnNames("test") + + val e = intercept[InvalidPlanInput] { + transform(proto.Relation.newBuilder.setDeduplicate(deduplicate).build()) + } + assert( + e.getMessage.contains("Cannot deduplicate on both all columns and a subset of columns")) + + val deduplicate2 = proto.Deduplicate + .newBuilder() + .setInput(readRel) + val e2 = intercept[InvalidPlanInput] { + transform(proto.Relation.newBuilder.setDeduplicate(deduplicate2).build()) + } + assert(e2.getMessage.contains("either deduplicate on all columns or a subset of columns")) + } } diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index b244cdf8dcb..1c868bcf411 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcf\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -44,51 +44,53 @@ if _descriptor._USE_C_DESCRIPTORS == False: _READ_DATASOURCE_OPTIONSENTRY._options = None _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 801 - _UNKNOWN._serialized_start = 803 - _UNKNOWN._serialized_end = 812 - _RELATIONCOMMON._serialized_start = 814 - _RELATIONCOMMON._serialized_end = 885 - _SQL._serialized_start = 887 - _SQL._serialized_end = 914 - _READ._serialized_start = 917 - _READ._serialized_end = 1327 - _READ_NAMEDTABLE._serialized_start = 1059 - _READ_NAMEDTABLE._serialized_end = 1120 - _READ_DATASOURCE._serialized_start = 1123 - _READ_DATASOURCE._serialized_end = 1314 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1256 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1314 - _PROJECT._serialized_start = 1329 - _PROJECT._serialized_end = 1446 - _FILTER._serialized_start = 1448 - _FILTER._serialized_end = 1560 - _JOIN._serialized_start = 1563 - _JOIN._serialized_end = 1976 - _JOIN_JOINTYPE._serialized_start = 1789 - _JOIN_JOINTYPE._serialized_end = 1976 - _UNION._serialized_start = 1979 - _UNION._serialized_end = 2184 - _UNION_UNIONTYPE._serialized_start = 2100 - _UNION_UNIONTYPE._serialized_end = 2184 - _LIMIT._serialized_start = 2186 - _LIMIT._serialized_end = 2262 - _OFFSET._serialized_start = 2264 - _OFFSET._serialized_end = 2343 - _AGGREGATE._serialized_start = 2346 - _AGGREGATE._serialized_end = 2671 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2575 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2671 - _SORT._serialized_start = 2674 - _SORT._serialized_end = 3176 - _SORT_SORTFIELD._serialized_start = 2794 - _SORT_SORTFIELD._serialized_end = 2982 - _SORT_SORTDIRECTION._serialized_start = 2984 - _SORT_SORTDIRECTION._serialized_end = 3092 - _SORT_SORTNULLS._serialized_start = 3094 - _SORT_SORTNULLS._serialized_end = 3176 - _LOCALRELATION._serialized_start = 3178 - _LOCALRELATION._serialized_end = 3271 - _SAMPLE._serialized_start = 3274 - _SAMPLE._serialized_end = 3458 + _RELATION._serialized_end = 865 + _UNKNOWN._serialized_start = 867 + _UNKNOWN._serialized_end = 876 + _RELATIONCOMMON._serialized_start = 878 + _RELATIONCOMMON._serialized_end = 949 + _SQL._serialized_start = 951 + _SQL._serialized_end = 978 + _READ._serialized_start = 981 + _READ._serialized_end = 1391 + _READ_NAMEDTABLE._serialized_start = 1123 + _READ_NAMEDTABLE._serialized_end = 1184 + _READ_DATASOURCE._serialized_start = 1187 + _READ_DATASOURCE._serialized_end = 1378 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1320 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1378 + _PROJECT._serialized_start = 1393 + _PROJECT._serialized_end = 1510 + _FILTER._serialized_start = 1512 + _FILTER._serialized_end = 1624 + _JOIN._serialized_start = 1627 + _JOIN._serialized_end = 2040 + _JOIN_JOINTYPE._serialized_start = 1853 + _JOIN_JOINTYPE._serialized_end = 2040 + _UNION._serialized_start = 2043 + _UNION._serialized_end = 2248 + _UNION_UNIONTYPE._serialized_start = 2164 + _UNION_UNIONTYPE._serialized_end = 2248 + _LIMIT._serialized_start = 2250 + _LIMIT._serialized_end = 2326 + _OFFSET._serialized_start = 2328 + _OFFSET._serialized_end = 2407 + _AGGREGATE._serialized_start = 2410 + _AGGREGATE._serialized_end = 2735 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2639 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2735 + _SORT._serialized_start = 2738 + _SORT._serialized_end = 3240 + _SORT_SORTFIELD._serialized_start = 2858 + _SORT_SORTFIELD._serialized_end = 3046 + _SORT_SORTDIRECTION._serialized_start = 3048 + _SORT_SORTDIRECTION._serialized_end = 3156 + _SORT_SORTNULLS._serialized_start = 3158 + _SORT_SORTNULLS._serialized_end = 3240 + _DEDUPLICATE._serialized_start = 3243 + _DEDUPLICATE._serialized_end = 3385 + _LOCALRELATION._serialized_start = 3387 + _LOCALRELATION._serialized_end = 3480 + _SAMPLE._serialized_start = 3483 + _SAMPLE._serialized_end = 3667 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index f0a8b6412b5..fc135c559a6 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -72,6 +72,7 @@ class Relation(google.protobuf.message.Message): LOCAL_RELATION_FIELD_NUMBER: builtins.int SAMPLE_FIELD_NUMBER: builtins.int OFFSET_FIELD_NUMBER: builtins.int + DEDUPLICATE_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property def common(self) -> global___RelationCommon: ... @@ -100,6 +101,8 @@ class Relation(google.protobuf.message.Message): @property def offset(self) -> global___Offset: ... @property + def deduplicate(self) -> global___Deduplicate: ... + @property def unknown(self) -> global___Unknown: ... def __init__( self, @@ -117,6 +120,7 @@ class Relation(google.protobuf.message.Message): local_relation: global___LocalRelation | None = ..., sample: global___Sample | None = ..., offset: global___Offset | None = ..., + deduplicate: global___Deduplicate | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... def HasField( @@ -126,6 +130,8 @@ class Relation(google.protobuf.message.Message): b"aggregate", "common", b"common", + "deduplicate", + b"deduplicate", "filter", b"filter", "join", @@ -161,6 +167,8 @@ class Relation(google.protobuf.message.Message): b"aggregate", "common", b"common", + "deduplicate", + b"deduplicate", "filter", b"filter", "join", @@ -204,6 +212,7 @@ class Relation(google.protobuf.message.Message): "local_relation", "sample", "offset", + "deduplicate", "unknown", ] | None: ... @@ -759,6 +768,47 @@ class Sort(google.protobuf.message.Message): global___Sort = Sort +class Deduplicate(google.protobuf.message.Message): + """Relation of type [[Deduplicate]] which have duplicate rows removed, could consider either only + the subset of columns or all the columns. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + COLUMN_NAMES_FIELD_NUMBER: builtins.int + ALL_COLUMNS_AS_KEYS_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: ... + @property + def column_names( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + all_columns_as_keys: builtins.bool + def __init__( + self, + *, + input: global___Relation | None = ..., + column_names: collections.abc.Iterable[builtins.str] | None = ..., + all_columns_as_keys: builtins.bool = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "all_columns_as_keys", + b"all_columns_as_keys", + "column_names", + b"column_names", + "input", + b"input", + ], + ) -> None: ... + +global___Deduplicate = Deduplicate + class LocalRelation(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org