HyukjinKwon commented on a change in pull request #32399:
URL: https://github.com/apache/spark/pull/32399#discussion_r629086511
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
##########
@@ -161,11 +169,26 @@ class TrainValidationSplit @Since("1.5.0")
(@Since("1.5.0") override val uid: St
}
// Wait for all metrics to be calculated
- val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
-
- // Unpersist training & validation set once all metrics have been produced
- trainingDataset.unpersist()
- validationDataset.unpersist()
+ val metrics = try {
+ metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
+ }
+ catch {
+ case e: Throwable =>
+ subTaskFailed = true
+ throw e
+ }
+ finally {
+ if (subTaskFailed) {
+ Thread.sleep(1000)
+ val sparkContext = dataset.sparkSession.sparkContext
+ sparkContext.cancelJobGroup(
+ sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
+ )
+ }
Review comment:
i think we could do this as below: without `subTaskFailed` .. ?
```scala
catch {
case e: Throwable =>
val sparkContext = dataset.sparkSession.sparkContext
sparkContext.cancelJobGroup(
sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
)
```
##########
File path: python/pyspark/ml/tuning.py
##########
@@ -730,13 +733,40 @@ def _fit(self, dataset):
train = datasets[i][0].cache()
tasks = _parallelFitTasks(est, train, eva, validation, epm,
collectSubModelsParam)
- for j, metric, subModel in pool.imap_unordered(lambda f: f(),
tasks):
- metrics[j] += (metric / nFolds)
- if collectSubModelsParam:
- subModels[i][j] = subModel
- validation.unpersist()
- train.unpersist()
+ sub_task_failed = False
+
+ @inheritable_thread_target
+ def run_task(task):
+ if sub_task_failed:
+ raise RuntimeError("Terminate this task because one of
other task failed.")
+ return task()
+
+ try:
+ for j, metric, subModel in pool.imap_unordered(run_task,
tasks):
+ metrics[j] += (metric / nFolds)
+ if collectSubModelsParam:
+ subModels[i][j] = subModel
+ except:
+ sub_task_failed = True
+ raise
+ finally:
+ if sub_task_failed:
+ if is_pinned_thread_mode():
+ try:
+ time.sleep(1)
+ sc = dataset._sc
+
sc.cancelJobGroup(sc.getLocalProperty("spark.jobGroup.id"))
+ except:
+ pass
+ else:
+ warnings.warn("CrossValidator {} fit call failed but
some spark jobs "
+ "may still running for unfinished
trials. Enable pyspark "
Review comment:
Hm, why is it inconsistent with `TrainValidationSplit`? Seems like
`TrainValidationSplit` will always cancel but here only cancel when pinned
thread mode is on.
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
##########
@@ -161,11 +169,26 @@ class TrainValidationSplit @Since("1.5.0")
(@Since("1.5.0") override val uid: St
}
// Wait for all metrics to be calculated
- val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
-
- // Unpersist training & validation set once all metrics have been produced
- trainingDataset.unpersist()
- validationDataset.unpersist()
+ val metrics = try {
+ metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
+ }
+ catch {
+ case e: Throwable =>
+ subTaskFailed = true
+ throw e
+ }
+ finally {
+ if (subTaskFailed) {
+ Thread.sleep(1000)
Review comment:
Would you mind elabourating why we should sleep here? I should avoid
relying on sleep in the main codes or test codes whenever possible.
##########
File path: python/pyspark/util.py
##########
@@ -263,6 +264,69 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+def is_pinned_thread_mode():
+ """
+ Return ``True`` when spark run under pinned thread mode.
+ """
+ from pyspark import SparkContext
+ return isinstance(SparkContext._gateway, ClientServer)
+
+
+def inheritable_thread_target(f):
+ """
+ Return thread target wrapper which is recommended to be used in PySpark
when the
+ pinned thread mode is enabled. The wrapper function, before calling
original
+ thread target, it inherits the inheritable properties specific
+ to JVM thread such as ``InheritableThreadLocal``.
+
+ Also, note that pinned thread mode does not close the connection from
Python
+ to JVM when the thread is finished in the Python side. With this wrapper,
Python
+ garbage-collects the Python thread instance and also closes the connection
+ which finishes JVM thread correctly.
+
+ When the pinned thread mode is off, it return the original ``f``.
+ :param f: the original thread target.
+
+ .. versionadded:: 3.1.0
Review comment:
I think you meant: 3.2.0
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]