This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5750bdc553b [SPARK-42585][CONNECT][FOLLOWUP] Store cached local 
relations as proto
5750bdc553b is described below

commit 5750bdc553be6f78ceb81a16d124ecc11481db8d
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Wed May 10 14:29:17 2023 +0300

    [SPARK-42585][CONNECT][FOLLOWUP] Store cached local relations as proto
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to store the cached local relations in the proto 
format, the same as `LocalRelation`. Also I reverted `transformLocalRelation()` 
to the state before the commit 
https://github.com/apache/spark/commit/0d7618a2ca847cf9577659e50409dd5a383d66d3.
    
    ### Why are the changes needed?
    To explicitly specify the format of cached local relations in the proto API.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes but the feature of cached local relations haven't been released yet.
    
    ### How was this patch tested?
    By running the existing tests:
    ```
    $ build/sbt "test:testOnly *.ArtifactManagerSuite"
    $ build/sbt "test:testOnly *.ClientE2ETestSuite"
    $ build/sbt "test:testOnly *.ArtifactStatusesHandlerSuite"
    ```
    
    Closes #41107 from MaxGekk/cached-blob-in-proto.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |   2 +-
 .../sql/connect/client/SparkConnectClient.scala    |  18 ++-
 .../main/protobuf/spark/connect/relations.proto    |   2 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 133 ++++++++++-----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |   2 +-
 5 files changed, 71 insertions(+), 86 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 4e5474a33b7..7395bb5f16c 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -131,7 +131,7 @@ class SparkSession private[sql] (
             .setSchema(encoder.schema.json)
             .setData(arrowData)
         } else {
-          val hash = client.cacheLocalRelation(arrowDataSize, arrowData, 
encoder.schema.json)
+          val hash = client.cacheLocalRelation(arrowData, encoder.schema.json)
           builder.getCachedLocalRelationBuilder
             .setUserId(client.userId)
             .setSessionId(client.sessionId)
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index a5aabe62ae4..0c3b1cae091 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -18,8 +18,6 @@
 package org.apache.spark.sql.connect.client
 
 import java.net.URI
-import java.nio.ByteBuffer
-import java.nio.charset.StandardCharsets
 import java.util.UUID
 import java.util.concurrent.Executor
 
@@ -237,14 +235,14 @@ private[sql] class SparkConnectClient(
   /**
    * Cache the given local relation at the server, and return its key in the 
remote cache.
    */
-  def cacheLocalRelation(size: Int, data: ByteString, schema: String): String 
= {
-    val schemaBytes = schema.getBytes(StandardCharsets.UTF_8)
-    val locRelData = data.toByteArray
-    val locRel = ByteBuffer.allocate(4 + locRelData.length + 
schemaBytes.length)
-    locRel.putInt(size)
-    locRel.put(locRelData)
-    locRel.put(schemaBytes)
-    artifactManager.cacheArtifact(locRel.array())
+  def cacheLocalRelation(data: ByteString, schema: String): String = {
+    val localRelation = proto.Relation
+      .newBuilder()
+      .getLocalRelationBuilder
+      .setSchema(schema)
+      .setData(data)
+      .build()
+    artifactManager.cacheArtifact(localRelation.toByteArray)
   }
 }
 
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 984b7d3166c..68133f509f3 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -390,7 +390,7 @@ message CachedLocalRelation {
   // (Required) An identifier of the Spark SQL session in which the user 
created the local relation.
   string sessionId = 2;
 
-  // (Required) A sha-256 hash of the serialized local relation.
+  // (Required) A sha-256 hash of the serialized local relation in proto, see 
LocalRelation.
   string hash = 3;
 }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 01f1e890630..b86ed866d6e 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -17,9 +17,6 @@
 
 package org.apache.spark.sql.connect.planner
 
-import java.nio.ByteBuffer
-import java.nio.charset.StandardCharsets
-
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 
@@ -795,13 +792,12 @@ class SparkConnectPlanner(val session: SparkSession) {
     bytes
       .map { blockData =>
         try {
-          val blob = blockData.toByteBuffer().array()
-          val blobSize = blockData.size.toInt
-          val size = ByteBuffer.wrap(blob).getInt
-          val intSize = 4
-          val data = blob.slice(intSize, intSize + size)
-          val schema = new String(blob.slice(intSize + size, blobSize), 
StandardCharsets.UTF_8)
-          transformLocalRelation(Option(schema), Option(data))
+          val localRelation = proto.Relation
+            .newBuilder()
+            .getLocalRelation
+            .getParserForType
+            .parseFrom(blockData.toInputStream())
+          transformLocalRelation(localRelation)
         } finally {
           blockManager.releaseLock(blockId)
         }
@@ -944,79 +940,70 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
-  private def transformLocalRelation(
-      schema: Option[String],
-      data: Option[Array[Byte]]): LogicalPlan = {
-    val optStruct = schema.map { schemaStr =>
+  private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
+    var schema: StructType = null
+    if (rel.hasSchema) {
       val schemaType = DataType.parseTypeWithFallback(
-        schemaStr,
+        rel.getSchema,
         parseDatatypeString,
         fallbackParser = DataType.fromJson)
-      schemaType match {
+      schema = schemaType match {
         case s: StructType => s
         case d => StructType(Seq(StructField("value", d)))
       }
     }
-    data
-      .map { dataBytes =>
-        val (rows, structType) =
-          ArrowConverters.fromBatchWithSchemaIterator(Iterator(dataBytes), 
TaskContext.get())
-        if (structType == null) {
-          throw InvalidPlanInput(s"Input data for LocalRelation does not 
produce a schema.")
-        }
-        val attributes = structType.toAttributes
-        val proj = UnsafeProjection.create(attributes, attributes)
-        val data = rows.map(proj)
-        optStruct
-          .map { struct =>
-            def normalize(dt: DataType): DataType = dt match {
-              case udt: UserDefinedType[_] => normalize(udt.sqlType)
-              case StructType(fields) =>
-                val newFields = fields.zipWithIndex.map {
-                  case (StructField(_, dataType, nullable, metadata), i) =>
-                    StructField(s"col_$i", normalize(dataType), nullable, 
metadata)
-                }
-                StructType(newFields)
-              case ArrayType(elementType, containsNull) =>
-                ArrayType(normalize(elementType), containsNull)
-              case MapType(keyType, valueType, valueContainsNull) =>
-                MapType(normalize(keyType), normalize(valueType), 
valueContainsNull)
-              case _ => dt
+
+    if (rel.hasData) {
+      val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
+        Iterator(rel.getData.toByteArray),
+        TaskContext.get())
+      if (structType == null) {
+        throw InvalidPlanInput(s"Input data for LocalRelation does not produce 
a schema.")
+      }
+      val attributes = structType.toAttributes
+      val proj = UnsafeProjection.create(attributes, attributes)
+      val data = rows.map(proj)
+
+      if (schema == null) {
+        logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
+      } else {
+        def normalize(dt: DataType): DataType = dt match {
+          case udt: UserDefinedType[_] => normalize(udt.sqlType)
+          case StructType(fields) =>
+            val newFields = fields.zipWithIndex.map {
+              case (StructField(_, dataType, nullable, metadata), i) =>
+                StructField(s"col_$i", normalize(dataType), nullable, metadata)
             }
-            val normalized = normalize(struct).asInstanceOf[StructType]
-            val project = Dataset
-              .ofRows(
-                session,
-                logicalPlan = logical.LocalRelation(
-                  normalize(structType).asInstanceOf[StructType].toAttributes))
-              .toDF(normalized.names: _*)
-              .to(normalized)
-              .logicalPlan
-              .asInstanceOf[Project]
-
-            val proj = UnsafeProjection.create(project.projectList, 
project.child.output)
-            logical.LocalRelation(struct.toAttributes, 
data.map(proj).map(_.copy()).toSeq)
-          }
-          .getOrElse {
-            logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
-          }
+            StructType(newFields)
+          case ArrayType(elementType, containsNull) =>
+            ArrayType(normalize(elementType), containsNull)
+          case MapType(keyType, valueType, valueContainsNull) =>
+            MapType(normalize(keyType), normalize(valueType), 
valueContainsNull)
+          case _ => dt
+        }
+
+        val normalized = normalize(schema).asInstanceOf[StructType]
+
+        val project = Dataset
+          .ofRows(
+            session,
+            logicalPlan =
+              
logical.LocalRelation(normalize(structType).asInstanceOf[StructType].toAttributes))
+          .toDF(normalized.names: _*)
+          .to(normalized)
+          .logicalPlan
+          .asInstanceOf[Project]
+
+        val proj = UnsafeProjection.create(project.projectList, 
project.child.output)
+        logical.LocalRelation(schema.toAttributes, 
data.map(proj).map(_.copy()).toSeq)
       }
-      .getOrElse {
-        optStruct
-          .map { struct =>
-            LocalRelation(struct.toAttributes, data = Seq.empty)
-          }
-          .getOrElse {
-            throw InvalidPlanInput(
-              s"Schema for LocalRelation is required when the input data is 
not provided.")
-          }
+    } else {
+      if (schema == null) {
+        throw InvalidPlanInput(
+          s"Schema for LocalRelation is required when the input data is not 
provided.")
       }
-  }
-
-  private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
-    transformLocalRelation(
-      if (rel.hasSchema) Some(rel.getSchema) else None,
-      if (rel.hasData) Some(rel.getData.toByteArray) else None)
+      LocalRelation(schema.toAttributes, data = Seq.empty)
+    }
   }
 
   /** Parse as DDL, with a fallback to JSON. Throws an exception if if fails 
to parse. */
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 7898645dca5..69a4d6b9ccc 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -1576,7 +1576,7 @@ class 
CachedLocalRelation(google.protobuf.message.Message):
     sessionId: builtins.str
     """(Required) An identifier of the Spark SQL session in which the user 
created the local relation."""
     hash: builtins.str
-    """(Required) A sha-256 hash of the serialized local relation."""
+    """(Required) A sha-256 hash of the serialized local relation in proto, 
see LocalRelation."""
     def __init__(
         self,
         *,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to