WweiL commented on code in PR #42340:
URL: https://github.com/apache/spark/pull/42340#discussion_r1284789131


##########
core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala:
##########
@@ -29,39 +29,56 @@ import 
org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTH
 
 
 private[spark] object StreamingPythonRunner {
-  def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = 
{
-    new StreamingPythonRunner(func, connectUrl)
+  def apply(
+      func: PythonFunction,
+      connectUrl: String,
+      sessionId: String,
+      workerModule: String
+  ): StreamingPythonRunner = {
+    new StreamingPythonRunner(func, connectUrl, sessionId, workerModule)
   }
 }
 
-private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: 
String)
-  extends Logging {
+private[spark] class StreamingPythonRunner(
+    func: PythonFunction,
+    connectUrl: String,
+    sessionId: String,
+    workerModule: String) extends Logging {
   private val conf = SparkEnv.get.conf
   protected val bufferSize: Int = conf.get(BUFFER_SIZE)
   protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
 
   private val envVars: java.util.Map[String, String] = func.envVars
   private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[Socket] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
   protected val pythonVer: String = func.pythonVer
 
   /**
    * Initializes the Python worker for streaming functions. Sets up Spark 
Connect session
    * to be used with the functions.
    */
-  def init(sessionId: String, workerModule: String): (DataOutputStream, 
DataInputStream) = {
-    logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: 
$pythonExec")
+  def init(): (DataOutputStream, DataInputStream) = {
+    logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: 
$pythonExec")
     val env = SparkEnv.get
 
     val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
     envVars.put("SPARK_LOCAL_DIRS", localdir)
 
     envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
     envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
-    conf.set(PYTHON_USE_DAEMON, false)
     envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)
 
-    val pythonWorkerFactory = new PythonWorkerFactory(pythonExec, 
envVars.asScala.toMap)
-    val (worker: Socket, _) = 
pythonWorkerFactory.createStreamingWorker(workerModule)
+    val prevConf = conf.get(PYTHON_USE_DAEMON)
+    conf.set(PYTHON_USE_DAEMON, false)
+    try {
+      val workerFactory = new PythonWorkerFactory(pythonExec, 
envVars.asScala.toMap)
+      val (worker: Socket, _) = 
pythonWorkerFactory.createStreamingWorker(workerModule)
+      pythonWorker = Some(worker)
+      pythonWorkerFactory = Some(workerFactory)
+    } finally {
+      conf.set(PYTHON_USE_DAEMON, prevConf)
+    }
 
     val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)

Review Comment:
   ah thank you! will add



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to