WweiL commented on code in PR #42340:
URL: https://github.com/apache/spark/pull/42340#discussion_r1284077049
##########
core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala:
##########
@@ -29,39 +29,56 @@ import
org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTH
private[spark] object StreamingPythonRunner {
- def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner =
{
- new StreamingPythonRunner(func, connectUrl)
+ def apply(
+ func: PythonFunction,
+ connectUrl: String,
+ sessionId: String,
+ workerModule: String
+ ): StreamingPythonRunner = {
+ new StreamingPythonRunner(func, connectUrl, sessionId, workerModule)
}
}
-private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl:
String)
- extends Logging {
+private[spark] class StreamingPythonRunner(
+ func: PythonFunction,
+ connectUrl: String,
+ sessionId: String,
+ workerModule: String) extends Logging {
private val conf = SparkEnv.get.conf
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val envVars: java.util.Map[String, String] = func.envVars
private val pythonExec: String = func.pythonExec
+ private var pythonWorker: Option[Socket] = None
+ private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
protected val pythonVer: String = func.pythonVer
/**
* Initializes the Python worker for streaming functions. Sets up Spark
Connect session
* to be used with the functions.
*/
- def init(sessionId: String, workerModule: String): (DataOutputStream,
DataInputStream) = {
- logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec:
$pythonExec")
+ def init(): (DataOutputStream, DataInputStream) = {
+ logInfo(s"Initializing Python runner (session: $sessionId, pythonExec:
$pythonExec")
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(f =>
f.getPath()).mkString(",")
envVars.put("SPARK_LOCAL_DIRS", localdir)
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
- conf.set(PYTHON_USE_DAEMON, false)
envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)
- val pythonWorkerFactory = new PythonWorkerFactory(pythonExec,
envVars.asScala.toMap)
- val (worker: Socket, _) =
pythonWorkerFactory.createStreamingWorker(workerModule)
+ val prevConf = conf.get(PYTHON_USE_DAEMON)
+ conf.set(PYTHON_USE_DAEMON, false)
+ try {
+ val workerFactory = new PythonWorkerFactory(pythonExec,
envVars.asScala.toMap)
+ val (worker: Socket, _) =
pythonWorkerFactory.createStreamingWorker(workerModule)
+ pythonWorker = Some(worker)
+ pythonWorkerFactory = Some(workerFactory)
+ } finally {
+ conf.set(PYTHON_USE_DAEMON, prevConf)
+ }
Review Comment:
This and the stop() method are different from master branch since the
createPythonWorker method doesn't support custom modules at that time:
https://github.com/WweiL/oss-spark/blob/f8b312a22eae3ce1176da49a693182832c1f1402/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala#L72-L74
cc @ueshin to double check this
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]