This is an automated email from the ASF dual-hosted git repository.

benjobs pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/incubator-streampark.git


The following commit(s) were added to refs/heads/dev by this push:
     new d1476823e [Bug] Correct spark app submitting process (#3908)
d1476823e is described below

commit d1476823e03789cf34d3b147b5c4fed5e55deee3
Author: lenoxzhao <[email protected]>
AuthorDate: Tue Jul 23 15:01:29 2024 +0800

    [Bug] Correct spark app submitting process (#3908)
    
    * fix: correct spark app submitting process
    
    * improve: throw exception when scala version is not found
    
    * improve: standardize acquirement of spark and scala version and improve 
stop operation
    
    * improve: spark job kill improvement
---
 .../streampark/common/conf/SparkVersion.scala      | 101 +++++++++------------
 .../streampark/spark/client/impl/YarnClient.scala  |  32 ++++++-
 2 files changed, 73 insertions(+), 60 deletions(-)

diff --git 
a/streampark-common/src/main/scala/org/apache/streampark/common/conf/SparkVersion.scala
 
b/streampark-common/src/main/scala/org/apache/streampark/common/conf/SparkVersion.scala
index c6ba76888..8b52adfb3 100644
--- 
a/streampark-common/src/main/scala/org/apache/streampark/common/conf/SparkVersion.scala
+++ 
b/streampark-common/src/main/scala/org/apache/streampark/common/conf/SparkVersion.scala
@@ -21,7 +21,6 @@ import org.apache.streampark.common.util.{CommandUtils, 
Logger}
 import org.apache.streampark.common.util.Implicits._
 
 import java.io.File
-import java.net.URL
 import java.util.function.Consumer
 import java.util.regex.Pattern
 
@@ -32,34 +31,49 @@ class SparkVersion(val sparkHome: String) extends 
Serializable with Logger {
 
   private[this] lazy val SPARK_VER_PATTERN = 
Pattern.compile("^(\\d+\\.\\d+)(\\.)?.*$")
 
-  private[this] lazy val SPARK_VERSION_PATTERN = Pattern.compile("(version) 
(\\d+\\.\\d+\\.\\d+)")
+  private[this] lazy val SPARK_VERSION_PATTERN = 
Pattern.compile("\\s{2}version\\s(\\d+\\.\\d+\\.\\d+)")
 
-  private[this] lazy val SPARK_SCALA_VERSION_PATTERN = 
Pattern.compile("^spark-core_(.*)-[0-9].*.jar$")
+  private[this] lazy val SPARK_SCALA_VERSION_PATTERN = 
Pattern.compile("Using\\sScala\\sversion\\s(\\d+\\.\\d+)")
 
-  lazy val scalaVersion: String = 
SPARK_SCALA_VERSION_PATTERN.matcher(sparkCoreJar.getName).group(1)
+  val (version, scalaVersion) = {
+    var sparkVersion: String = null
+    var scalaVersion: String = null
+    val cmd = List(s"$sparkHome/bin/spark-submit --version")
+    val buffer = new mutable.StringBuilder
 
-  lazy val sparkCoreJar: File = {
-    val distJar = 
sparkLib.listFiles().filter(_.getName.matches("spark-core.*\\.jar"))
-    distJar match {
-      case x if x.isEmpty =>
-        throw new IllegalArgumentException(s"[StreamPark] can no found 
spark-core jar in $sparkLib")
-      case x if x.length > 1 =>
-        throw new IllegalArgumentException(
-          s"[StreamPark] found multiple spark-core jar in $sparkLib")
-      case _ =>
+    CommandUtils.execute(
+      sparkHome,
+      cmd,
+      new Consumer[String]() {
+        override def accept(out: String): Unit = {
+          buffer.append(out).append("\n")
+          val matcher = SPARK_VERSION_PATTERN.matcher(out)
+          if (matcher.find) {
+            sparkVersion = matcher.group(1)
+          } else {
+            val matcher1 = SPARK_SCALA_VERSION_PATTERN.matcher(out)
+            if (matcher1.find) {
+              scalaVersion = matcher1.group(1)
+            }
+          }
+        }
+      })
+
+    logInfo(buffer.toString())
+    if (sparkVersion == null || scalaVersion == null) {
+      throw new IllegalStateException(s"[StreamPark] parse spark version 
failed. $buffer")
     }
-    distJar.head
+    buffer.clear()
+    (sparkVersion, scalaVersion)
   }
 
-  def checkVersion(throwException: Boolean = true): Boolean = {
-    version.split("\\.").map(_.trim.toInt) match {
-      case Array(3, v, _) if v >= 1 && v <= 3 => true
-      case _ =>
-        if (throwException) {
-          throw new UnsupportedOperationException(s"Unsupported spark version: 
$version")
-        } else {
-          false
-        }
+  lazy val majorVersion: String = {
+    if (version == null) {
+      null
+    } else {
+      val matcher = SPARK_VER_PATTERN.matcher(version)
+      matcher.matches()
+      matcher.group(1)
     }
   }
 
@@ -75,41 +89,16 @@ class SparkVersion(val sparkHome: String) extends 
Serializable with Logger {
     lib
   }
 
-  lazy val sparkLibs: List[URL] = 
sparkLib.listFiles().map(_.toURI.toURL).toList
-
-  lazy val majorVersion: String = {
-    if (version == null) {
-      null
-    } else {
-      val matcher = SPARK_VER_PATTERN.matcher(version)
-      matcher.matches()
-      matcher.group(1)
-    }
-  }
-
-  lazy val version: String = {
-    var sparkVersion: String = null
-    val cmd = List(s"$sparkHome/bin/spark-submit --version")
-    val buffer = new mutable.StringBuilder
-    CommandUtils.execute(
-      sparkHome,
-      cmd,
-      new Consumer[String]() {
-        override def accept(out: String): Unit = {
-          buffer.append(out).append("\n")
-          val matcher = SPARK_VERSION_PATTERN.matcher(out)
-          if (matcher.find) {
-            sparkVersion = matcher.group(2)
-          }
+  def checkVersion(throwException: Boolean = true): Boolean = {
+    version.split("\\.").map(_.trim.toInt) match {
+      case Array(3, v, _) if v >= 1 && v <= 3 => true
+      case _ =>
+        if (throwException) {
+          throw new UnsupportedOperationException(s"Unsupported spark version: 
$version")
+        } else {
+          false
         }
-      })
-
-    logInfo(buffer.toString())
-    if (sparkVersion == null) {
-      throw new IllegalStateException(s"[StreamPark] parse spark version 
failed. $buffer")
     }
-    buffer.clear()
-    sparkVersion
   }
 
   override def toString: String =
diff --git 
a/streampark-spark/streampark-spark-client/streampark-spark-client-core/src/main/scala/org/apache/streampark/spark/client/impl/YarnClient.scala
 
b/streampark-spark/streampark-spark-client/streampark-spark-client-core/src/main/scala/org/apache/streampark/spark/client/impl/YarnClient.scala
index 19e6beaf7..60648c15d 100644
--- 
a/streampark-spark/streampark-spark-client/streampark-spark-client-core/src/main/scala/org/apache/streampark/spark/client/impl/YarnClient.scala
+++ 
b/streampark-spark/streampark-spark-client/streampark-spark-client-core/src/main/scala/org/apache/streampark/spark/client/impl/YarnClient.scala
@@ -26,16 +26,39 @@ import org.apache.streampark.spark.client.bean._
 import org.apache.hadoop.yarn.api.records.ApplicationId
 import org.apache.spark.launcher.{SparkAppHandle, SparkLauncher}
 
-import java.util.concurrent.CountDownLatch
+import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}
 
 import scala.util.{Failure, Success, Try}
 
 /** yarn application mode submit */
 object YarnClient extends SparkClientTrait {
 
+  private lazy val sparkHandles = new ConcurrentHashMap[String, 
SparkAppHandle]()
+
   override def doStop(stopRequest: StopRequest): StopResponse = {
-    
HadoopUtils.yarnClient.killApplication(ApplicationId.fromString(stopRequest.jobId))
-    null
+    val sparkAppHandle = sparkHandles.remove(stopRequest.jobId)
+    if (sparkAppHandle != null) {
+      Try(sparkAppHandle.kill()) match {
+        case Success(_) =>
+          logger.info(s"[StreamPark][Spark][YarnClient] spark job: 
${stopRequest.jobId} is stopped successfully.")
+          StopResponse(null)
+        case Failure(e) =>
+          logger.error("[StreamPark][Spark][YarnClient] sparkAppHandle kill 
failed. Try kill by yarn", e)
+          yarnKill(stopRequest.jobId)
+          StopResponse(null)
+      }
+    } else {
+      logger.warn(s"[StreamPark][Spark][YarnClient] spark job: 
${stopRequest.jobId} is not existed. Try kill by yarn")
+      yarnKill(stopRequest.jobId)
+      StopResponse(null)
+    }
+  }
+
+  private def yarnKill(appId: String): Unit = {
+    
Try(HadoopUtils.yarnClient.killApplication(ApplicationId.fromString(appId))) 
match {
+      case Success(_) => logger.info(s"[StreamPark][Spark][YarnClient] spark 
job: $appId is killed by yarn successfully.")
+      case Failure(e) => throw e
+    }
   }
 
   override def setConfig(submitRequest: SubmitRequest): Unit = {}
@@ -53,6 +76,7 @@ object YarnClient extends SparkClientTrait {
         logger.info(s"[StreamPark][Spark][YarnClient] spark job: 
${submitRequest.effectiveAppName} is submit successful, " +
           s"appid: ${handle.getAppId}, " +
           s"state: ${handle.getState}")
+        sparkHandles += handle.getAppId -> handle
         SubmitResponse(handle.getAppId, submitRequest.properties)
       case Failure(e) => throw e
     }
@@ -72,7 +96,7 @@ object YarnClient extends SparkClientTrait {
         if (SparkAppHandle.State.FAILED == handle.getState) {
           logger.error("Task run failure stateChanged :{}", 
handle.getState.toString)
         }
-        if (handle.getState.isFinal) {
+        if (handle.getAppId != null && submitFinished.getCount != 0) {
           submitFinished.countDown()
         }
       }

Reply via email to