Github user srowen commented on a diff in the pull request:
https://github.com/apache/spark/pull/22404#discussion_r217154410
--- Diff: core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
---
@@ -692,5 +681,238 @@ private[spark] class PythonBroadcast(@transient var
path: String) extends Serial
}
super.finalize()
}
+
+ def setupEncryptionServer(): Array[Any] = {
+ encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
+ override def handleConnection(sock: Socket): Unit = {
+ val env = SparkEnv.get
+ val in = sock.getInputStream()
+ val dir = new File(Utils.getLocalDir(env.conf))
+ val file = File.createTempFile("broadcast", "", dir)
+ path = file.getAbsolutePath
+ val out = env.serializerManager.wrapForEncryption(new
FileOutputStream(path))
+ DechunkedInputStream.dechunkAndCopyToOutput(in, out)
+ }
+ }
+ Array(encryptionServer.port, encryptionServer.secret)
+ }
+
+ def waitTillDataReceived(): Unit = encryptionServer.getResult()
}
// scalastyle:on no.finalize
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending broadcast data.
+ * Tested from python tests.
+ */
+private[spark] class DechunkedInputStream(wrapped: InputStream) extends
InputStream with Logging {
+ private val din = new DataInputStream(wrapped)
+ private var remainingInChunk = din.readInt()
+
+ override def read(): Int = {
+ val into = new Array[Byte](1)
+ val n = read(into, 0, 1)
+ if (n == -1) {
+ -1
+ } else {
+ // if you just cast a byte to an int, then anything > 127 is
negative, which is interpreted
+ // as an EOF
+ val b = into(0)
+ if (b < 0) {
--- End diff --
Pardon is this just trying to treat it as an unsigned byte? then just `b &
0xFF`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]