This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 330b03cddb6 [SPARK-41725][CONNECT] Eager Execution of DF.sql() 330b03cddb6 is described below commit 330b03cddb6e30e0097c754e12d52e8768bfb52a Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Wed Mar 1 16:55:57 2023 +0800 [SPARK-41725][CONNECT] Eager Execution of DF.sql() ### What changes were proposed in this pull request? This patch allows for eager execution of SQL statements using the Spark Connect Data Frame API. The implementation of the patch is as follows: When `spark.sql` is called, the client sends a command to the server including the SQL statement. The server will evaluate the query and execute the side-effects if necessary. If the query was a command it will return the results as a `Relaiton.LocalRelation` back to the client otherwise it will return a `Relation.SQL` to the client. The clien [...] ### Why are the changes needed? Compatibility ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #40160 from grundprinzip/eager_sql_v2. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> (cherry picked from commit 51a87ac549120d9fe1fe4503ca8825785d9e886d) Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../scala/org/apache/spark/sql/SparkSession.scala | 14 ++- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 23 ++++- .../apache/spark/sql/PlanGenerationTestSuite.scala | 8 -- .../connect/client/util/RemoteSparkSession.scala | 4 + .../src/main/protobuf/spark/connect/base.proto | 17 ++- .../src/main/protobuf/spark/connect/commands.proto | 15 +++ .../explain-results/parameterized_sql.explain | 2 - .../query-tests/explain-results/sql.explain | 2 - .../query-tests/queries/parameterized_sql.json | 12 --- .../queries/parameterized_sql.proto.bin | Bin 41 -> 0 bytes .../test/resources/query-tests/queries/sql.json | 8 -- .../resources/query-tests/queries/sql.proto.bin | Bin 16 -> 0 bytes .../sql/connect/planner/SparkConnectPlanner.scala | 85 ++++++++++++++- .../service/SparkConnectStreamHandler.scala | 66 ++++++------ .../connect/planner/SparkConnectPlannerSuite.scala | 10 +- .../plugin/SparkConnectPluginRegistrySuite.scala | 2 +- python/pyspark/sql/connect/client.py | 32 ++++-- python/pyspark/sql/connect/plan.py | 19 ++++ python/pyspark/sql/connect/proto/base_pb2.py | 115 ++++++++++++--------- python/pyspark/sql/connect/proto/base_pb2.pyi | 61 ++++++++++- python/pyspark/sql/connect/proto/commands_pb2.py | 77 +++++++++----- python/pyspark/sql/connect/proto/commands_pb2.pyi | 56 ++++++++++ python/pyspark/sql/connect/session.py | 9 +- python/pyspark/sql/tests/connect/test_client.py | 17 +++ .../sql/execution/arrow/ArrowConverters.scala | 6 +- 25 files changed, 495 insertions(+), 165 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 84731072ebc..e72dc264727 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -123,8 +123,18 @@ class SparkSession private[sql] ( @Experimental def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataFrame { builder => - builder - .setSql(proto.SQL.newBuilder().setQuery(sqlText).putAllArgs(args)) + // Send the SQL once to the server and then check the output. + val cmd = newCommand(b => + b.setSqlCommand(proto.SqlCommand.newBuilder().setSql(sqlText).putAllArgs(args))) + val plan = proto.Plan.newBuilder().setCommand(cmd) + val responseIter = client.execute(plan.build()) + + val response = responseIter.asScala + .find(_.hasSqlCommandResult) + .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) + + // Update the builder with the values from the result. + builder.mergeFrom(response.getSqlCommandResult.getRelation) } /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index d47cc3858ab..274424f0f0d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -52,6 +52,23 @@ class ClientE2ETestSuite extends RemoteSparkSession { assert(result(1).getString(0) == "World") } + test("eager execution of sql") { + withTable("test_martin") { + // Fails, because table does not exist. + assertThrows[StatusRuntimeException] { + spark.sql("select * from test_martin").collect() + } + // Execute eager, DML + spark.sql("create table test_martin (id int)") + // Execute read again. + val rows = spark.sql("select * from test_martin").collect() + assert(rows.length == 0) + spark.sql("insert into test_martin values (1), (2)") + val rows_new = spark.sql("select * from test_martin").collect() + assert(rows_new.length == 2) + } + } + test("simple dataset") { val df = spark.range(10).limit(3) val result = df.collect() @@ -189,10 +206,8 @@ class ClientE2ETestSuite extends RemoteSparkSession { // e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) test("writeTo with create") { withTable("myTableV2") { - assertThrows[StatusRuntimeException] { - // Failed to create as Hive support is required. - spark.range(3).writeTo("myTableV2").create() - } + // Failed to create as Hive support is required. + spark.range(3).writeTo("myTableV2").create() } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index bc7111e9bf8..0b198ab8f70 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -230,14 +230,6 @@ class PlanGenerationTestSuite private def temporals = createLocalRelation(temporalsSchemaString) /* Spark Session API */ - test("sql") { - session.sql("select 1") - } - - test("parameterized sql") { - session.sql("select 1", Map("minId" -> "7", "maxId" -> "20")) - } - test("range") { session.range(1, 10, 1, 2) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index 96b3ab4e9ef..0ec31ee9943 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -69,6 +69,10 @@ object SparkConnectServerUtils { jar, "--conf", s"spark.connect.grpc.binding.port=$port", + "--conf", + "spark.sql.catalog.testcat=org.apache.spark.sql.connect.catalog.InMemoryTableCatalog", + "--conf", + "spark.sql.catalogImplementation=hive", "--class", "org.apache.spark.sql.connect.SimpleSparkConnectService", jar), diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 3eacd0cc482..066a63d58ba 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -221,12 +221,27 @@ message ExecutePlanRequest { message ExecutePlanResponse { string client_id = 1; - ArrowBatch arrow_batch = 2; + // Union type for the different response messages. + oneof response_type { + ArrowBatch arrow_batch = 2; + + // Special case for executing SQL commands. + SqlCommandResult sql_command_result = 5; + + // Support arbitrary result objects. + google.protobuf.Any extension = 999; + } // Metrics for the query execution. Typically, this field is only present in the last // batch of results and then represent the overall state of the query execution. Metrics metrics = 4; + // A SQL command returns an opaque Relation that can be directly used as input for the next + // call. + message SqlCommandResult { + Relation relation = 1; + } + // Batch results of metrics. message ArrowBatch { int64 row_count = 1; diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index 1f2f473a050..e553dcb1bc4 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -35,6 +35,7 @@ message Command { WriteOperation write_operation = 2; CreateDataFrameViewCommand create_dataframe_view = 3; WriteOperationV2 write_operation_v2 = 4; + SqlCommand sql_command = 5; // This field is used to mark extensions to the protocol. When plugins generate arbitrary // Commands they can add them here. During the planning the correct resolution is done. @@ -43,6 +44,20 @@ message Command { } } +// A SQL Command is used to trigger the eager evaluation of SQL commands in Spark. +// +// When the SQL provide as part of the message is a command it will be immediately evaluated +// and the result will be collected and returned as part of a LocalRelation. If the result is +// not a command, the operation will simply return a SQL Relation. This allows the client to be +// almost oblivious to the server-side behavior. +message SqlCommand { + // (Required) SQL Query. + string sql = 1; + + // (Optional) A map of parameter names to literal values. + map<string, string> args = 2; +} + // A command that can create DataFrame global temp view or local temp view. message CreateDataFrameViewCommand { // (Required) The relation that this view will be built on. diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/parameterized_sql.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/parameterized_sql.explain deleted file mode 100644 index 7f5aafb1943..00000000000 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/parameterized_sql.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [1 AS 1#0] -+- OneRowRelation diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/sql.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/sql.explain deleted file mode 100644 index 7f5aafb1943..00000000000 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/sql.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [1 AS 1#0] -+- OneRowRelation diff --git a/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.json b/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.json deleted file mode 100644 index 5ceb1d5a087..00000000000 --- a/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "common": { - "planId": "0" - }, - "sql": { - "query": "select 1", - "args": { - "minId": "7", - "maxId": "20" - } - } -} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.proto.bin deleted file mode 100644 index 50bc8457f31..00000000000 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.proto.bin and /dev/null differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/sql.json b/connector/connect/common/src/test/resources/query-tests/queries/sql.json deleted file mode 100644 index c4bc9b2c082..00000000000 --- a/connector/connect/common/src/test/resources/query-tests/queries/sql.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "common": { - "planId": "0" - }, - "sql": { - "query": "select 1" - } -} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/sql.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/sql.proto.bin deleted file mode 100644 index 3d4394f23af..00000000000 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/sql.proto.bin and /dev/null differ diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index d52117b469c..c8b1b3125f9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -21,11 +21,14 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import com.google.common.collect.{Lists, Maps} -import com.google.protobuf.{Any => ProtoAny} +import com.google.protobuf.{Any => ProtoAny, ByteString} +import io.grpc.stub.StreamObserver -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand} +import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} @@ -34,11 +37,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket} +import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue} import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry +import org.apache.spark.sql.connect.service.SparkConnectStreamHandler import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -1438,7 +1443,10 @@ class SparkConnectPlanner(val session: SparkSession) { } } - def process(command: proto.Command): Unit = { + def process( + command: proto.Command, + clientId: String, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { command.getCommandTypeCase match { case proto.Command.CommandTypeCase.REGISTER_FUNCTION => handleRegisterUserDefinedFunction(command.getRegisterFunction) @@ -1450,10 +1458,79 @@ class SparkConnectPlanner(val session: SparkSession) { handleWriteOperationV2(command.getWriteOperationV2) case proto.Command.CommandTypeCase.EXTENSION => handleCommandPlugin(command.getExtension) + case proto.Command.CommandTypeCase.SQL_COMMAND => + handleSqlCommand(command.getSqlCommand, clientId, responseObserver) case _ => throw new UnsupportedOperationException(s"$command not supported.") } } + def handleSqlCommand( + getSqlCommand: SqlCommand, + clientId: String, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + // Eagerly execute commands of the provided SQL string. + val df = session.sql(getSqlCommand.getSql, getSqlCommand.getArgsMap) + // Check if commands have been executed. + val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult] + val rows = df.logicalPlan match { + case lr: LocalRelation => lr.data + case cr: CommandResult => cr.rows + case _ => Seq.empty + } + + // Convert the results to Arrow. + val schema = df.schema + val maxRecordsPerBatch = session.sessionState.conf.arrowMaxRecordsPerBatch + val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong + val timeZoneId = session.sessionState.conf.sessionLocalTimeZone + + // Convert the data. + val bytes = if (rows.isEmpty) { + ArrowConverters.createEmptyArrowBatch(schema, timeZoneId) + } else { + val batches = ArrowConverters.toBatchWithSchemaIterator( + rows.iterator, + schema, + maxRecordsPerBatch, + maxBatchSize, + timeZoneId) + assert(batches.size == 1) + batches.next() + } + + // To avoid explicit handling of the result on the client, we build the expected input + // of the relation on the server. The client has to simply forward the result. + val result = SqlCommandResult.newBuilder() + if (isCommand) { + result.setRelation( + proto.Relation + .newBuilder() + .setLocalRelation( + proto.LocalRelation + .newBuilder() + .setData(ByteString.copyFrom(bytes)))) + } else { + result.setRelation( + proto.Relation + .newBuilder() + .setSql( + proto.SQL + .newBuilder() + .setQuery(getSqlCommand.getSql) + .putAllArgs(getSqlCommand.getArgsMap))) + } + // Exactly one SQL Command Result Batch + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setClientId(clientId) + .setSqlCommandResult(result) + .build()) + + // Send Metrics + SparkConnectStreamHandler.sendMetricsToResponse(clientId, df) + } + private def handleRegisterUserDefinedFunction( fun: proto.CommonInlineUserDefinedFunction): Unit = { fun.getFunctionCase match { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index f46aca9c8cf..41ca564e6d3 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsArrowBatches import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -58,10 +59,41 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(session) val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) - processAsArrowBatches(request.getClientId, dataframe) + processAsArrowBatches(request.getClientId, dataframe, responseObserver) + responseObserver.onNext( + SparkConnectStreamHandler.sendMetricsToResponse(request.getClientId, dataframe)) + responseObserver.onCompleted() + } + + private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = { + val command = request.getPlan.getCommand + val planner = new SparkConnectPlanner(session) + planner.process(command, request.getClientId, responseObserver) + responseObserver.onCompleted() + } +} + +object SparkConnectStreamHandler { + type Batch = (Array[Byte], Long) + + def rowToArrowConverter( + schema: StructType, + maxRecordsPerBatch: Int, + maxBatchSize: Long, + timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows => + val batches = ArrowConverters.toBatchWithSchemaIterator( + rows, + schema, + maxRecordsPerBatch, + maxBatchSize, + timeZoneId) + batches.map(b => b -> batches.rowCountInLastBatch) } - private def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { + def processAsArrowBatches( + clientId: String, + dataframe: DataFrame, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val spark = dataframe.sparkSession val schema = dataframe.schema val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch @@ -163,13 +195,10 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp response.setArrowBatch(batch) responseObserver.onNext(response.build()) } - - responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) - responseObserver.onCompleted() } } - private def sendMetricsToResponse(clientId: String, rows: DataFrame): ExecutePlanResponse = { + def sendMetricsToResponse(clientId: String, rows: DataFrame): ExecutePlanResponse = { // Send a last batch with the metrics ExecutePlanResponse .newBuilder() @@ -177,31 +206,6 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan)) .build() } - - private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = { - val command = request.getPlan.getCommand - val planner = new SparkConnectPlanner(session) - planner.process(command) - responseObserver.onCompleted() - } -} - -object SparkConnectStreamHandler { - type Batch = (Array[Byte], Long) - - private def rowToArrowConverter( - schema: StructType, - maxRecordsPerBatch: Int, - maxBatchSize: Long, - timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows => - val batches = ArrowConverters.toBatchWithSchemaIterator( - rows, - schema, - maxRecordsPerBatch, - maxBatchSize, - timeZoneId) - batches.map(b => b -> batches.rowCountInLastBatch) - } } object MetricGenerator extends AdaptiveSparkPlanHelper { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 83056c27729..b79d91d2d10 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.connect.planner import scala.collection.JavaConverters._ import com.google.protobuf.ByteString +import io.grpc.stub.StreamObserver import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.connect.proto.Expression.{Alias, ExpressionString, UnresolvedStar} import org.apache.spark.sql.{AnalysisException, Dataset, Row} import org.apache.spark.sql.catalyst.InternalRow @@ -41,12 +43,18 @@ import org.apache.spark.unsafe.types.UTF8String */ trait SparkConnectPlanTest extends SharedSparkSession { + class MockObserver extends StreamObserver[proto.ExecutePlanResponse] { + override def onNext(value: ExecutePlanResponse): Unit = {} + override def onError(t: Throwable): Unit = {} + override def onCompleted(): Unit = {} + } + def transform(rel: proto.Relation): logical.LogicalPlan = { new SparkConnectPlanner(spark).transformRelation(rel) } def transform(cmd: proto.Command): Unit = { - new SparkConnectPlanner(spark).process(cmd) + new SparkConnectPlanner(spark).process(cmd, "clientId", new MockObserver()) } def readRel: proto.Relation = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index 7abe4a4d085..39fc90fd002 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -195,7 +195,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne .build())) .build() - new SparkConnectPlanner(spark).process(plan) + new SparkConnectPlanner(spark).process(plan, "clientId", new MockObserver()) assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin")) } } diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 8046da409d7..3f70ca6ad15 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -568,7 +568,8 @@ class SparkConnectClient(object): logger.info(f"Executing plan {self._proto_to_string(plan)}") req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) - table, _ = self._execute_and_fetch(req) + table, _, _2 = self._execute_and_fetch(req) + assert table is not None return table def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame": @@ -578,7 +579,8 @@ class SparkConnectClient(object): logger.info(f"Executing plan {self._proto_to_string(plan)}") req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) - table, metrics = self._execute_and_fetch(req) + table, metrics, _ = self._execute_and_fetch(req) + assert table is not None column_names = table.column_names table = table.rename_columns([f"col_{i}" for i in range(len(column_names))]) pdf = table.to_pandas() @@ -641,7 +643,9 @@ class SparkConnectClient(object): assert result is not None return result - def execute_command(self, command: pb2.Command) -> None: + def execute_command( + self, command: pb2.Command + ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]: """ Execute given command. """ @@ -650,8 +654,11 @@ class SparkConnectClient(object): if self._user_id: req.user_context.user_id = self._user_id req.plan.command.CopyFrom(command) - self._execute(req) - return + data, _, properties = self._execute_and_fetch(req) + if data is not None: + return (data.to_pandas(), properties) + else: + return (None, properties) def close(self) -> None: """ @@ -774,12 +781,12 @@ class SparkConnectClient(object): def _execute_and_fetch( self, req: pb2.ExecutePlanRequest - ) -> Tuple["pa.Table", List[PlanMetrics]]: + ) -> Tuple[Optional["pa.Table"], List[PlanMetrics], Dict[str, Any]]: logger.info("ExecuteAndFetch") m: Optional[pb2.ExecutePlanResponse.Metrics] = None batches: List[pa.RecordBatch] = [] - + properties = {} try: for attempt in Retrying( can_retry=SparkConnectClient.retry_exception, **self._retry_policy @@ -795,6 +802,8 @@ class SparkConnectClient(object): if b.metrics is not None: logger.debug("Received metric batch.") m = b.metrics + if b.HasField("sql_command_result"): + properties["sql_command_result"] = b.sql_command_result.relation if b.HasField("arrow_batch"): logger.debug( f"Received arrow batch rows={b.arrow_batch.row_count} " @@ -807,10 +816,13 @@ class SparkConnectClient(object): batches.append(batch) except grpc.RpcError as rpc_error: self._handle_error(rpc_error) - assert len(batches) > 0 - table = pa.Table.from_batches(batches=batches) metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else [] - return table, metrics + + if len(batches) > 0: + table = pa.Table.from_batches(batches=batches) + return table, metrics, properties + else: + return None, metrics, properties def _config_request_with_metadata(self) -> pb2.ConfigRequest: req = pb2.ConfigRequest() diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index f82cf9167cb..7e767885793 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -955,6 +955,14 @@ class SQL(LogicalPlan): return plan + def command(self, session: "SparkConnectClient") -> proto.Command: + cmd = proto.Command() + cmd.sql_command.sql = self._query + if self._args is not None and len(self._args) > 0: + for k, v in self._args.items(): + cmd.sql_command.args[k] = v + return cmd + class Range(LogicalPlan): def __init__( @@ -1880,3 +1888,14 @@ class FrameMap(LogicalPlan): plan.frame_map.input.CopyFrom(self._child.plan(session)) plan.frame_map.func.CopyFrom(self._func.to_plan_udf(session)) return plan + + +class CachedRelation(LogicalPlan): + def __init__(self, plan: proto.Relation) -> None: + super(CachedRelation, self).__init__(None) + self._plan = plan + # Update the plan ID based on the incremented counter. + self._plan.common.plan_id = self._plan_id + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + return self._plan diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index c43619facb6..628f7ebdd46 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x0 [...] + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x0 [...] ) @@ -62,6 +62,9 @@ _ANALYZEPLANRESPONSE_SPARKVERSION = _ANALYZEPLANRESPONSE.nested_types_by_name["S _ANALYZEPLANRESPONSE_DDLPARSE = _ANALYZEPLANRESPONSE.nested_types_by_name["DDLParse"] _EXECUTEPLANREQUEST = DESCRIPTOR.message_types_by_name["ExecutePlanRequest"] _EXECUTEPLANRESPONSE = DESCRIPTOR.message_types_by_name["ExecutePlanResponse"] +_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT = _EXECUTEPLANRESPONSE.nested_types_by_name[ + "SqlCommandResult" +] _EXECUTEPLANRESPONSE_ARROWBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["ArrowBatch"] _EXECUTEPLANRESPONSE_METRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["Metrics"] _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[ @@ -319,6 +322,15 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType( "ExecutePlanResponse", (_message.Message,), { + "SqlCommandResult": _reflection.GeneratedProtocolMessageType( + "SqlCommandResult", + (_message.Message,), + { + "DESCRIPTOR": _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.SqlCommandResult) + }, + ), "ArrowBatch": _reflection.GeneratedProtocolMessageType( "ArrowBatch", (_message.Message,), @@ -370,6 +382,7 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType( }, ) _sym_db.RegisterMessage(ExecutePlanResponse) +_sym_db.RegisterMessage(ExecutePlanResponse.SqlCommandResult) _sym_db.RegisterMessage(ExecutePlanResponse.ArrowBatch) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject) @@ -613,53 +626,55 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXECUTEPLANREQUEST._serialized_start = 2919 _EXECUTEPLANREQUEST._serialized_end = 3126 _EXECUTEPLANRESPONSE._serialized_start = 3129 - _EXECUTEPLANRESPONSE._serialized_end = 3912 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 3331 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 3392 - _EXECUTEPLANRESPONSE_METRICS._serialized_start = 3395 - _EXECUTEPLANRESPONSE_METRICS._serialized_end = 3912 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 3490 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 3822 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 3699 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 3822 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 3824 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 3912 - _KEYVALUE._serialized_start = 3914 - _KEYVALUE._serialized_end = 3979 - _CONFIGREQUEST._serialized_start = 3982 - _CONFIGREQUEST._serialized_end = 5008 - _CONFIGREQUEST_OPERATION._serialized_start = 4200 - _CONFIGREQUEST_OPERATION._serialized_end = 4698 - _CONFIGREQUEST_SET._serialized_start = 4700 - _CONFIGREQUEST_SET._serialized_end = 4752 - _CONFIGREQUEST_GET._serialized_start = 4754 - _CONFIGREQUEST_GET._serialized_end = 4779 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 4781 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 4844 - _CONFIGREQUEST_GETOPTION._serialized_start = 4846 - _CONFIGREQUEST_GETOPTION._serialized_end = 4877 - _CONFIGREQUEST_GETALL._serialized_start = 4879 - _CONFIGREQUEST_GETALL._serialized_end = 4927 - _CONFIGREQUEST_UNSET._serialized_start = 4929 - _CONFIGREQUEST_UNSET._serialized_end = 4956 - _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 4958 - _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 4992 - _CONFIGRESPONSE._serialized_start = 5010 - _CONFIGRESPONSE._serialized_end = 5130 - _ADDARTIFACTSREQUEST._serialized_start = 5133 - _ADDARTIFACTSREQUEST._serialized_end = 5948 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 5480 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 5533 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 5535 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 5646 - _ADDARTIFACTSREQUEST_BATCH._serialized_start = 5648 - _ADDARTIFACTSREQUEST_BATCH._serialized_end = 5741 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 5744 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 5937 - _ADDARTIFACTSRESPONSE._serialized_start = 5951 - _ADDARTIFACTSRESPONSE._serialized_end = 6139 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 6058 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 6139 - _SPARKCONNECTSERVICE._serialized_start = 6142 - _SPARKCONNECTSERVICE._serialized_end = 6507 + _EXECUTEPLANRESPONSE._serialized_end = 4160 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 3489 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 3560 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 3562 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 3623 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 3626 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4143 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 3721 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4053 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 3930 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 4053 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4055 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4143 + _KEYVALUE._serialized_start = 4162 + _KEYVALUE._serialized_end = 4227 + _CONFIGREQUEST._serialized_start = 4230 + _CONFIGREQUEST._serialized_end = 5256 + _CONFIGREQUEST_OPERATION._serialized_start = 4448 + _CONFIGREQUEST_OPERATION._serialized_end = 4946 + _CONFIGREQUEST_SET._serialized_start = 4948 + _CONFIGREQUEST_SET._serialized_end = 5000 + _CONFIGREQUEST_GET._serialized_start = 5002 + _CONFIGREQUEST_GET._serialized_end = 5027 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5029 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5092 + _CONFIGREQUEST_GETOPTION._serialized_start = 5094 + _CONFIGREQUEST_GETOPTION._serialized_end = 5125 + _CONFIGREQUEST_GETALL._serialized_start = 5127 + _CONFIGREQUEST_GETALL._serialized_end = 5175 + _CONFIGREQUEST_UNSET._serialized_start = 5177 + _CONFIGREQUEST_UNSET._serialized_end = 5204 + _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 5206 + _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 5240 + _CONFIGRESPONSE._serialized_start = 5258 + _CONFIGRESPONSE._serialized_end = 5378 + _ADDARTIFACTSREQUEST._serialized_start = 5381 + _ADDARTIFACTSREQUEST._serialized_end = 6196 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 5728 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 5781 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 5783 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 5894 + _ADDARTIFACTSREQUEST_BATCH._serialized_start = 5896 + _ADDARTIFACTSREQUEST_BATCH._serialized_end = 5989 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 5992 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 6185 + _ADDARTIFACTSRESPONSE._serialized_start = 6199 + _ADDARTIFACTSRESPONSE._serialized_end = 6387 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 6306 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 6387 + _SPARKCONNECTSERVICE._serialized_start = 6390 + _SPARKCONNECTSERVICE._serialized_end = 6755 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 677f101aa47..0e800947975 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -759,6 +759,28 @@ class ExecutePlanResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class SqlCommandResult(google.protobuf.message.Message): + """A SQL command returns an opaque Relation that can be directly used as input for the next + call. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + RELATION_FIELD_NUMBER: builtins.int + @property + def relation(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: ... + def __init__( + self, + *, + relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["relation", b"relation"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["relation", b"relation"] + ) -> None: ... + class ArrowBatch(google.protobuf.message.Message): """Batch results of metrics.""" @@ -885,11 +907,19 @@ class ExecutePlanResponse(google.protobuf.message.Message): CLIENT_ID_FIELD_NUMBER: builtins.int ARROW_BATCH_FIELD_NUMBER: builtins.int + SQL_COMMAND_RESULT_FIELD_NUMBER: builtins.int + EXTENSION_FIELD_NUMBER: builtins.int METRICS_FIELD_NUMBER: builtins.int client_id: builtins.str @property def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ... @property + def sql_command_result(self) -> global___ExecutePlanResponse.SqlCommandResult: + """Special case for executing SQL commands.""" + @property + def extension(self) -> google.protobuf.any_pb2.Any: + """Support arbitrary result objects.""" + @property def metrics(self) -> global___ExecutePlanResponse.Metrics: """Metrics for the query execution. Typically, this field is only present in the last batch of results and then represent the overall state of the query execution. @@ -899,18 +929,45 @@ class ExecutePlanResponse(google.protobuf.message.Message): *, client_id: builtins.str = ..., arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ..., + sql_command_result: global___ExecutePlanResponse.SqlCommandResult | None = ..., + extension: google.protobuf.any_pb2.Any | None = ..., metrics: global___ExecutePlanResponse.Metrics | None = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal["arrow_batch", b"arrow_batch", "metrics", b"metrics"], + field_name: typing_extensions.Literal[ + "arrow_batch", + b"arrow_batch", + "extension", + b"extension", + "metrics", + b"metrics", + "response_type", + b"response_type", + "sql_command_result", + b"sql_command_result", + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "arrow_batch", b"arrow_batch", "client_id", b"client_id", "metrics", b"metrics" + "arrow_batch", + b"arrow_batch", + "client_id", + b"client_id", + "extension", + b"extension", + "metrics", + b"metrics", + "response_type", + b"response_type", + "sql_command_result", + b"sql_command_result", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["response_type", b"response_type"] + ) -> typing_extensions.Literal["arrow_batch", "sql_command_result", "extension"] | None: ... global___ExecutePlanResponse = ExecutePlanResponse diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index c8ade1ea81b..823ed81aa07 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -36,11 +36,13 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...] + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xe9\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...] ) _COMMAND = DESCRIPTOR.message_types_by_name["Command"] +_SQLCOMMAND = DESCRIPTOR.message_types_by_name["SqlCommand"] +_SQLCOMMAND_ARGSENTRY = _SQLCOMMAND.nested_types_by_name["ArgsEntry"] _CREATEDATAFRAMEVIEWCOMMAND = DESCRIPTOR.message_types_by_name["CreateDataFrameViewCommand"] _WRITEOPERATION = DESCRIPTOR.message_types_by_name["WriteOperation"] _WRITEOPERATION_OPTIONSENTRY = _WRITEOPERATION.nested_types_by_name["OptionsEntry"] @@ -67,6 +69,27 @@ Command = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(Command) +SqlCommand = _reflection.GeneratedProtocolMessageType( + "SqlCommand", + (_message.Message,), + { + "ArgsEntry": _reflection.GeneratedProtocolMessageType( + "ArgsEntry", + (_message.Message,), + { + "DESCRIPTOR": _SQLCOMMAND_ARGSENTRY, + "__module__": "spark.connect.commands_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.SqlCommand.ArgsEntry) + }, + ), + "DESCRIPTOR": _SQLCOMMAND, + "__module__": "spark.connect.commands_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.SqlCommand) + }, +) +_sym_db.RegisterMessage(SqlCommand) +_sym_db.RegisterMessage(SqlCommand.ArgsEntry) + CreateDataFrameViewCommand = _reflection.GeneratedProtocolMessageType( "CreateDataFrameViewCommand", (_message.Message,), @@ -154,6 +177,8 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" + _SQLCOMMAND_ARGSENTRY._options = None + _SQLCOMMAND_ARGSENTRY._serialized_options = b"8\001" _WRITEOPERATION_OPTIONSENTRY._options = None _WRITEOPERATION_OPTIONSENTRY._serialized_options = b"8\001" _WRITEOPERATIONV2_OPTIONSENTRY._options = None @@ -161,27 +186,31 @@ if _descriptor._USE_C_DESCRIPTORS == False: _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001" _COMMAND._serialized_start = 166 - _COMMAND._serialized_end = 593 - _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596 - _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746 - _WRITEOPERATION._serialized_start = 749 - _WRITEOPERATION._serialized_end = 1800 - _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1224 - _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1282 - _WRITEOPERATION_SAVETABLE._serialized_start = 1285 - _WRITEOPERATION_SAVETABLE._serialized_end = 1543 - _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 1419 - _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 1543 - _WRITEOPERATION_BUCKETBY._serialized_start = 1545 - _WRITEOPERATION_BUCKETBY._serialized_end = 1636 - _WRITEOPERATION_SAVEMODE._serialized_start = 1639 - _WRITEOPERATION_SAVEMODE._serialized_end = 1776 - _WRITEOPERATIONV2._serialized_start = 1803 - _WRITEOPERATIONV2._serialized_end = 2616 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1224 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1282 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2375 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2441 - _WRITEOPERATIONV2_MODE._serialized_start = 2444 - _WRITEOPERATIONV2_MODE._serialized_end = 2603 + _COMMAND._serialized_end = 655 + _SQLCOMMAND._serialized_start = 658 + _SQLCOMMAND._serialized_end = 802 + _SQLCOMMAND_ARGSENTRY._serialized_start = 747 + _SQLCOMMAND_ARGSENTRY._serialized_end = 802 + _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 805 + _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 955 + _WRITEOPERATION._serialized_start = 958 + _WRITEOPERATION._serialized_end = 2009 + _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1433 + _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1491 + _WRITEOPERATION_SAVETABLE._serialized_start = 1494 + _WRITEOPERATION_SAVETABLE._serialized_end = 1752 + _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 1628 + _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 1752 + _WRITEOPERATION_BUCKETBY._serialized_start = 1754 + _WRITEOPERATION_BUCKETBY._serialized_end = 1845 + _WRITEOPERATION_SAVEMODE._serialized_start = 1848 + _WRITEOPERATION_SAVEMODE._serialized_end = 1985 + _WRITEOPERATIONV2._serialized_start = 2012 + _WRITEOPERATIONV2._serialized_end = 2825 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1433 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1491 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2584 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2650 + _WRITEOPERATIONV2_MODE._serialized_start = 2653 + _WRITEOPERATIONV2_MODE._serialized_end = 2812 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index fb767ead329..d2bfaf9ed89 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -63,6 +63,7 @@ class Command(google.protobuf.message.Message): WRITE_OPERATION_FIELD_NUMBER: builtins.int CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int + SQL_COMMAND_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int @property def register_function( @@ -75,6 +76,8 @@ class Command(google.protobuf.message.Message): @property def write_operation_v2(self) -> global___WriteOperationV2: ... @property + def sql_command(self) -> global___SqlCommand: ... + @property def extension(self) -> google.protobuf.any_pb2.Any: """This field is used to mark extensions to the protocol. When plugins generate arbitrary Commands they can add them here. During the planning the correct resolution is done. @@ -87,6 +90,7 @@ class Command(google.protobuf.message.Message): write_operation: global___WriteOperation | None = ..., create_dataframe_view: global___CreateDataFrameViewCommand | None = ..., write_operation_v2: global___WriteOperationV2 | None = ..., + sql_command: global___SqlCommand | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., ) -> None: ... def HasField( @@ -100,6 +104,8 @@ class Command(google.protobuf.message.Message): b"extension", "register_function", b"register_function", + "sql_command", + b"sql_command", "write_operation", b"write_operation", "write_operation_v2", @@ -117,6 +123,8 @@ class Command(google.protobuf.message.Message): b"extension", "register_function", b"register_function", + "sql_command", + b"sql_command", "write_operation", b"write_operation", "write_operation_v2", @@ -130,11 +138,59 @@ class Command(google.protobuf.message.Message): "write_operation", "create_dataframe_view", "write_operation_v2", + "sql_command", "extension", ] | None: ... global___Command = Command +class SqlCommand(google.protobuf.message.Message): + """A SQL Command is used to trigger the eager evaluation of SQL commands in Spark. + + When the SQL provide as part of the message is a command it will be immediately evaluated + and the result will be collected and returned as part of a LocalRelation. If the result is + not a command, the operation will simply return a SQL Relation. This allows the client to be + almost oblivious to the server-side behavior. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class ArgsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + + SQL_FIELD_NUMBER: builtins.int + ARGS_FIELD_NUMBER: builtins.int + sql: builtins.str + """(Required) SQL Query.""" + @property + def args(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """(Optional) A map of parameter names to literal values.""" + def __init__( + self, + *, + sql: builtins.str = ..., + args: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["args", b"args", "sql", b"sql"] + ) -> None: ... + +global___SqlCommand = SqlCommand + class CreateDataFrameViewCommand(google.protobuf.message.Message): """A command that can create DataFrame global temp view or local temp view.""" diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index d82dbcb2db0..6b501b7c375 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -50,7 +50,7 @@ from pyspark import SparkContext, SparkConf, __version__ from pyspark.sql.connect.client import SparkConnectClient from pyspark.sql.connect.conf import RuntimeConf from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.plan import SQL, Range, LocalRelation +from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer from pyspark.sql.pandas.types import to_arrow_type, _get_local_timezone @@ -347,7 +347,12 @@ class SparkSession: createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ def sql(self, sqlQuery: str, args: Optional[Dict[str, str]] = None) -> "DataFrame": - return DataFrame.withPlan(SQL(sqlQuery, args), self) + cmd = SQL(sqlQuery, args) + data, properties = self.client.execute_command(cmd.command(self._client)) + if "sql_command_result" in properties: + return DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self) + else: + return DataFrame.withPlan(SQL(sqlQuery, args), self) sql.__doc__ = PySparkSession.sql.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_client.py b/python/pyspark/sql/tests/connect/test_client.py index 41b2888eb74..84281a6764f 100644 --- a/python/pyspark/sql/tests/connect/test_client.py +++ b/python/pyspark/sql/tests/connect/test_client.py @@ -20,6 +20,11 @@ from typing import Optional from pyspark.sql.connect.client import SparkConnectClient import pyspark.sql.connect.proto as proto +from pyspark.testing.connectutils import should_test_connect + +if should_test_connect: + import pandas as pd + import pyarrow as pa class SparkConnectClientTestCase(unittest.TestCase): @@ -60,6 +65,18 @@ class MockService: self.req = req resp = proto.ExecutePlanResponse() resp.client_id = self._session_id + + pdf = pd.DataFrame(data={"col1": [1, 2]}) + schema = pa.Schema.from_pandas(pdf) + table = pa.Table.from_pandas(pdf) + sink = pa.BufferOutputStream() + + writer = pa.ipc.new_stream(sink, schema=schema) + writer.write(table) + writer.close() + + buf = sink.getvalue() + resp.arrow_batch.data = buf.to_pybytes() return [resp] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 40de117f6f6..b22c80d17e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -158,7 +158,11 @@ private[sql] object ArrowConverters extends Logging { rowCountInLastBatch < maxRecordsPerBatch)) { val row = rowIter.next() arrowWriter.write(row) - estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes + estimatedBatchSize += (row match { + case ur: UnsafeRow => ur.getSizeInBytes + // Trying to estimate the size of the current row, assuming 16 bytes per value. + case ir: InternalRow => ir.numFields * 16 + }) rowCountInLastBatch += 1 } arrowWriter.finish() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org