Torrent-ish broadcast based on BlockManager.

Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/4602e2bf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/4602e2bf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/4602e2bf

Branch: refs/heads/master
Commit: 4602e2bf6e2ae9202d71b4e28222a0b525380e7e
Parents: f9973ca
Author: Mosharaf Chowdhury <mosha...@cs.berkeley.edu>
Authored: Sun Oct 13 18:46:03 2013 -0700
Committer: Mosharaf Chowdhury <mosha...@cs.berkeley.edu>
Committed: Wed Oct 16 21:33:33 2013 -0700

----------------------------------------------------------------------
 .../spark/broadcast/TorrentBroadcast.scala      | 245 +++++++++++++++++++
 .../org/apache/spark/storage/BlockManager.scala |   5 +-
 .../spark/storage/BlockManagerMasterActor.scala |   5 +-
 3 files changed, 251 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/4602e2bf/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala 
b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
new file mode 100644
index 0000000..ad1d29a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+
+import scala.math
+import scala.util.Random
+
+import org.apache.spark._
+import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: 
Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
+
+  def value = value_
+
+  def broadcastId = BlockManager.toBroadcastId(id)
+
+  TorrentBroadcast.synchronized {
+    SparkEnv.get.blockManager.putSingle(broadcastId, value_, 
StorageLevel.MEMORY_AND_DISK, false)
+  }
+
+  @transient var arrayOfBlocks: Array[TorrentBlock] = null
+  @transient var totalBlocks = -1
+  @transient var totalBytes = -1
+  @transient var hasBlocks = 0
+
+  if (!isLocal) {
+    sendBroadcast()
+  }
+
+  def sendBroadcast() {
+    var tInfo = TorrentBroadcast.blockifyObject(value_)
+
+    totalBlocks = tInfo.totalBlocks
+    totalBytes = tInfo.totalBytes
+    hasBlocks = tInfo.totalBlocks
+
+    // Store meta-info
+    val metaId = broadcastId + "_meta"
+    val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
+    TorrentBroadcast.synchronized {
+      SparkEnv.get.blockManager.putSingle(
+        metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+    }
+
+    // Store individual pieces
+    for (i <- 0 until totalBlocks) {
+      val pieceId = broadcastId + "_piece_" + i
+      TorrentBroadcast.synchronized {
+        SparkEnv.get.blockManager.putSingle(
+          pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+      }
+    }
+  }
+
+  // Called by JVM when deserializing an object
+  private def readObject(in: ObjectInputStream) {
+    in.defaultReadObject()
+    TorrentBroadcast.synchronized {
+      SparkEnv.get.blockManager.getSingle(broadcastId) match {
+        case Some(x) =>
+          value_ = x.asInstanceOf[T]
+
+        case None =>
+          val start = System.nanoTime
+          logInfo("Started reading broadcast variable " + id)
+          
+          // Master might send invalid values
+          resetWorkerVariables()
+
+          if (receiveBroadcast(id)) {
+            value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, 
totalBytes, totalBlocks)
+            SparkEnv.get.blockManager.putSingle(broadcastId, value_, 
StorageLevel.MEMORY_AND_DISK, false)
+
+            // Remove arrayOfBlocks from memory once value_ is on local cache
+            resetWorkerVariables()
+          }  else {
+            logError("Reading broadcast variable " + id + " failed")
+          }
+
+          val time = (System.nanoTime - start) / 1e9
+          logInfo("Reading broadcast variable " + id + " took " + time + " s")
+      }
+    }
+  }
+
+  private def resetWorkerVariables() {
+    arrayOfBlocks = null
+    totalBytes = -1
+    totalBlocks = -1
+    hasBlocks = 0
+  }
+
+  def receiveBroadcast(variableID: Long): Boolean = {
+    if (totalBlocks > 0 && totalBlocks == hasBlocks)
+      return true
+
+    // Receive meta-info
+    val metaId = broadcastId + "_meta"
+    var attemptId = 10
+    while (attemptId > 0 && totalBlocks == -1) {
+      TorrentBroadcast.synchronized {
+        SparkEnv.get.blockManager.getSingle(metaId) match {
+          case Some(x) => 
+            val tInfo = x.asInstanceOf[TorrentInfo]
+            totalBlocks = tInfo.totalBlocks
+            totalBytes = tInfo.totalBytes
+            arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
+            hasBlocks = 0
+          
+          case None => 
+            Thread.sleep(500)
+        }
+      }
+      attemptId -= 1
+    }
+    if (totalBlocks == -1)
+      return false
+
+    // Receive actual blocks
+    val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 
1).toList)
+    for (pid <- recvOrder) {
+      val pieceId = broadcastId + "_piece_" + pid
+      TorrentBroadcast.synchronized {
+        SparkEnv.get.blockManager.getSingle(pieceId) match {
+          case Some(x) => 
+            arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
+            hasBlocks += 1
+            SparkEnv.get.blockManager.putSingle(
+              pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+          
+          case None => 
+            throw new SparkException(
+              "Failed to get " + pieceId + " of " + broadcastId)
+        }
+      }
+    }
+
+    (hasBlocks == totalBlocks)
+  }
+
+}
+
+private object TorrentBroadcast
+extends Logging {
+
+  private var initialized = false
+
+  def initialize(_isDriver: Boolean) {
+    synchronized {
+      if (!initialized) {
+        initialized = true
+      }
+    }
+  }
+  
+  def stop() {
+    initialized = false
+  }
+
+  val BlockSize = System.getProperty("spark.broadcast.blockSize", 
"2048").toInt * 1024
+  
+  def blockifyObject[IN](obj: IN): TorrentInfo = {
+    val byteArray = Utils.serialize[IN](obj)
+    val bais = new ByteArrayInputStream(byteArray)
+
+    var blockNum = (byteArray.length / BlockSize)
+    if (byteArray.length % BlockSize != 0)
+      blockNum += 1
+
+    var retVal = new Array[TorrentBlock](blockNum)
+    var blockID = 0
+
+    for (i <- 0 until (byteArray.length, BlockSize)) {
+      val thisBlockSize = math.min(BlockSize, byteArray.length - i)
+      var tempByteArray = new Array[Byte](thisBlockSize)
+      val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+      retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
+      blockID += 1
+    }
+    bais.close()
+
+    var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+    tInfo.hasBlocks = blockNum
+
+    return tInfo
+  }
+
+  def unBlockifyObject[OUT](arrayOfBlocks: Array[TorrentBlock],
+                            totalBytes: Int, 
+                            totalBlocks: Int): OUT = {
+    var retByteArray = new Array[Byte](totalBytes)
+    for (i <- 0 until totalBlocks) {
+      System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+        i * BlockSize, arrayOfBlocks(i).byteArray.length)
+    }
+    Utils.deserialize[OUT](retByteArray, 
Thread.currentThread.getContextClassLoader)
+  }
+
+}
+
+private[spark] case class TorrentBlock(
+    blockID: Int, 
+    byteArray: Array[Byte]) 
+  extends Serializable
+
+private[spark] case class TorrentInfo(
+    @transient arrayOfBlocks : Array[TorrentBlock],
+    totalBlocks: Int, 
+    totalBytes: Int) 
+  extends Serializable {
+  
+  @transient var hasBlocks = 0 
+}
+
+private[spark] class TorrentBroadcastFactory
+  extends BroadcastFactory {
+  
+  def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
+
+  def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+    new TorrentBroadcast[T](value_, isLocal, id)
+
+  def stop() { TorrentBroadcast.stop() }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/4602e2bf/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 801f88a..c67a615 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -21,6 +21,7 @@ import java.io.{InputStream, OutputStream}
 import java.nio.{ByteBuffer, MappedByteBuffer}
 
 import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
+import scala.util.Random
 
 import akka.actor.{ActorSystem, Cancellable, Props}
 import akka.dispatch.{Await, Future}
@@ -269,7 +270,7 @@ private[spark] class BlockManager(
   }
 
   /**
-   * Actually send a UpdateBlockInfo message. Returns the mater's response,
+   * Actually send a UpdateBlockInfo message. Returns the master's response,
    * which will be true if the block was successfully recorded and false if
    * the slave needs to re-register.
    */
@@ -478,7 +479,7 @@ private[spark] class BlockManager(
     }
     logDebug("Getting remote block " + blockId)
     // Get locations of block
-    val locations = master.getLocations(blockId)
+    val locations = Random.shuffle(master.getLocations(blockId))
 
     // Get block from remote locations
     for (loc <- locations) {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/4602e2bf/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 633230c..8b2a812 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -227,9 +227,10 @@ class BlockManagerMasterActor(val isLocal: Boolean) 
extends Actor with Logging {
   }
 
   private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: 
ActorRef) {
-    if (id.executorId == "<driver>" && !isLocal) {
+    /* if (id.executorId == "<driver>" && !isLocal) {
       // Got a register message from the master node; don't register it
-    } else if (!blockManagerInfo.contains(id)) {
+    } else */
+    if (!blockManagerInfo.contains(id)) {
       blockManagerIdByExecutor.get(id.executorId) match {
         case Some(manager) =>
           // A block manager of the same executor already exists.

Reply via email to