cloud-fan commented on a change in pull request #26624:
URL: https://github.com/apache/spark/pull/26624#discussion_r417736547
##########
File path: core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
##########
@@ -17,21 +17,110 @@
package org.apache.spark.util
+import java.util
import java.util.concurrent._
import java.util.concurrent.locks.ReentrantLock
+import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
import scala.concurrent.{Awaitable, ExecutionContext,
ExecutionContextExecutor, Future}
import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.language.higherKinds
import scala.util.control.NonFatal
-import com.google.common.util.concurrent.ThreadFactoryBuilder
-
import org.apache.spark.SparkException
import org.apache.spark.rpc.RpcAbortException
private[spark] object ThreadUtils {
+ object MDCAwareThreadPoolExecutor {
+ def newCachedThreadPool(threadFactory: ThreadFactory): ThreadPoolExecutor
= {
+ new MDCAwareThreadPoolExecutor(0, Integer.MAX_VALUE,
+ 60L, TimeUnit.SECONDS,
+ new SynchronousQueue[Runnable],
+ threadFactory)
+ }
+
+ def newFixedThreadPool(nThreads: Int, threadFactory: ThreadFactory):
ThreadPoolExecutor = {
+ new MDCAwareThreadPoolExecutor(nThreads, nThreads,
+ 0L,
+ TimeUnit.MILLISECONDS,
+ new LinkedBlockingQueue[Runnable],
+ threadFactory)
+ }
+
+ def newSingleThreadExecutor(threadFactory: ThreadFactory):
ThreadPoolExecutor = {
+ new MDCAwareThreadPoolExecutor(1, 1,
+ 0L, TimeUnit.MILLISECONDS,
+ new LinkedBlockingQueue[Runnable],
+ threadFactory)
+ }
+
+ }
+
+ class MDCAwareScheduledThreadPoolExecutor(
+ corePoolSize: Int,
+ threadFactory: ThreadFactory)
+ extends ScheduledThreadPoolExecutor(corePoolSize, threadFactory) {
+ override def execute(runnable: Runnable) {
+ super.execute(new Runnable {
+ val callerThreadMDC: util.Map[String, String] = getMDCMap
+
+ override def run() {
+ val threadMDC = getMDCMap
+ org.slf4j.MDC.setContextMap(callerThreadMDC)
+ try {
+ runnable.run()
+ } finally {
+ org.slf4j.MDC.setContextMap(threadMDC)
+ }
+ }
+ })
+ }
+
+ @inline
+ private def getMDCMap: util.Map[String, String] = {
+ org.slf4j.MDC.getCopyOfContextMap match {
+ case null => new util.HashMap[String, String]()
+ case m => m
+ }
+ }
+ }
+
+ class MDCAwareThreadPoolExecutor(
+ corePoolSize: Int,
+ maximumPoolSize: Int,
+ keepAliveTime: Long,
+ unit: TimeUnit,
+ workQueue: BlockingQueue[Runnable],
+ threadFactory: ThreadFactory)
+ extends ThreadPoolExecutor(corePoolSize, maximumPoolSize,
+ keepAliveTime, unit, workQueue, threadFactory) {
+
+ override def execute(runnable: Runnable) {
+ super.execute(new Runnable {
Review comment:
how about creating a `MDCAwareRunnable(proxy: Runnable) extends
Runnable` to avoid duplicated code?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]