WeichenXu123 commented on a change in pull request #32399:
URL: https://github.com/apache/spark/pull/32399#discussion_r633177750
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
##########
@@ -161,11 +174,28 @@ class TrainValidationSplit @Since("1.5.0")
(@Since("1.5.0") override val uid: St
}
// Wait for all metrics to be calculated
- val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
-
- // Unpersist training & validation set once all metrics have been produced
- trainingDataset.unpersist()
- validationDataset.unpersist()
+ val metrics = try {
+ metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
+ }
+ catch {
+ case e: Throwable =>
+ subTaskFailed = true
+ try {
+ Thread.sleep(1000)
+ val sparkContext = sparkSession.sparkContext
+ sparkContext.cancelJobGroup(tvsJobGroup)
+ } catch {
+ case _: Throwable => ()
Review comment:
I think this is a standard way representing return "Unit",
https://stackoverflow.com/questions/53415664/how-do-i-explicitly-return-unit
##########
File path: python/pyspark/util.py
##########
@@ -263,6 +264,69 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+def is_pinned_thread_mode():
+ """
+ Return ``True`` when spark run under pinned thread mode.
+ """
+ from pyspark import SparkContext
+ return isinstance(SparkContext._gateway, ClientServer)
+
+
+def inheritable_thread_target(f):
+ """
+ Return thread target wrapper which is recommended to be used in PySpark
when the
+ pinned thread mode is enabled. The wrapper function, before calling
original
+ thread target, it inherits the inheritable properties specific
+ to JVM thread such as ``InheritableThreadLocal``.
+
+ Also, note that pinned thread mode does not close the connection from
Python
+ to JVM when the thread is finished in the Python side. With this wrapper,
Python
+ garbage-collects the Python thread instance and also closes the connection
+ which finishes JVM thread correctly.
+
+ When the pinned thread mode is off, it return the original ``f``.
+ :param f: the original thread target.
+
+ .. versionadded:: 3.1.0
Review comment:
Ok we don't need backport it in Apache/spark.
##########
File path: mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
##########
@@ -168,9 +169,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0")
override val uid: String)
val validationDataset = sparkSession.createDataFrame(validation,
schemaWithoutFold).cache()
instr.logDebug(s"Train split $splitIndex with multiple sets of
parameters.")
+ val sparkContext = sparkSession.sparkContext
+ val oldJobGroup =
sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
+ val cvJobGroup = s"${this.uid}_job_group"
+ sparkContext.setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID,
cvJobGroup)
+ @volatile var subTaskFailed = false
// Fit models in a Future for training in parallel
val foldMetricFutures = epm.zipWithIndex.map { case (paramMap,
paramIndex) =>
Future[Double] {
+ if (subTaskFailed) {
+ throw new RuntimeException(
Review comment:
Yes just for kill the future thread.
I already add message in the runtimeException. So I think it is fine.
`IllegalStateException` means "Signals that a method has been invoked at an
illegal or inappropriate time" so it may be not a proper exception class name.
##########
File path: mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
##########
@@ -168,9 +169,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0")
override val uid: String)
val validationDataset = sparkSession.createDataFrame(validation,
schemaWithoutFold).cache()
instr.logDebug(s"Train split $splitIndex with multiple sets of
parameters.")
+ val sparkContext = sparkSession.sparkContext
+ val oldJobGroup =
sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
+ val cvJobGroup = s"${this.uid}_job_group"
+ sparkContext.setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID,
cvJobGroup)
Review comment:
Yes this is a corner case. I don't have good idea.
If we support cancelling by job group prefix, then it could address this
issue.
CC @srowen Any thoughts ?
##########
File path: mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
##########
@@ -168,9 +169,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0")
override val uid: String)
val validationDataset = sparkSession.createDataFrame(validation,
schemaWithoutFold).cache()
instr.logDebug(s"Train split $splitIndex with multiple sets of
parameters.")
+ val sparkContext = sparkSession.sparkContext
+ val oldJobGroup =
sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
+ val cvJobGroup = s"${this.uid}_job_group"
+ sparkContext.setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID,
cvJobGroup)
Review comment:
Yes this is a corner case. I don't have good idea.
CC @srowen Any thoughts ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]