Github user zsxwing commented on a diff in the pull request:
https://github.com/apache/spark/pull/9947#discussion_r45805343
--- Diff: core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
---
@@ -300,6 +323,132 @@ private[netty] class NettyRpcEnv(
}
}
+ override def fileServer: RpcEnvFileServer = _fileServer
+
+ override def openChannel(uri: String): ReadableByteChannel = {
+ val parsedUri = new URI(uri)
+ require(parsedUri.getHost() != null, "Host name must be defined.")
+ require(parsedUri.getPort() > 0, "Port must be defined.")
+ require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty,
"Path must be defined.")
+
+ val pipe = Pipe.open()
+ val source = new FileDownloadChannel(pipe.source())
+ try {
+ val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
+ val callback = new FileDownloadCallback(pipe.sink(), source, client)
+ client.stream(parsedUri.getPath(), callback)
+ } catch {
+ case e: Exception =>
+ pipe.sink().close()
+ source.close()
+ throw e
+ }
+
+ source
+ }
+
+ private def downloadClient(host: String, port: Int): TransportClient = {
+ if (fileDownloadFactory == null) synchronized {
+ if (fileDownloadFactory == null) {
+ val module = "files"
+ val prefix = "spark.rpc.io."
+ val clone = conf.clone()
+
+ // Copy any RPC configuration that is not overridden in the
spark.files namespace.
+ conf.getAll.foreach { case (key, value) =>
+ if (key.startsWith(prefix)) {
+ val opt = key.substring(prefix.length())
+ clone.setIfMissing(s"spark.$module.io.$opt", value)
+ }
+ }
+
+ val ioThreads = clone.getInt("spark.files.io.threads", 1)
+ val downloadConf = SparkTransportConf.fromSparkConf(clone, module,
ioThreads)
+ val downloadContext = new TransportContext(downloadConf, new
NoOpRpcHandler(), true)
+ fileDownloadFactory =
downloadContext.createClientFactory(createClientBootstraps())
+ }
+ }
+ fileDownloadFactory.createClient(host, port)
+ }
+
+ private class FileDownloadChannel(source: ReadableByteChannel) extends
ReadableByteChannel {
+
+ @volatile private var error: Throwable = _
+
+ def setError(e: Throwable): Unit = error = e
+
+ override def read(dst: ByteBuffer): Int = {
+ if (error != null) {
+ throw error
+ }
+ source.read(dst)
+ }
+
+ override def close(): Unit = source.close()
+
+ override def isOpen(): Boolean = source.isOpen()
+
+ }
+
+ private class FileDownloadCallback(
+ sink: WritableByteChannel,
+ source: FileDownloadChannel,
+ client: TransportClient) extends StreamCallback {
+
+ override def onData(streamId: String, buf: ByteBuffer): Unit = {
+ while (buf.remaining() > 0) {
+ sink.write(buf)
+ }
+ }
+
+ override def onComplete(streamId: String): Unit = {
+ sink.close()
+ }
+
+ override def onFailure(streamId: String, cause: Throwable): Unit = {
+ logError(s"Error downloading stream $streamId.", cause)
+ source.setError(cause)
+ sink.close()
+ }
+
+ }
+
+ private class HttpBasedFileServer extends RpcEnvFileServer {
--- End diff --
Could you move this one to a new file? So that we don't need to resolve the
conflicts when adding new codes at the end of NettyRpcEnv.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]