This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5750bdc553b [SPARK-42585][CONNECT][FOLLOWUP] Store cached local relations as proto 5750bdc553b is described below commit 5750bdc553be6f78ceb81a16d124ecc11481db8d Author: Max Gekk <max.g...@gmail.com> AuthorDate: Wed May 10 14:29:17 2023 +0300 [SPARK-42585][CONNECT][FOLLOWUP] Store cached local relations as proto ### What changes were proposed in this pull request? In the PR, I propose to store the cached local relations in the proto format, the same as `LocalRelation`. Also I reverted `transformLocalRelation()` to the state before the commit https://github.com/apache/spark/commit/0d7618a2ca847cf9577659e50409dd5a383d66d3. ### Why are the changes needed? To explicitly specify the format of cached local relations in the proto API. ### Does this PR introduce _any_ user-facing change? Yes but the feature of cached local relations haven't been released yet. ### How was this patch tested? By running the existing tests: ``` $ build/sbt "test:testOnly *.ArtifactManagerSuite" $ build/sbt "test:testOnly *.ClientE2ETestSuite" $ build/sbt "test:testOnly *.ArtifactStatusesHandlerSuite" ``` Closes #41107 from MaxGekk/cached-blob-in-proto. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../scala/org/apache/spark/sql/SparkSession.scala | 2 +- .../sql/connect/client/SparkConnectClient.scala | 18 ++- .../main/protobuf/spark/connect/relations.proto | 2 +- .../sql/connect/planner/SparkConnectPlanner.scala | 133 ++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 2 +- 5 files changed, 71 insertions(+), 86 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 4e5474a33b7..7395bb5f16c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -131,7 +131,7 @@ class SparkSession private[sql] ( .setSchema(encoder.schema.json) .setData(arrowData) } else { - val hash = client.cacheLocalRelation(arrowDataSize, arrowData, encoder.schema.json) + val hash = client.cacheLocalRelation(arrowData, encoder.schema.json) builder.getCachedLocalRelationBuilder .setUserId(client.userId) .setSessionId(client.sessionId) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index a5aabe62ae4..0c3b1cae091 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.connect.client import java.net.URI -import java.nio.ByteBuffer -import java.nio.charset.StandardCharsets import java.util.UUID import java.util.concurrent.Executor @@ -237,14 +235,14 @@ private[sql] class SparkConnectClient( /** * Cache the given local relation at the server, and return its key in the remote cache. */ - def cacheLocalRelation(size: Int, data: ByteString, schema: String): String = { - val schemaBytes = schema.getBytes(StandardCharsets.UTF_8) - val locRelData = data.toByteArray - val locRel = ByteBuffer.allocate(4 + locRelData.length + schemaBytes.length) - locRel.putInt(size) - locRel.put(locRelData) - locRel.put(schemaBytes) - artifactManager.cacheArtifact(locRel.array()) + def cacheLocalRelation(data: ByteString, schema: String): String = { + val localRelation = proto.Relation + .newBuilder() + .getLocalRelationBuilder + .setSchema(schema) + .setData(data) + .build() + artifactManager.cacheArtifact(localRelation.toByteArray) } } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 984b7d3166c..68133f509f3 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -390,7 +390,7 @@ message CachedLocalRelation { // (Required) An identifier of the Spark SQL session in which the user created the local relation. string sessionId = 2; - // (Required) A sha-256 hash of the serialized local relation. + // (Required) A sha-256 hash of the serialized local relation in proto, see LocalRelation. string hash = 3; } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 01f1e890630..b86ed866d6e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.connect.planner -import java.nio.ByteBuffer -import java.nio.charset.StandardCharsets - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -795,13 +792,12 @@ class SparkConnectPlanner(val session: SparkSession) { bytes .map { blockData => try { - val blob = blockData.toByteBuffer().array() - val blobSize = blockData.size.toInt - val size = ByteBuffer.wrap(blob).getInt - val intSize = 4 - val data = blob.slice(intSize, intSize + size) - val schema = new String(blob.slice(intSize + size, blobSize), StandardCharsets.UTF_8) - transformLocalRelation(Option(schema), Option(data)) + val localRelation = proto.Relation + .newBuilder() + .getLocalRelation + .getParserForType + .parseFrom(blockData.toInputStream()) + transformLocalRelation(localRelation) } finally { blockManager.releaseLock(blockId) } @@ -944,79 +940,70 @@ class SparkConnectPlanner(val session: SparkSession) { } } - private def transformLocalRelation( - schema: Option[String], - data: Option[Array[Byte]]): LogicalPlan = { - val optStruct = schema.map { schemaStr => + private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { + var schema: StructType = null + if (rel.hasSchema) { val schemaType = DataType.parseTypeWithFallback( - schemaStr, + rel.getSchema, parseDatatypeString, fallbackParser = DataType.fromJson) - schemaType match { + schema = schemaType match { case s: StructType => s case d => StructType(Seq(StructField("value", d))) } } - data - .map { dataBytes => - val (rows, structType) = - ArrowConverters.fromBatchWithSchemaIterator(Iterator(dataBytes), TaskContext.get()) - if (structType == null) { - throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.") - } - val attributes = structType.toAttributes - val proj = UnsafeProjection.create(attributes, attributes) - val data = rows.map(proj) - optStruct - .map { struct => - def normalize(dt: DataType): DataType = dt match { - case udt: UserDefinedType[_] => normalize(udt.sqlType) - case StructType(fields) => - val newFields = fields.zipWithIndex.map { - case (StructField(_, dataType, nullable, metadata), i) => - StructField(s"col_$i", normalize(dataType), nullable, metadata) - } - StructType(newFields) - case ArrayType(elementType, containsNull) => - ArrayType(normalize(elementType), containsNull) - case MapType(keyType, valueType, valueContainsNull) => - MapType(normalize(keyType), normalize(valueType), valueContainsNull) - case _ => dt + + if (rel.hasData) { + val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator( + Iterator(rel.getData.toByteArray), + TaskContext.get()) + if (structType == null) { + throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.") + } + val attributes = structType.toAttributes + val proj = UnsafeProjection.create(attributes, attributes) + val data = rows.map(proj) + + if (schema == null) { + logical.LocalRelation(attributes, data.map(_.copy()).toSeq) + } else { + def normalize(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => normalize(udt.sqlType) + case StructType(fields) => + val newFields = fields.zipWithIndex.map { + case (StructField(_, dataType, nullable, metadata), i) => + StructField(s"col_$i", normalize(dataType), nullable, metadata) } - val normalized = normalize(struct).asInstanceOf[StructType] - val project = Dataset - .ofRows( - session, - logicalPlan = logical.LocalRelation( - normalize(structType).asInstanceOf[StructType].toAttributes)) - .toDF(normalized.names: _*) - .to(normalized) - .logicalPlan - .asInstanceOf[Project] - - val proj = UnsafeProjection.create(project.projectList, project.child.output) - logical.LocalRelation(struct.toAttributes, data.map(proj).map(_.copy()).toSeq) - } - .getOrElse { - logical.LocalRelation(attributes, data.map(_.copy()).toSeq) - } + StructType(newFields) + case ArrayType(elementType, containsNull) => + ArrayType(normalize(elementType), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType(normalize(keyType), normalize(valueType), valueContainsNull) + case _ => dt + } + + val normalized = normalize(schema).asInstanceOf[StructType] + + val project = Dataset + .ofRows( + session, + logicalPlan = + logical.LocalRelation(normalize(structType).asInstanceOf[StructType].toAttributes)) + .toDF(normalized.names: _*) + .to(normalized) + .logicalPlan + .asInstanceOf[Project] + + val proj = UnsafeProjection.create(project.projectList, project.child.output) + logical.LocalRelation(schema.toAttributes, data.map(proj).map(_.copy()).toSeq) } - .getOrElse { - optStruct - .map { struct => - LocalRelation(struct.toAttributes, data = Seq.empty) - } - .getOrElse { - throw InvalidPlanInput( - s"Schema for LocalRelation is required when the input data is not provided.") - } + } else { + if (schema == null) { + throw InvalidPlanInput( + s"Schema for LocalRelation is required when the input data is not provided.") } - } - - private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { - transformLocalRelation( - if (rel.hasSchema) Some(rel.getSchema) else None, - if (rel.hasData) Some(rel.getData.toByteArray) else None) + LocalRelation(schema.toAttributes, data = Seq.empty) + } } /** Parse as DDL, with a fallback to JSON. Throws an exception if if fails to parse. */ diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 7898645dca5..69a4d6b9ccc 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1576,7 +1576,7 @@ class CachedLocalRelation(google.protobuf.message.Message): sessionId: builtins.str """(Required) An identifier of the Spark SQL session in which the user created the local relation.""" hash: builtins.str - """(Required) A sha-256 hash of the serialized local relation.""" + """(Required) A sha-256 hash of the serialized local relation in proto, see LocalRelation.""" def __init__( self, *, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org