This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b9ca91dde94c [SPARK-47712][CONNECT] Allow connect plugins to create and process Datasets b9ca91dde94c is described below commit b9ca91dde94c5ac6eeae9bb5818099adbc93006c Author: Tom van Bussel <tom.vanbus...@databricks.com> AuthorDate: Fri Apr 5 10:42:43 2024 +0900 [SPARK-47712][CONNECT] Allow connect plugins to create and process Datasets ### What changes were proposed in this pull request? This PR adds new versions of `SparkSession.createDataset` and `SparkSession.createDataFrame` that take an `Array[Byte]` as input. The older versions that take a `protobuf.Any` are deprecated. This PR also adds new versions of `SparkConnectPlanner.transformRelation` and `SparkConnectPlanner.transformExpression` that take an `Array[Byte]`. ### Why are the changes needed? Without these changes it's difficult to create plugins for Spark Connect. The methods above used to take a protobuf class that is shaded as input, meaning that that plugins had to shade these classes in the exact same way. Now they can just serialize the protobuf object to bytes and pass that in instead. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tests were added ### Was this patch authored or co-authored using generative AI tooling? No Closes #45850 from tomvanbussel/SPARK-47712. Authored-by: Tom van Bussel <tom.vanbus...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/scala/org/apache/spark/sql/Column.scala | 6 +++++ .../scala/org/apache/spark/sql/SparkSession.scala | 14 ++++++++++- .../org/apache/spark/sql/ClientDatasetSuite.scala | 14 ++++++++++- .../apache/spark/sql/PlanGenerationTestSuite.scala | 26 +++++++++++++++++++-- .../expression_extension_deprecated.explain | 2 ++ .../relation_extension_deprecated.explain | 1 + .../queries/expression_extension_deprecated.json | 26 +++++++++++++++++++++ .../expression_extension_deprecated.proto.bin | Bin 0 -> 127 bytes .../queries/relation_extension_deprecated.json | 16 +++++++++++++ .../relation_extension_deprecated.proto.bin | Bin 0 -> 108 bytes .../sql/connect/planner/SparkConnectPlanner.scala | 11 +++++++++ .../plugin/SparkConnectPluginRegistrySuite.scala | 5 ++-- 12 files changed, 114 insertions(+), 7 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala index dec699f4f1a8..c23d49440248 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala @@ -1351,10 +1351,16 @@ private[sql] object Column { } @DeveloperApi + @deprecated("Use forExtension(Array[Byte]) instead", "4.0.0") def apply(extension: com.google.protobuf.Any): Column = { apply(_.setExtension(extension)) } + @DeveloperApi + def forExtension(extension: Array[Byte]): Column = { + apply(_.setExtension(com.google.protobuf.Any.parseFrom(extension))) + } + private[sql] def fn(name: String, inputs: Column*): Column = { fn(name, isDistinct = false, inputs: _*) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index adee5b33fb4e..1e467a864442 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -496,17 +496,29 @@ class SparkSession private[sql] ( } @DeveloperApi + @deprecated("Use newDataFrame(Array[Byte]) instead", "4.0.0") def newDataFrame(extension: com.google.protobuf.Any): DataFrame = { - newDataset(extension, UnboundRowEncoder) + newDataFrame(_.setExtension(extension)) } @DeveloperApi + @deprecated("Use newDataFrame(Array[Byte], AgnosticEncoder[T]) instead", "4.0.0") def newDataset[T]( extension: com.google.protobuf.Any, encoder: AgnosticEncoder[T]): Dataset[T] = { newDataset(encoder)(_.setExtension(extension)) } + @DeveloperApi + def newDataFrame(extension: Array[Byte]): DataFrame = { + newDataFrame(_.setExtension(com.google.protobuf.Any.parseFrom(extension))) + } + + @DeveloperApi + def newDataset[T](extension: Array[Byte], encoder: AgnosticEncoder[T]): Dataset[T] = { + newDataset(encoder)(_.setExtension(com.google.protobuf.Any.parseFrom(extension))) + } + private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = { val builder = proto.Command.newBuilder() f(builder) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala index 041b09283658..4a32b8460bce 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala @@ -162,7 +162,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { } } - test("command extension") { + test("command extension deprecated") { val extension = proto.ExamplePluginCommand.newBuilder().setCustomField("abc").build() val command = proto.Command .newBuilder() @@ -174,6 +174,18 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(actualPlan.equals(expectedPlan)) } + test("command extension") { + val extension = proto.ExamplePluginCommand.newBuilder().setCustomField("abc").build() + val command = proto.Command + .newBuilder() + .setExtension(com.google.protobuf.Any.pack(extension)) + .build() + val expectedPlan = proto.Plan.newBuilder().setCommand(command).build() + ss.execute(com.google.protobuf.Any.pack(extension).toByteArray) + val actualPlan = service.getAndClearLatestInputPlan() + assert(actualPlan.equals(expectedPlan)) + } + test("serialize as null") { val session = newSparkSession() val ds = session.range(10) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 5fde8b04735b..5844df8a4889 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3191,7 +3191,7 @@ class PlanGenerationTestSuite } /* Extensions */ - test("relation extension") { + test("relation extension deprecated") { val input = proto.ExamplePluginRelation .newBuilder() .setInput(simple.plan.getRoot) @@ -3199,7 +3199,7 @@ class PlanGenerationTestSuite session.newDataFrame(com.google.protobuf.Any.pack(input)) } - test("expression extension") { + test("expression extension deprecated") { val extension = proto.ExamplePluginExpression .newBuilder() .setChild( @@ -3213,6 +3213,28 @@ class PlanGenerationTestSuite simple.select(Column(com.google.protobuf.Any.pack(extension))) } + test("relation extension") { + val input = proto.ExamplePluginRelation + .newBuilder() + .setInput(simple.plan.getRoot) + .build() + session.newDataFrame(com.google.protobuf.Any.pack(input).toByteArray) + } + + test("expression extension") { + val extension = proto.ExamplePluginExpression + .newBuilder() + .setChild( + proto.Expression + .newBuilder() + .setUnresolvedAttribute(proto.Expression.UnresolvedAttribute + .newBuilder() + .setUnparsedIdentifier("id"))) + .setCustomField("abc") + .build() + simple.select(Column.forExtension(com.google.protobuf.Any.pack(extension).toByteArray)) + } + test("crosstab") { simple.stat.crosstab("a", "b") } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension_deprecated.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension_deprecated.explain new file mode 100644 index 000000000000..7426332004a8 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension_deprecated.explain @@ -0,0 +1,2 @@ +Project [id#0L AS abc#0L] ++- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension_deprecated.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension_deprecated.explain new file mode 100644 index 000000000000..df724a7dd185 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension_deprecated.explain @@ -0,0 +1 @@ +LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.json b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.json new file mode 100644 index 000000000000..acfb3cc2333d --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.json @@ -0,0 +1,26 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "expressions": [{ + "extension": { + "@type": "type.googleapis.com/spark.connect.ExamplePluginExpression", + "child": { + "unresolvedAttribute": { + "unparsedIdentifier": "id" + } + }, + "customField": "abc" + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.proto.bin new file mode 100644 index 000000000000..24669eba6423 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.json b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.json new file mode 100644 index 000000000000..47ceba13ca7e --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.json @@ -0,0 +1,16 @@ +{ + "common": { + "planId": "1" + }, + "extension": { + "@type": "type.googleapis.com/spark.connect.ExamplePluginRelation", + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + } + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.proto.bin new file mode 100644 index 000000000000..680bb550eca5 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.proto.bin differ diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 1894ab984490..40dc7f88255e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -30,6 +30,7 @@ import io.grpc.stub.StreamObserver import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult} @@ -202,6 +203,11 @@ class SparkConnectPlanner( plan } + @DeveloperApi + def transformRelation(bytes: Array[Byte]): LogicalPlan = { + transformRelation(proto.Relation.parseFrom(bytes)) + } + private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = { SparkConnectPluginRegistry.relationRegistry // Lazily traverse the collection. @@ -1470,6 +1476,11 @@ class SparkConnectPlanner( } } + @DeveloperApi + def transformExpression(bytes: Array[Byte]): Expression = { + transformExpression(proto.Expression.parseFrom(bytes)) + } + private def toNamedExpression(expr: Expression): NamedExpression = expr match { case named: NamedExpression => named case expr => UnresolvedAlias(expr) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index ff8cac7a35d6..a213a36168e8 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -68,7 +68,7 @@ class ExampleRelationPlugin extends RelationPlugin { return Optional.empty() } val plugin = rel.unpack(classOf[proto.ExamplePluginRelation]) - Optional.of(planner.transformRelation(plugin.getInput)) + Optional.of(planner.transformRelation(plugin.getInput.toByteArray)) } } @@ -82,8 +82,7 @@ class ExampleExpressionPlugin extends ExpressionPlugin { } val exp = rel.unpack(classOf[proto.ExamplePluginExpression]) Optional.of( - Alias(planner.transformExpression(exp.getChild), exp.getCustomField)(explicitMetadata = - None)) + Alias(planner.transformExpression(exp.getChild.toByteArray), exp.getCustomField)()) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org