This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0c3f4cf1632 [SPARK-42144][CORE][SQL] Handle null string values in 
StageDataWrapper/StreamBlockData/StreamingQueryData
0c3f4cf1632 is described below

commit 0c3f4cf1632e48e52351d1b0664bbe6d0ae4e882
Author: yangjie01 <[email protected]>
AuthorDate: Sun Jan 22 13:49:25 2023 -0800

    [SPARK-42144][CORE][SQL] Handle null string values in 
StageDataWrapper/StreamBlockData/StreamingQueryData
    
    ### What changes were proposed in this pull request?
    Similar to #39666, this PR handles null string values in 
StageDataWrapper/StreamBlockData/StreamingQueryData
    
    ### Why are the changes needed?
    Properly handles null string values in the protobuf serializer.
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    New UTs
    
    Closes #39683 from LuciferYang/SPARK-42144.
    
    Lead-authored-by: yangjie01 <[email protected]>
    Co-authored-by: YangJie <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../apache/spark/status/protobuf/store_types.proto | 28 +++++------
 .../protobuf/StageDataWrapperSerializer.scala      | 31 ++++++------
 .../protobuf/StreamBlockDataSerializer.scala       | 20 ++++----
 .../protobuf/KVStoreProtobufSerializerSuite.scala  | 58 +++++++++++++++-------
 .../sql/StreamingQueryDataSerializer.scala         | 21 +++++---
 .../sql/KVStoreProtobufSerializerSuite.scala       | 31 ++++++++----
 6 files changed, 114 insertions(+), 75 deletions(-)

