mridulm commented on PR #40629:
URL: https://github.com/apache/spark/pull/40629#issuecomment-1498460524

   A strawman proposal:
   ```
   diff --git 
a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala 
b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
   index 7b430766851..d9632964e3d 100644
   --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
   +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
   @@ -29,6 +29,7 @@ import scala.util.Random
    import org.apache.spark._
    import org.apache.spark.internal.{config, Logging}
    import org.apache.spark.io.CompressionCodec
   +import org.apache.spark.network.util.JavaUtils
    import org.apache.spark.serializer.Serializer
    import org.apache.spark.storage._
    import org.apache.spark.util.{KeyLock, Utils}
   @@ -95,12 +96,16 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: 
T, id: Long, serializedO
    
      private val broadcastId = BroadcastBlockId(id)
    
   -  /** Total number of blocks this broadcast variable contains. */
   -  private val numBlocks: Int = writeBlocks(obj)
   -
      /** The checksum for all the blocks. */
      private var checksums: Array[Int] = _
    
   +  /** Total number of blocks this broadcast variable contains. */
   +  private val (singleBlockData: Array[Byte], numBlocks: Int) = 
writeBlocks(obj)
   +  assert(1 != numBlocks || null != singleBlockData)
   +  assert(1 == numBlocks || null == singleBlockData)
   +  assert(null != checksums || null != singleBlockData)
   +
   +
      override protected def getValue() = synchronized {
        val memoized: T = if (_value == null) null.asInstanceOf[T] else 
_value.get
        if (memoized != null) {
   @@ -135,7 +140,23 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: 
T, id: Long, serializedO
       * @param value the object to divide
       * @return number of blocks this broadcast variable is divided into
       */
   -  private def writeBlocks(value: T): Int = {
   +  private def writeBlocks(value: T): (Array[Byte], Int) = {
   +
   +    val blocks = {
   +      try {
   +        TorrentBroadcast.blockifyObject(value, blockSize, 
SparkEnv.get.serializer, compressionCodec)
   +      } catch {
   +        case t: Throwable =>
   +          logError(s"Store broadcast $broadcastId failed, cannot serialize 
object")
   +          throw t
   +      }
   +    }
   +
   +    if (1 == blocks.length) {
   +      // no checksum
   +      return (JavaUtils.bufferToArray(blocks(0)), 1)
   +    }
   +
        import StorageLevel._
        val blockManager = SparkEnv.get.blockManager
        if (serializedOnly && !isLocalMaster) {
   @@ -156,8 +177,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: 
T, id: Long, serializedO
          }
        }
        try {
   -      val blocks =
   -        TorrentBroadcast.blockifyObject(value, blockSize, 
SparkEnv.get.serializer, compressionCodec)
          if (checksumEnabled) {
            checksums = new Array[Int](blocks.length)
          }
   @@ -172,7 +191,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: 
T, id: Long, serializedO
                s"in local BlockManager")
            }
          }
   -      blocks.length
   +      (null, blocks.length)
        } catch {
          case t: Throwable =>
            logError(s"Store broadcast $broadcastId fail, remove all pieces of 
the broadcast")
   @@ -186,6 +205,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: 
T, id: Long, serializedO
        // Fetch chunks of data. Note that all these chunks are stored in the 
BlockManager and reported
        // to the driver, so other executors can pull these chunks from this 
executor as well.
        val blocks = new Array[BlockData](numBlocks)
   +    if (null != singleBlockData) {
   +      assert(1 == numBlocks)
   +      blocks(0) = new ByteBufferBlockData(
   +        new ChunkedByteBuffer(ByteBuffer.wrap(singleBlockData)), false)
   +      return blocks
   +    }
   +
   +    assert(1 != numBlocks)
        val bm = SparkEnv.get.blockManager
    
        for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
   
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to