diff --git 
a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto 
b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
index ab6861057c9..aacf49bd401 100644
--- a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
+++ b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
@@ -229,10 +229,10 @@ message ApplicationInfoWrapper {
 }
 
 message StreamBlockData {
-  string name = 1;
-  string executor_id = 2;
-  string host_port = 3;
-  string storage_level = 4;
+  optional string name = 1;
+  optional string executor_id = 2;
+  optional string host_port = 3;
+  optional string storage_level = 4;
   bool use_memory = 5;
   bool use_disk = 6;
   bool deserialized = 7;
@@ -495,9 +495,9 @@ message RDDOperationGraphWrapper {
 }
 
 message StreamingQueryData {
-  string name = 1;
-  string id = 2;
-  string run_id = 3;
+  optional string name = 1;
+  optional string id = 2;
+  optional string run_id = 3;
   bool is_active = 4;
   optional string exception = 5;
   int64 start_timestamp = 6;
@@ -518,10 +518,10 @@ message TaskData {
   int64 launch_time = 5;
   optional int64 result_fetch_start = 6;
   optional int64 duration = 7;
-  string executor_id = 8;
-  string host = 9;
-  string status = 10;
-  string task_locality = 11;
+  optional string executor_id = 8;
+  optional string host = 9;
+  optional string status = 10;
+  optional string task_locality = 11;
   bool speculative = 12;
   repeated AccumulableInfo accumulator_updates = 13;
   optional string error_message = 14;
@@ -582,10 +582,10 @@ message StageData {
   int64 shuffle_write_time = 37;
   int64 shuffle_write_records = 38;
 
-  string name = 39;
+  optional string name = 39;
   optional string description = 40;
-  string details = 41;
-  string scheduling_pool = 42;
+  optional string details = 41;
+  optional string scheduling_pool = 42;
 
   repeated int64 rdd_ids = 43;
   repeated AccumulableInfo accumulator_updates = 44;
diff --git 
a/core/src/main/scala/org/apache/spark/status/protobuf/StageDataWrapperSerializer.scala
 
b/core/src/main/scala/org/apache/spark/status/protobuf/StageDataWrapperSerializer.scala
index dc72c3ed467..25394c1a719 100644
--- 
a/core/src/main/scala/org/apache/spark/status/protobuf/StageDataWrapperSerializer.scala
+++ 
b/core/src/main/scala/org/apache/spark/status/protobuf/StageDataWrapperSerializer.scala
@@ -24,7 +24,7 @@ import org.apache.commons.collections4.MapUtils
 
 import org.apache.spark.status.StageDataWrapper
 import org.apache.spark.status.api.v1.{ExecutorMetricsDistributions, 
ExecutorPeakMetricsDistributions, InputMetricDistributions, InputMetrics, 
OutputMetricDistributions, OutputMetrics, ShufflePushReadMetricDistributions, 
ShufflePushReadMetrics, ShuffleReadMetricDistributions, ShuffleReadMetrics, 
ShuffleWriteMetricDistributions, ShuffleWriteMetrics, SpeculationStageSummary, 
StageData, TaskData, TaskMetricDistributions, TaskMetrics}
-import org.apache.spark.status.protobuf.Utils.getOptional
+import org.apache.spark.status.protobuf.Utils._
 import org.apache.spark.util.Utils.weakIntern
 
 class StageDataWrapperSerializer extends ProtobufSerDe[StageDataWrapper] {
@@ -86,12 +86,12 @@ class StageDataWrapperSerializer extends 
ProtobufSerDe[StageDataWrapper] {
       .setShuffleWriteBytes(stageData.shuffleWriteBytes)
       .setShuffleWriteTime(stageData.shuffleWriteTime)
       .setShuffleWriteRecords(stageData.shuffleWriteRecords)
-      .setName(stageData.name)
-      .setDetails(stageData.details)
-      .setSchedulingPool(stageData.schedulingPool)
       .setResourceProfileId(stageData.resourceProfileId)
       .setIsShufflePushEnabled(stageData.isShufflePushEnabled)
       .setShuffleMergersCount(stageData.shuffleMergersCount)
+    setStringField(stageData.name, stageDataBuilder.setName)
+    setStringField(stageData.details, stageDataBuilder.setDetails)
+    setStringField(stageData.schedulingPool, 
stageDataBuilder.setSchedulingPool)
     stageData.submissionTime.foreach { d =>
       stageDataBuilder.setSubmissionTime(d.getTime)
     }
@@ -149,13 +149,13 @@ class StageDataWrapperSerializer extends 
ProtobufSerDe[StageDataWrapper] {
       .setAttempt(t.attempt)
       .setPartitionId(t.partitionId)
       .setLaunchTime(t.launchTime.getTime)
-      .setExecutorId(t.executorId)
-      .setHost(t.host)
-      .setStatus(t.status)
-      .setTaskLocality(t.taskLocality)
       .setSpeculative(t.speculative)
       .setSchedulerDelay(t.schedulerDelay)
       .setGettingResultTime(t.gettingResultTime)
+    setStringField(t.executorId, taskDataBuilder.setExecutorId)
+    setStringField(t.host, taskDataBuilder.setHost)
+    setStringField(t.status, taskDataBuilder.setStatus)
+    setStringField(t.taskLocality, taskDataBuilder.setTaskLocality)
     t.resultFetchStart.foreach { rfs =>
       taskDataBuilder.setResultFetchStart(rfs.getTime)
     }
@@ -465,10 +465,10 @@ class StageDataWrapperSerializer extends 
ProtobufSerDe[StageDataWrapper] {
       shuffleWriteBytes = binary.getShuffleWriteBytes,
       shuffleWriteTime = binary.getShuffleWriteTime,
       shuffleWriteRecords = binary.getShuffleWriteRecords,
-      name = binary.getName,
+      name = getStringField(binary.hasName, () => binary.getName),
       description = description,
-      details = binary.getDetails,
-      schedulingPool = weakIntern(binary.getSchedulingPool),
+      details = getStringField(binary.hasDetails, () => binary.getDetails),
+      schedulingPool = getStringField(binary.hasSchedulingPool, () => 
binary.getSchedulingPool),
       rddIds = binary.getRddIdsList.asScala.map(_.toInt),
       accumulatorUpdates = accumulatorUpdates,
       tasks = tasks,
@@ -636,10 +636,11 @@ class StageDataWrapperSerializer extends 
ProtobufSerDe[StageDataWrapper] {
       launchTime = new Date(binary.getLaunchTime),
       resultFetchStart = resultFetchStart,
       duration = duration,
-      executorId = weakIntern(binary.getExecutorId),
-      host = weakIntern(binary.getHost),
-      status = weakIntern(binary.getStatus),
-      taskLocality = weakIntern(binary.getTaskLocality),
+      executorId = getStringField(binary.hasExecutorId, () => 
weakIntern(binary.getExecutorId)),
+      host = getStringField(binary.hasHost, () => weakIntern(binary.getHost)),
+      status = getStringField(binary.hasStatus, () => 
weakIntern(binary.getStatus)),
+      taskLocality =
+        getStringField(binary.hasTaskLocality, () => 
weakIntern(binary.getTaskLocality)),
       speculative = binary.getSpeculative,
       accumulatorUpdates = accumulatorUpdates,
       errorMessage = getOptional(binary.hasErrorMessage, 
binary.getErrorMessage),
diff --git 
a/core/src/main/scala/org/apache/spark/status/protobuf/StreamBlockDataSerializer.scala
 
b/core/src/main/scala/org/apache/spark/status/protobuf/StreamBlockDataSerializer.scala
index 5dac03bb337..fff7cf8ffc4 100644
--- 
a/core/src/main/scala/org/apache/spark/status/protobuf/StreamBlockDataSerializer.scala
+++ 
b/core/src/main/scala/org/apache/spark/status/protobuf/StreamBlockDataSerializer.scala
@@ -18,17 +18,18 @@
 package org.apache.spark.status.protobuf
 
 import org.apache.spark.status.StreamBlockData
+import org.apache.spark.status.protobuf.Utils.{getStringField, setStringField}
 import org.apache.spark.util.Utils.weakIntern
 
 class StreamBlockDataSerializer extends ProtobufSerDe[StreamBlockData] {
 
   override def serialize(data: StreamBlockData): Array[Byte] = {
     val builder = StoreTypes.StreamBlockData.newBuilder()
-      .setName(data.name)
-      .setExecutorId(data.executorId)
-      .setHostPort(data.hostPort)
-      .setStorageLevel(data.storageLevel)
-      .setUseMemory(data.useMemory)
+    setStringField(data.name, builder.setName)
+    setStringField(data.executorId, builder.setExecutorId)
+    setStringField(data.hostPort, builder.setHostPort)
+    setStringField(data.storageLevel, builder.setStorageLevel)
+    builder.setUseMemory(data.useMemory)
       .setUseDisk(data.useDisk)
       .setDeserialized(data.deserialized)
       .setMemSize(data.memSize)
@@ -39,10 +40,11 @@ class StreamBlockDataSerializer extends 
ProtobufSerDe[StreamBlockData] {
   override def deserialize(bytes: Array[Byte]): StreamBlockData = {
     val binary = StoreTypes.StreamBlockData.parseFrom(bytes)
     new StreamBlockData(
-      name = binary.getName,
-      executorId = weakIntern(binary.getExecutorId),
-      hostPort = weakIntern(binary.getHostPort),
-      storageLevel = weakIntern(binary.getStorageLevel),
+      name = getStringField(binary.hasName, () => binary.getName),
+      executorId = getStringField(binary.hasExecutorId, () => 
weakIntern(binary.getExecutorId)),
+      hostPort = getStringField(binary.hasHostPort, () => 
weakIntern(binary.getHostPort)),
+      storageLevel =
+        getStringField(binary.hasStorageLevel, () => 
weakIntern(binary.getStorageLevel)),
       useMemory = binary.getUseMemory,
       useDisk = binary.getUseDisk,
       deserialized = binary.getDeserialized,
diff --git 
a/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala
 
b/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala
index d4c79adf2ec..de2021fb60e 100644
--- 
a/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/status/protobuf/KVStoreProtobufSerializerSuite.scala
@@ -516,7 +516,7 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
   }
 
   test("Stream Block Data") {
-    val input = new StreamBlockData(
+    val normal = new StreamBlockData(
       name = "a",
       executorId = "executor-1",
       hostPort = "123",
@@ -526,17 +526,29 @@ class KVStoreProtobufSerializerSuite extends 
SparkFunSuite {
       deserialized = true,
       memSize = 1L,
       diskSize = 2L)
-    val bytes = serializer.serialize(input)
-    val result = serializer.deserialize(bytes, classOf[StreamBlockData])
-    assert(result.name == input.name)
-    assert(result.executorId == input.executorId)
-    assert(result.hostPort == input.hostPort)
-    assert(result.storageLevel == input.storageLevel)
-    assert(result.useMemory == input.useMemory)
-    assert(result.useDisk == input.useDisk)
-    assert(result.deserialized == input.deserialized)
-    assert(result.memSize == input.memSize)
-    assert(result.diskSize == input.diskSize)
+    val withNull = new StreamBlockData(
+      name = null,
+      executorId = null,
+      hostPort = null,
+      storageLevel = null,
+      useMemory = true,
+      useDisk = false,
+      deserialized = true,
+      memSize = 1L,
+      diskSize = 2L)
+    Seq(normal, withNull).foreach { input =>
+      val bytes = serializer.serialize(input)
+      val result = serializer.deserialize(bytes, classOf[StreamBlockData])
+      assert(result.name == input.name)
+      assert(result.executorId == input.executorId)
+      assert(result.hostPort == input.hostPort)
+      assert(result.storageLevel == input.storageLevel)
+      assert(result.useMemory == input.useMemory)
+      assert(result.useDisk == input.useDisk)
+      assert(result.deserialized == input.deserialized)
+      assert(result.memSize == input.memSize)
+      assert(result.diskSize == input.diskSize)
+    }
   }
 
   test("Resource Profile") {
@@ -956,6 +968,14 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite 
{
   }
 
   test("Stage Data") {
+    testStageDataSerDe("name", "test details", "test scheduling pool")
+  }
+
+  test("Stage Data with null strings") {
+    testStageDataSerDe(null, null, null)
+  }
+
+  private def testStageDataSerDe(name: String, details: String, 
schedulingPool: String): Unit = {
     val accumulatorUpdates = Seq(
       new AccumulableInfo(1L, "duration", Some("update"), "value1"),
       new AccumulableInfo(2L, "duration2", None, "value2")
@@ -1038,10 +1058,10 @@ class KVStoreProtobufSerializerSuite extends 
SparkFunSuite {
       launchTime = new Date(1123456L),
       resultFetchStart = Some(new Date(1223456L)),
       duration = Some(110000L),
-      executorId = "executor_id_2",
-      host = "host_name_2",
-      status = "SUCCESS",
-      taskLocality = "LOCAL",
+      executorId = null,
+      host = null,
+      status = null,
+      taskLocality = null,
       speculative = false,
       accumulatorUpdates = accumulatorUpdates,
       errorMessage = Some("error_2"),
@@ -1229,10 +1249,10 @@ class KVStoreProtobufSerializerSuite extends 
SparkFunSuite {
       shuffleWriteBytes = 41L,
       shuffleWriteTime = 42L,
       shuffleWriteRecords = 43L,
-      name = "name",
+      name = name,
       description = Some("test description"),
-      details = "test details",
-      schedulingPool = "test scheduling pool",
+      details = details,
+      schedulingPool = schedulingPool,
       rddIds = Seq(1, 2, 3, 4, 5, 6),
       accumulatorUpdates = accumulatorUpdates,
       tasks = tasks,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/StreamingQueryDataSerializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/StreamingQueryDataSerializer.scala
index 70f8bedf91b..65758594c40 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/StreamingQueryDataSerializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/StreamingQueryDataSerializer.scala
@@ -21,16 +21,18 @@ import java.util.UUID
 
 import org.apache.spark.sql.streaming.ui.StreamingQueryData
 import org.apache.spark.status.protobuf.{ProtobufSerDe, StoreTypes}
-import org.apache.spark.status.protobuf.Utils.getOptional
+import org.apache.spark.status.protobuf.Utils._
 
 class StreamingQueryDataSerializer extends ProtobufSerDe[StreamingQueryData] {
 
   override def serialize(data: StreamingQueryData): Array[Byte] = {
     val builder = StoreTypes.StreamingQueryData.newBuilder()
-      .setId(data.id.toString)
-      .setRunId(data.runId)
-      .setIsActive(data.isActive)
-    Option(data.name).foreach(builder.setName)
+    setStringField(data.name, builder.setName)
+    if (data.id != null) {
+      builder.setId(data.id.toString)
+    }
+    setStringField(data.runId, builder.setRunId)
+    builder.setIsActive(data.isActive)
     data.exception.foreach(builder.setException)
     builder.setStartTimestamp(data.startTimestamp)
     data.endTimestamp.foreach(builder.setEndTimestamp)
@@ -43,10 +45,13 @@ class StreamingQueryDataSerializer extends 
ProtobufSerDe[StreamingQueryData] {
       getOptional(data.hasException, () => data.getException)
     val endTimestamp =
       getOptional(data.hasEndTimestamp, () => data.getEndTimestamp)
+    val id = if (data.hasId) {
+      UUID.fromString(data.getId)
+    } else null
     new StreamingQueryData(
-      name = data.getName,
-      id = UUID.fromString(data.getId),
-      runId = data.getRunId,
+      name = getStringField(data.hasName, () => data.getName),
+      id = id,
+      runId = getStringField(data.hasRunId, () => data.getRunId),
       isActive = data.getIsActive,
       exception = exception,
       startTimestamp = data.getStartTimestamp,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala
index c220ca1c96f..3c2d2533275 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala
@@ -241,7 +241,7 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
 
   test("StreamingQueryData") {
     val id = UUID.randomUUID()
-    val input = new StreamingQueryData(
+    val normal = new StreamingQueryData(
       name = "some-query",
       id = id,
       runId = s"run-id-$id",
@@ -250,14 +250,25 @@ class KVStoreProtobufSerializerSuite extends 
SparkFunSuite {
       startTimestamp = 1L,
       endTimestamp = Some(2L)
     )
-    val bytes = serializer.serialize(input)
-    val result = serializer.deserialize(bytes, classOf[StreamingQueryData])
-    assert(result.name == input.name)
-    assert(result.id == input.id)
-    assert(result.runId == input.runId)
-    assert(result.isActive == input.isActive)
-    assert(result.exception == input.exception)
-    assert(result.startTimestamp == input.startTimestamp)
-    assert(result.endTimestamp == input.endTimestamp)
+    val withNull = new StreamingQueryData(
+      name = null,
+      id = null,
+      runId = null,
+      isActive = false,
+      exception = None,
+      startTimestamp = 1L,
+      endTimestamp = None
+    )
+    Seq(normal, withNull).foreach { input =>
+      val bytes = serializer.serialize(input)
+      val result = serializer.deserialize(bytes, classOf[StreamingQueryData])
+      assert(result.name == input.name)
+      assert(result.id == input.id)
+      assert(result.runId == input.runId)
+      assert(result.isActive == input.isActive)
+      assert(result.exception == input.exception)
+      assert(result.startTimestamp == input.startTimestamp)
+      assert(result.endTimestamp == input.endTimestamp)
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to