This is an automated email from the ASF dual-hosted git repository.

zsxwing pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cfd7ca9  Revert "[SPARK-21869][SS] Apply Apache Commons Pool to Kafka 
producer"
cfd7ca9 is described below

commit cfd7ca9a06161f7622b5179a777f965c11892afa
Author: Shixiong Zhu <zsxw...@gmail.com>
AuthorDate: Tue Dec 10 11:21:46 2019 -0800

    Revert "[SPARK-21869][SS] Apply Apache Commons Pool to Kafka producer"
    
    This reverts commit 3641c3dd69b2bd2beae028d52356450cc41f69ed.
---
 .../spark/sql/kafka010/CachedKafkaProducer.scala   | 118 +++++++-----
 .../sql/kafka010/InternalKafkaConnectorPool.scala  | 210 ---------------------
 .../sql/kafka010/InternalKafkaConsumerPool.scala   | 210 ++++++++++++++++++---
 .../sql/kafka010/InternalKafkaProducerPool.scala   |  68 -------
 .../spark/sql/kafka010/KafkaDataConsumer.scala     |   7 +-
 .../spark/sql/kafka010/KafkaDataWriter.scala       |  34 +---
 .../apache/spark/sql/kafka010/KafkaWriteTask.scala |  20 +-
 .../org/apache/spark/sql/kafka010/package.scala    |  34 +---
 .../sql/kafka010/CachedKafkaProducerSuite.scala    | 154 ++++-----------
 ....scala => InternalKafkaConsumerPoolSuite.scala} |   8 +-
 .../sql/kafka010/KafkaDataConsumerSuite.scala      |   6 +-
 .../org/apache/spark/sql/kafka010/KafkaTest.scala  |  10 +-
 .../kafka010/KafkaDataConsumerSuite.scala          |   7 +
 13 files changed, 332 insertions(+), 554 deletions(-)

diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala
index 907440a..fc177cd 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala
@@ -18,68 +18,60 @@
 package org.apache.spark.sql.kafka010
 
 import java.{util => ju}
-import java.io.Closeable
+import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit}
 
+import com.google.common.cache._
+import com.google.common.util.concurrent.{ExecutionError, 
UncheckedExecutionException}
+import org.apache.kafka.clients.producer.KafkaProducer
 import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
 
-import org.apache.kafka.clients.producer.{Callback, KafkaProducer, 
ProducerRecord}
-
 import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaRedactionUtil}
-import org.apache.spark.sql.kafka010.InternalKafkaProducerPool._
-import org.apache.spark.util.ShutdownHookManager
 
-private[kafka010] class CachedKafkaProducer(val kafkaParams: ju.Map[String, 
Object])
-  extends Closeable with Logging {
+private[kafka010] object CachedKafkaProducer extends Logging {
 
   private type Producer = KafkaProducer[Array[Byte], Array[Byte]]
 
-  private val producer = createProducer()
+  private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10)
 
-  private def createProducer(): Producer = {
-    val producer: Producer = new Producer(kafkaParams)
-    if (log.isDebugEnabled()) {
-      val redactedParamsSeq = 
KafkaRedactionUtil.redactParams(toCacheKey(kafkaParams))
-      logDebug(s"Created a new instance of kafka producer for 
$redactedParamsSeq.")
+  private lazy val cacheExpireTimeout: Long = Option(SparkEnv.get)
+    .map(_.conf.get(PRODUCER_CACHE_TIMEOUT))
+    .getOrElse(defaultCacheExpireTimeout)
+
+  private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] {
+    override def load(config: Seq[(String, Object)]): Producer = {
+      createKafkaProducer(config)
     }
-    producer
   }
 
-  override def close(): Unit = {
-    try {
-      if (log.isInfoEnabled()) {
-        val redactedParamsSeq = 
KafkaRedactionUtil.redactParams(toCacheKey(kafkaParams))
-        logInfo(s"Closing the KafkaProducer with params: 
${redactedParamsSeq.mkString("\n")}.")
+  private val removalListener = new RemovalListener[Seq[(String, Object)], 
Producer]() {
+    override def onRemoval(
+        notification: RemovalNotification[Seq[(String, Object)], Producer]): 
Unit = {
+      val paramsSeq: Seq[(String, Object)] = notification.getKey
+      val producer: Producer = notification.getValue
+      if (log.isDebugEnabled()) {
+        val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
+        logDebug(s"Evicting kafka producer $producer params: 
$redactedParamsSeq, " +
+          s"due to ${notification.getCause}")
       }
-      producer.close()
-    } catch {
-      case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
+      close(paramsSeq, producer)
     }
   }
 
-  def send(record: ProducerRecord[Array[Byte], Array[Byte]], callback: 
Callback): Unit = {
-    producer.send(record, callback)
-  }
-
-  def flush(): Unit = {
-    producer.flush()
-  }
-}
+  private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] =
+    CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, 
TimeUnit.MILLISECONDS)
+      .removalListener(removalListener)
+      .build[Seq[(String, Object)], Producer](cacheLoader)
 
-private[kafka010] object CachedKafkaProducer extends Logging {
-
-  private val sparkConf = SparkEnv.get.conf
-  private val producerPool = new InternalKafkaProducerPool(sparkConf)
-
-  ShutdownHookManager.addShutdownHook { () =>
-    try {
-      producerPool.close()
-    } catch {
-      case e: Throwable =>
-        logWarning("Ignoring exception while shutting down pool from shutdown 
hook", e)
+  private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer 
= {
+    val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava)
+    if (log.isDebugEnabled()) {
+      val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
+      logDebug(s"Created a new instance of KafkaProducer for 
$redactedParamsSeq.")
     }
+    kafkaProducer
   }
 
   /**
@@ -87,20 +79,50 @@ private[kafka010] object CachedKafkaProducer extends 
Logging {
    * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, 
it is best to keep
    * one instance per specified kafkaParams.
    */
-  def acquire(kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = {
-    val updatedKafkaParams =
+  private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): 
Producer = {
+    val updatedKafkaProducerConfiguration =
       KafkaConfigUpdater("executor", kafkaParams.asScala.toMap)
         .setAuthenticationConfigIfNeeded()
         .build()
-    val key = toCacheKey(updatedKafkaParams)
-    producerPool.borrowObject(key, updatedKafkaParams)
+    val paramsSeq: Seq[(String, Object)] = 
paramsToSeq(updatedKafkaProducerConfiguration)
+    try {
+      guavaCache.get(paramsSeq)
+    } catch {
+      case e @ (_: ExecutionException | _: UncheckedExecutionException | _: 
ExecutionError)
+        if e.getCause != null =>
+        throw e.getCause
+    }
+  }
+
+  private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, 
Object)] = {
+    val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x 
=> x._1)
+    paramsSeq
+  }
+
+  /** For explicitly closing kafka producer */
+  private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = {
+    val paramsSeq = paramsToSeq(kafkaParams)
+    guavaCache.invalidate(paramsSeq)
   }
 
-  def release(producer: CachedKafkaProducer): Unit = {
-    producerPool.returnObject(producer)
+  /** Auto close on cache evict */
+  private def close(paramsSeq: Seq[(String, Object)], producer: Producer): 
Unit = {
+    try {
+      if (log.isInfoEnabled()) {
+        val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
+        logInfo(s"Closing the KafkaProducer with params: 
${redactedParamsSeq.mkString("\n")}.")
+      }
+      producer.close()
+    } catch {
+      case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
+    }
   }
 
   private[kafka010] def clear(): Unit = {
-    producerPool.reset()
+    logInfo("Cleaning up guava cache.")
+    guavaCache.invalidateAll()
   }
+
+  // Intended for testing purpose only.
+  private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = 
guavaCache.asMap()
 }
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConnectorPool.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConnectorPool.scala
deleted file mode 100644
index 0fb250e..0000000
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConnectorPool.scala
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.kafka010
-
-import java.{util => ju}
-import java.io.Closeable
-import java.util.concurrent.ConcurrentHashMap
-
-import org.apache.commons.pool2.{BaseKeyedPooledObjectFactory, PooledObject, 
SwallowedExceptionListener}
-import org.apache.commons.pool2.impl.{DefaultEvictionPolicy, 
DefaultPooledObject, GenericKeyedObjectPool, GenericKeyedObjectPoolConfig}
-
-import org.apache.spark.internal.Logging
-
-/**
- * Provides object pool for objects which is grouped by a key.
- *
- * This class leverages [[GenericKeyedObjectPool]] internally, hence providing 
methods based on
- * the class, and same contract applies: after using the borrowed object, you 
must either call
- * returnObject() if the object is healthy to return to pool, or 
invalidateObject() if the object
- * should be destroyed.
- *
- * The soft capacity of pool is determined by "poolConfig.capacity" config 
value,
- * and the pool will have reasonable default value if the value is not 
provided.
- * (The instance will do its best effort to respect soft capacity but it can 
exceed when there's
- * a borrowing request and there's neither free space nor idle object to 
clear.)
- *
- * This class guarantees that no caller will get pooled object once the object 
is borrowed and
- * not yet returned, hence provide thread-safety usage of non-thread-safe 
objects unless caller
- * shares the object to multiple threads.
- */
-private[kafka010] abstract class InternalKafkaConnectorPool[K, V <: Closeable](
-    objectFactory: ObjectFactory[K, V],
-    poolConfig: PoolConfig[V],
-    swallowedExceptionListener: SwallowedExceptionListener) extends Logging {
-
-  // the class is intended to have only soft capacity
-  assert(poolConfig.getMaxTotal < 0)
-
-  private val pool = {
-    val internalPool = new GenericKeyedObjectPool[K, V](objectFactory, 
poolConfig)
-    internalPool.setSwallowedExceptionListener(swallowedExceptionListener)
-    internalPool
-  }
-
-  /**
-   * Borrows object from the pool. If there's no idle object for the key,
-   * the pool will create the object.
-   *
-   * If the pool doesn't have idle object for the key and also exceeds the 
soft capacity,
-   * pool will try to clear some of idle objects.
-   *
-   * Borrowed object must be returned by either calling returnObject or 
invalidateObject, otherwise
-   * the object will be kept in pool as active object.
-   */
-  def borrowObject(key: K, kafkaParams: ju.Map[String, Object]): V = {
-    updateKafkaParamForKey(key, kafkaParams)
-
-    if (size >= poolConfig.softMaxSize) {
-      logWarning("Pool exceeds its soft max size, cleaning up idle objects...")
-      pool.clearOldest()
-    }
-
-    pool.borrowObject(key)
-  }
-
-  /** Returns borrowed object to the pool. */
-  def returnObject(connector: V): Unit = {
-    pool.returnObject(createKey(connector), connector)
-  }
-
-  /** Invalidates (destroy) borrowed object to the pool. */
-  def invalidateObject(connector: V): Unit = {
-    pool.invalidateObject(createKey(connector), connector)
-  }
-
-  /** Invalidates all idle values for the key */
-  def invalidateKey(key: K): Unit = {
-    pool.clear(key)
-  }
-
-  /**
-   * Closes the keyed object pool. Once the pool is closed,
-   * borrowObject will fail with [[IllegalStateException]], but returnObject 
and invalidateObject
-   * will continue to work, with returned objects destroyed on return.
-   *
-   * Also destroys idle instances in the pool.
-   */
-  def close(): Unit = {
-    pool.close()
-  }
-
-  def reset(): Unit = {
-    // this is the best-effort of clearing up. otherwise we should close the 
pool and create again
-    // but we don't want to make it "var" only because of tests.
-    pool.clear()
-  }
-
-  def numIdle: Int = pool.getNumIdle
-
-  def numIdle(key: K): Int = pool.getNumIdle(key)
-
-  def numActive: Int = pool.getNumActive
-
-  def numActive(key: K): Int = pool.getNumActive(key)
-
-  def size: Int = numIdle + numActive
-
-  def size(key: K): Int = numIdle(key) + numActive(key)
-
-  private def updateKafkaParamForKey(key: K, kafkaParams: ju.Map[String, 
Object]): Unit = {
-    // We can assume that kafkaParam should not be different for same cache 
key,
-    // otherwise we can't reuse the cached object and cache key should contain 
kafkaParam.
-    // So it should be safe to put the key/value pair only when the key 
doesn't exist.
-    val oldKafkaParams = objectFactory.keyToKafkaParams.putIfAbsent(key, 
kafkaParams)
-    require(oldKafkaParams == null || kafkaParams == oldKafkaParams, "Kafka 
parameters for same " +
-      s"cache key should be equal. old parameters: $oldKafkaParams new 
parameters: $kafkaParams")
-  }
-
-  protected def createKey(connector: V): K
-}
-
-private[kafka010] abstract class PoolConfig[V] extends 
GenericKeyedObjectPoolConfig[V] {
-
-  init()
-
-  def softMaxSize: Int
-
-  def jmxEnabled: Boolean
-
-  def minEvictableIdleTimeMillis: Long
-
-  def evictorThreadRunIntervalMillis: Long
-
-  def jmxNamePrefix: String
-
-  def init(): Unit = {
-    // NOTE: Below lines define the behavior, so do not modify unless you know 
what you are
-    // doing, and update the class doc accordingly if necessary when you 
modify.
-
-    // 1. Set min idle objects per key to 0 to avoid creating unnecessary 
object.
-    // 2. Set max idle objects per key to 3 but set total objects per key to 
infinite
-    // which ensures borrowing per key is not restricted.
-    // 3. Set max total objects to infinite which ensures all objects are 
managed in this pool.
-    setMinIdlePerKey(0)
-    setMaxIdlePerKey(3)
-    setMaxTotalPerKey(-1)
-    setMaxTotal(-1)
-
-    // Set minimum evictable idle time which will be referred from evictor 
thread
-    setMinEvictableIdleTimeMillis(minEvictableIdleTimeMillis)
-    setSoftMinEvictableIdleTimeMillis(-1)
-
-    // evictor thread will run test with ten idle objects
-    setTimeBetweenEvictionRunsMillis(evictorThreadRunIntervalMillis)
-    setNumTestsPerEvictionRun(10)
-    setEvictionPolicy(new DefaultEvictionPolicy[V]())
-
-    // Immediately fail on exhausted pool while borrowing
-    setBlockWhenExhausted(false)
-
-    setJmxEnabled(jmxEnabled)
-    setJmxNamePrefix(jmxNamePrefix)
-  }
-}
-
-private[kafka010] abstract class ObjectFactory[K, V <: Closeable]
-  extends BaseKeyedPooledObjectFactory[K, V] {
-  val keyToKafkaParams = new ConcurrentHashMap[K, ju.Map[String, Object]]()
-
-  override def create(key: K): V = {
-    Option(keyToKafkaParams.get(key)) match {
-      case Some(kafkaParams) => createValue(key, kafkaParams)
-      case None => throw new IllegalStateException("Kafka params should be set 
before " +
-        "borrowing object.")
-    }
-  }
-
-  override def wrap(value: V): PooledObject[V] = {
-    new DefaultPooledObject[V](value)
-  }
-
-  override def destroyObject(key: K, p: PooledObject[V]): Unit = {
-    p.getObject.close()
-  }
-
-  protected def createValue(key: K, kafkaParams: ju.Map[String, Object]): V
-}
-
-private[kafka010] class CustomSwallowedExceptionListener(connectorType: String)
-  extends SwallowedExceptionListener with Logging {
-
-  override def onSwallowException(e: Exception): Unit = {
-    logWarning(s"Error closing Kafka $connectorType", e)
-  }
-}
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala
index a8e6045..276a942 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPool.scala
@@ -18,46 +18,204 @@
 package org.apache.spark.sql.kafka010
 
 import java.{util => ju}
+import java.util.concurrent.ConcurrentHashMap
 
-import org.apache.commons.pool2.PooledObject
+import org.apache.commons.pool2.{BaseKeyedPooledObjectFactory, PooledObject, 
SwallowedExceptionListener}
+import org.apache.commons.pool2.impl.{DefaultEvictionPolicy, 
DefaultPooledObject, GenericKeyedObjectPool, GenericKeyedObjectPoolConfig}
 
 import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.kafka010.InternalKafkaConsumerPool._
 import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey
 
-// TODO: revisit the relation between CacheKey and kafkaParams - for now it 
looks a bit weird
-//   as we force all consumers having same (groupId, topicPartition) to have 
same kafkaParams
-//   which might be viable in performance perspective (kafkaParams might be 
too huge to use
-//   as a part of key), but there might be the case kafkaParams could be 
different -
-//   cache key should be differentiated for both kafkaParams.
+/**
+ * Provides object pool for [[InternalKafkaConsumer]] which is grouped by 
[[CacheKey]].
+ *
+ * This class leverages [[GenericKeyedObjectPool]] internally, hence providing 
methods based on
+ * the class, and same contract applies: after using the borrowed object, you 
must either call
+ * returnObject() if the object is healthy to return to pool, or 
invalidateObject() if the object
+ * should be destroyed.
+ *
+ * The soft capacity of pool is determined by 
"spark.kafka.consumer.cache.capacity" config value,
+ * and the pool will have reasonable default value if the value is not 
provided.
+ * (The instance will do its best effort to respect soft capacity but it can 
exceed when there's
+ * a borrowing request and there's neither free space nor idle object to 
clear.)
+ *
+ * This class guarantees that no caller will get pooled object once the object 
is borrowed and
+ * not yet returned, hence provide thread-safety usage of non-thread-safe 
[[InternalKafkaConsumer]]
+ * unless caller shares the object to multiple threads.
+ */
 private[kafka010] class InternalKafkaConsumerPool(
-    objectFactory: ConsumerObjectFactory,
-    poolConfig: ConsumerPoolConfig)
-  extends InternalKafkaConnectorPool[CacheKey, InternalKafkaConsumer](
-      objectFactory,
-      poolConfig,
-      new CustomSwallowedExceptionListener("consumer")) {
+    objectFactory: ObjectFactory,
+    poolConfig: PoolConfig) extends Logging {
 
   def this(conf: SparkConf) = {
-    this(new ConsumerObjectFactory, new ConsumerPoolConfig(conf))
+    this(new ObjectFactory, new PoolConfig(conf))
+  }
+
+  // the class is intended to have only soft capacity
+  assert(poolConfig.getMaxTotal < 0)
+
+  private val pool = {
+    val internalPool = new GenericKeyedObjectPool[CacheKey, 
InternalKafkaConsumer](
+      objectFactory, poolConfig)
+    
internalPool.setSwallowedExceptionListener(CustomSwallowedExceptionListener)
+    internalPool
+  }
+
+  /**
+   * Borrows [[InternalKafkaConsumer]] object from the pool. If there's no 
idle object for the key,
+   * the pool will create the [[InternalKafkaConsumer]] object.
+   *
+   * If the pool doesn't have idle object for the key and also exceeds the 
soft capacity,
+   * pool will try to clear some of idle objects.
+   *
+   * Borrowed object must be returned by either calling returnObject or 
invalidateObject, otherwise
+   * the object will be kept in pool as active object.
+   */
+  def borrowObject(key: CacheKey, kafkaParams: ju.Map[String, Object]): 
InternalKafkaConsumer = {
+    updateKafkaParamForKey(key, kafkaParams)
+
+    if (size >= poolConfig.softMaxSize) {
+      logWarning("Pool exceeds its soft max size, cleaning up idle objects...")
+      pool.clearOldest()
+    }
+
+    pool.borrowObject(key)
+  }
+
+  /** Returns borrowed object to the pool. */
+  def returnObject(consumer: InternalKafkaConsumer): Unit = {
+    pool.returnObject(extractCacheKey(consumer), consumer)
+  }
+
+  /** Invalidates (destroy) borrowed object to the pool. */
+  def invalidateObject(consumer: InternalKafkaConsumer): Unit = {
+    pool.invalidateObject(extractCacheKey(consumer), consumer)
+  }
+
+  /** Invalidates all idle consumers for the key */
+  def invalidateKey(key: CacheKey): Unit = {
+    pool.clear(key)
+  }
+
+  /**
+   * Closes the keyed object pool. Once the pool is closed,
+   * borrowObject will fail with [[IllegalStateException]], but returnObject 
and invalidateObject
+   * will continue to work, with returned objects destroyed on return.
+   *
+   * Also destroys idle instances in the pool.
+   */
+  def close(): Unit = {
+    pool.close()
+  }
+
+  def reset(): Unit = {
+    // this is the best-effort of clearing up. otherwise we should close the 
pool and create again
+    // but we don't want to make it "var" only because of tests.
+    pool.clear()
   }
 
-  override protected def createKey(consumer: InternalKafkaConsumer): CacheKey 
= {
+  def numIdle: Int = pool.getNumIdle
+
+  def numIdle(key: CacheKey): Int = pool.getNumIdle(key)
+
+  def numActive: Int = pool.getNumActive
+
+  def numActive(key: CacheKey): Int = pool.getNumActive(key)
+
+  def size: Int = numIdle + numActive
+
+  def size(key: CacheKey): Int = numIdle(key) + numActive(key)
+
+  // TODO: revisit the relation between CacheKey and kafkaParams - for now it 
looks a bit weird
+  //   as we force all consumers having same (groupId, topicPartition) to have 
same kafkaParams
+  //   which might be viable in performance perspective (kafkaParams might be 
too huge to use
+  //   as a part of key), but there might be the case kafkaParams could be 
different -
+  //   cache key should be differentiated for both kafkaParams.
+  private def updateKafkaParamForKey(key: CacheKey, kafkaParams: 
ju.Map[String, Object]): Unit = {
+    // We can assume that kafkaParam should not be different for same cache 
key,
+    // otherwise we can't reuse the cached object and cache key should contain 
kafkaParam.
+    // So it should be safe to put the key/value pair only when the key 
doesn't exist.
+    val oldKafkaParams = objectFactory.keyToKafkaParams.putIfAbsent(key, 
kafkaParams)
+    require(oldKafkaParams == null || kafkaParams == oldKafkaParams, "Kafka 
parameters for same " +
+      s"cache key should be equal. old parameters: $oldKafkaParams new 
parameters: $kafkaParams")
+  }
+
+  private def extractCacheKey(consumer: InternalKafkaConsumer): CacheKey = {
     new CacheKey(consumer.topicPartition, consumer.kafkaParams)
   }
 }
 
-private class ConsumerPoolConfig(conf: SparkConf) extends 
PoolConfig[InternalKafkaConsumer] {
-  def softMaxSize: Int = conf.get(CONSUMER_CACHE_CAPACITY)
-  def jmxEnabled: Boolean = conf.get(CONSUMER_CACHE_JMX_ENABLED)
-  def minEvictableIdleTimeMillis: Long = conf.get(CONSUMER_CACHE_TIMEOUT)
-  def evictorThreadRunIntervalMillis: Long = 
conf.get(CONSUMER_CACHE_EVICTOR_THREAD_RUN_INTERVAL)
-  def jmxNamePrefix: String = "kafka010-cached-simple-kafka-consumer-pool"
-}
+private[kafka010] object InternalKafkaConsumerPool {
+  object CustomSwallowedExceptionListener extends SwallowedExceptionListener 
with Logging {
+    override def onSwallowException(e: Exception): Unit = {
+      logError(s"Error closing Kafka consumer", e)
+    }
+  }
+
+  class PoolConfig(conf: SparkConf) extends 
GenericKeyedObjectPoolConfig[InternalKafkaConsumer] {
+    private var _softMaxSize = Int.MaxValue
+
+    def softMaxSize: Int = _softMaxSize
+
+    init()
+
+    def init(): Unit = {
+      _softMaxSize = conf.get(CONSUMER_CACHE_CAPACITY)
+
+      val jmxEnabled = conf.get(CONSUMER_CACHE_JMX_ENABLED)
+      val minEvictableIdleTimeMillis = conf.get(CONSUMER_CACHE_TIMEOUT)
+      val evictorThreadRunIntervalMillis = conf.get(
+        CONSUMER_CACHE_EVICTOR_THREAD_RUN_INTERVAL)
 
-private class ConsumerObjectFactory extends ObjectFactory[CacheKey, 
InternalKafkaConsumer] {
-  override protected def createValue(
-      key: CacheKey,
-      kafkaParams: ju.Map[String, Object]): InternalKafkaConsumer = {
-    new InternalKafkaConsumer(key.topicPartition, kafkaParams)
+      // NOTE: Below lines define the behavior, so do not modify unless you 
know what you are
+      // doing, and update the class doc accordingly if necessary when you 
modify.
+
+      // 1. Set min idle objects per key to 0 to avoid creating unnecessary 
object.
+      // 2. Set max idle objects per key to 3 but set total objects per key to 
infinite
+      // which ensures borrowing per key is not restricted.
+      // 3. Set max total objects to infinite which ensures all objects are 
managed in this pool.
+      setMinIdlePerKey(0)
+      setMaxIdlePerKey(3)
+      setMaxTotalPerKey(-1)
+      setMaxTotal(-1)
+
+      // Set minimum evictable idle time which will be referred from evictor 
thread
+      setMinEvictableIdleTimeMillis(minEvictableIdleTimeMillis)
+      setSoftMinEvictableIdleTimeMillis(-1)
+
+      // evictor thread will run test with ten idle objects
+      setTimeBetweenEvictionRunsMillis(evictorThreadRunIntervalMillis)
+      setNumTestsPerEvictionRun(10)
+      setEvictionPolicy(new DefaultEvictionPolicy[InternalKafkaConsumer]())
+
+      // Immediately fail on exhausted pool while borrowing
+      setBlockWhenExhausted(false)
+
+      setJmxEnabled(jmxEnabled)
+      setJmxNamePrefix("kafka010-cached-simple-kafka-consumer-pool")
+    }
+  }
+
+  class ObjectFactory extends BaseKeyedPooledObjectFactory[CacheKey, 
InternalKafkaConsumer] {
+    val keyToKafkaParams = new ConcurrentHashMap[CacheKey, ju.Map[String, 
Object]]()
+
+    override def create(key: CacheKey): InternalKafkaConsumer = {
+      Option(keyToKafkaParams.get(key)) match {
+        case Some(kafkaParams) => new 
InternalKafkaConsumer(key.topicPartition, kafkaParams)
+        case None => throw new IllegalStateException("Kafka params should be 
set before " +
+          "borrowing object.")
+      }
+    }
+
+    override def wrap(value: InternalKafkaConsumer): 
PooledObject[InternalKafkaConsumer] = {
+      new DefaultPooledObject[InternalKafkaConsumer](value)
+    }
+
+    override def destroyObject(key: CacheKey, p: 
PooledObject[InternalKafkaConsumer]): Unit = {
+      p.getObject.close()
+    }
   }
 }
+
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaProducerPool.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaProducerPool.scala
deleted file mode 100644
index 165b643..0000000
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/InternalKafkaProducerPool.scala
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.kafka010
-
-import java.{util => ju}
-
-import scala.collection.JavaConverters._
-
-import org.apache.commons.pool2.PooledObject
-
-import org.apache.spark.SparkConf
-import org.apache.spark.sql.kafka010.InternalKafkaProducerPool.CacheKey
-
-private[kafka010] class InternalKafkaProducerPool(
-    objectFactory: ProducerObjectFactory,
-    poolConfig: ProducerPoolConfig)
-  extends InternalKafkaConnectorPool[CacheKey, CachedKafkaProducer](
-      objectFactory,
-      poolConfig,
-      new CustomSwallowedExceptionListener("producer")) {
-
-  def this(conf: SparkConf) = {
-    this(new ProducerObjectFactory, new ProducerPoolConfig(conf))
-  }
-
-  override protected def createKey(producer: CachedKafkaProducer): CacheKey = {
-    InternalKafkaProducerPool.toCacheKey(producer.kafkaParams)
-  }
-}
-
-private class ProducerPoolConfig(conf: SparkConf) extends 
PoolConfig[CachedKafkaProducer] {
-  def softMaxSize: Int = conf.get(PRODUCER_CACHE_CAPACITY)
-  def jmxEnabled: Boolean = conf.get(PRODUCER_CACHE_JMX_ENABLED)
-  def minEvictableIdleTimeMillis: Long = conf.get(PRODUCER_CACHE_TIMEOUT)
-  def evictorThreadRunIntervalMillis: Long = 
conf.get(PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL)
-  def jmxNamePrefix: String = "kafka010-cached-simple-kafka-producer-pool"
-}
-
-private class ProducerObjectFactory extends ObjectFactory[CacheKey, 
CachedKafkaProducer] {
-  override protected def createValue(
-      key: CacheKey,
-      kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = {
-    new CachedKafkaProducer(kafkaParams)
-  }
-}
-
-private[kafka010] object InternalKafkaProducerPool {
-  type CacheKey = Seq[(String, Object)]
-
-  def toCacheKey(params: ju.Map[String, Object]): CacheKey = {
-    params.asScala.toSeq.sortBy(x => x._1)
-  }
-}
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
index f2dad94..ca82c90 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
@@ -613,7 +613,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
       consumerPool.close()
     } catch {
       case e: Throwable =>
-        logWarning("Ignoring exception while shutting down pools from shutdown 
hook", e)
+        logWarning("Ignoring Exception while shutting down pools from shutdown 
hook", e)
     }
   }
 
@@ -639,11 +639,6 @@ private[kafka010] object KafkaDataConsumer extends Logging 
{
     new KafkaDataConsumer(topicPartition, kafkaParams, consumerPool, 
fetchedDataPool)
   }
 
-  private[kafka010] def clear(): Unit = {
-    consumerPool.reset()
-    fetchedDataPool.reset()
-  }
-
   private def reportDataLoss0(
       failOnDataLoss: Boolean,
       finalMessage: String,
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala
index 870ed7a..3f8d3d2 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala
@@ -44,7 +44,7 @@ private[kafka010] class KafkaDataWriter(
     inputSchema: Seq[Attribute])
   extends KafkaRowWriter(inputSchema, targetTopic) with 
DataWriter[InternalRow] {
 
-  private var producer = CachedKafkaProducer.acquire(producerParams)
+  private lazy val producer = CachedKafkaProducer.getOrCreate(producerParams)
 
   def write(row: InternalRow): Unit = {
     checkForErrors()
@@ -55,36 +55,20 @@ private[kafka010] class KafkaDataWriter(
     // Send is asynchronous, but we can't commit until all rows are actually 
in Kafka.
     // This requires flushing and then checking that no callbacks produced 
errors.
     // We also check for errors before to fail as soon as possible - the check 
is cheap.
-    try {
-      checkForErrors()
-      producer.flush()
-      checkForErrors()
-    } finally {
-      releaseProducer()
-    }
+    checkForErrors()
+    producer.flush()
+    checkForErrors()
     KafkaDataWriterCommitMessage
   }
 
-  def abort(): Unit = {
-    close()
-  }
+  def abort(): Unit = {}
 
   def close(): Unit = {
-    try {
-      checkForErrors()
-      if (producer != null) {
-        producer.flush()
-        checkForErrors()
-      }
-    } finally {
-      releaseProducer()
-    }
-  }
-
-  private def releaseProducer(): Unit = {
+    checkForErrors()
     if (producer != null) {
-      CachedKafkaProducer.release(producer)
-      producer = null
+      producer.flush()
+      checkForErrors()
+      CachedKafkaProducer.close(producerParams)
     }
   }
 }
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
index cfe3e16..8b90706 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
@@ -39,13 +39,13 @@ private[kafka010] class KafkaWriteTask(
     inputSchema: Seq[Attribute],
     topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) {
   // used to synchronize with Kafka callbacks
-  private var producer: CachedKafkaProducer = _
+  private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
 
   /**
    * Writes key value data out to topics.
    */
   def execute(iterator: Iterator[InternalRow]): Unit = {
-    producer = CachedKafkaProducer.acquire(producerConfiguration)
+    producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
     while (iterator.hasNext && failedWrite == null) {
       val currentRow = iterator.next()
       sendRow(currentRow, producer)
@@ -53,17 +53,11 @@ private[kafka010] class KafkaWriteTask(
   }
 
   def close(): Unit = {
-    try {
+    checkForErrors()
+    if (producer != null) {
+      producer.flush()
       checkForErrors()
-      if (producer != null) {
-        producer.flush()
-        checkForErrors()
-      }
-    } finally {
-      if (producer != null) {
-        CachedKafkaProducer.release(producer)
-        producer = null
-      }
+      producer = null
     }
   }
 }
@@ -89,7 +83,7 @@ private[kafka010] abstract class KafkaRowWriter(
    * assuming the row is in Kafka.
    */
   protected def sendRow(
-      row: InternalRow, producer: CachedKafkaProducer): Unit = {
+      row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): 
Unit = {
     val projectedRow = projection(row)
     val topic = projectedRow.getUTF8String(0)
     val key = projectedRow.getBinary(1)
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
index f103b5b..6f6ae55 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
@@ -26,6 +26,12 @@ package object kafka010 {   // scalastyle:ignore
   // ^^ scalastyle:ignore is for ignoring warnings about digits in package name
   type PartitionOffsetMap = Map[TopicPartition, Long]
 
+  private[kafka010] val PRODUCER_CACHE_TIMEOUT =
+    ConfigBuilder("spark.kafka.producer.cache.timeout")
+      .doc("The expire time to remove the unused producers.")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("10m")
+
   private[kafka010] val CONSUMER_CACHE_CAPACITY =
     ConfigBuilder("spark.kafka.consumer.cache.capacity")
       .doc("The maximum number of consumers cached. Please note it's a soft 
limit" +
@@ -68,32 +74,4 @@ package object kafka010 {   // scalastyle:ignore
         "When non-positive, no idle evictor thread will be run.")
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("1m")
-
-  private[kafka010] val PRODUCER_CACHE_CAPACITY =
-    ConfigBuilder("spark.kafka.producer.cache.capacity")
-      .doc("The maximum number of producers cached. Please note it's a soft 
limit" +
-        " (check Structured Streaming Kafka integration guide for further 
details).")
-      .intConf
-      .createWithDefault(64)
-
-  private[kafka010] val PRODUCER_CACHE_JMX_ENABLED =
-    ConfigBuilder("spark.kafka.producer.cache.jmx.enable")
-      .doc("Enable or disable JMX for pools created with this configuration 
instance.")
-      .booleanConf
-      .createWithDefault(false)
-
-  private[kafka010] val PRODUCER_CACHE_TIMEOUT =
-    ConfigBuilder("spark.kafka.producer.cache.timeout")
-      .doc("The minimum amount of time a producer may sit idle in the pool 
before " +
-        "it is eligible for eviction by the evictor. " +
-        "When non-positive, no producers will be evicted from the pool due to 
idle time alone.")
-      .timeConf(TimeUnit.MILLISECONDS)
-      .createWithDefaultString("5m")
-
-  private[kafka010] val PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL =
-    ConfigBuilder("spark.kafka.producer.cache.evictorThreadRunInterval")
-      .doc("The interval of time between runs of the idle evictor thread for 
producer pool. " +
-        "When non-positive, no idle evictor thread will be run.")
-      .timeConf(TimeUnit.MILLISECONDS)
-      .createWithDefaultString("1m")
 }
diff --git 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala
 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala
index 4506a40..7425a74 100644
--- 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala
+++ 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala
@@ -17,133 +17,61 @@
 
 package org.apache.spark.sql.kafka010
 
-import java.util.concurrent.{Executors, TimeUnit}
+import java.{util => ju}
+import java.util.concurrent.ConcurrentMap
 
-import scala.collection.JavaConverters._
-import scala.util.Random
-
-import 
org.apache.kafka.clients.consumer.ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG
-import org.apache.kafka.clients.producer.{Callback, ProducerRecord, 
RecordMetadata}
-import 
org.apache.kafka.clients.producer.ProducerConfig.{KEY_SERIALIZER_CLASS_CONFIG, 
VALUE_SERIALIZER_CLASS_CONFIG}
+import org.apache.kafka.clients.producer.KafkaProducer
 import org.apache.kafka.common.serialization.ByteArraySerializer
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{TaskContext, TaskContextImpl}
-import org.apache.spark.sql.kafka010.InternalKafkaProducerPool._
 import org.apache.spark.sql.test.SharedSparkSession
 
 class CachedKafkaProducerSuite extends SharedSparkSession with 
PrivateMethodTester with KafkaTest {
 
-  private var testUtils: KafkaTestUtils = _
-  private val topic = "topic" + Random.nextInt()
-  private var producerPool: InternalKafkaProducerPool = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    testUtils = new KafkaTestUtils(Map[String, Object]())
-    testUtils.setup()
-  }
-
-  override def afterAll(): Unit = {
-    if (testUtils != null) {
-      testUtils.teardown()
-      testUtils = null
-    }
-    super.afterAll()
-  }
+  type KP = KafkaProducer[Array[Byte], Array[Byte]]
 
-  override def beforeEach(): Unit = {
+  protected override def beforeEach(): Unit = {
     super.beforeEach()
-
-    producerPool = {
-      val internalKafkaConsumerPoolMethod = 
PrivateMethod[InternalKafkaProducerPool]('producerPool)
-      CachedKafkaProducer.invokePrivate(internalKafkaConsumerPoolMethod())
-    }
-
-    producerPool.reset()
+    CachedKafkaProducer.clear()
   }
 
-  private def getKafkaParams(acks: Int = 0) = Map[String, Object](
-    "acks" -> acks.toString,
+  test("Should return the cached instance on calling getOrCreate with same 
params.") {
+    val kafkaParams = new ju.HashMap[String, Object]()
+    kafkaParams.put("acks", "0")
     // Here only host should be resolvable, it does not need a running 
instance of kafka server.
-    BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
-    KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
-    VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName
-  ).asJava
-
-  test("acquire should return the cached instance with same params") {
-    val kafkaParams = getKafkaParams()
-
-    val producer1 = CachedKafkaProducer.acquire(kafkaParams)
-    CachedKafkaProducer.release(producer1)
-    val producer2 = CachedKafkaProducer.acquire(kafkaParams)
-    CachedKafkaProducer.release(producer2)
-
-    assert(producer1 === producer2)
-    assert(producerPool.size(toCacheKey(kafkaParams)) === 1)
+    kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
+    kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
+    kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
+    val producer = CachedKafkaProducer.getOrCreate(kafkaParams)
+    val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams)
+    assert(producer == producer2)
+
+    val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], 
KP]](Symbol("getAsMap"))
+    val map = CachedKafkaProducer.invokePrivate(cacheMap())
+    assert(map.size == 1)
   }
 
-  test("acquire should return a new instance with different params") {
-    val kafkaParams1 = getKafkaParams()
-    val kafkaParams2 = getKafkaParams(1)
-
-    val producer1 = CachedKafkaProducer.acquire(kafkaParams1)
-    CachedKafkaProducer.release(producer1)
-    val producer2 = CachedKafkaProducer.acquire(kafkaParams2)
-    CachedKafkaProducer.release(producer2)
-
-    assert(producer1 !== producer2)
-    assert(producerPool.size(toCacheKey(kafkaParams1)) === 1)
-    assert(producerPool.size(toCacheKey(kafkaParams2)) === 1)
-  }
-
-  test("Concurrent use of CachedKafkaProducer") {
-    val data = (1 to 1000).map(_.toString)
-    testUtils.createTopic(topic, 1)
-
-    val kafkaParams = getKafkaParams()
-    val numThreads = 100
-    val numProducerUsages = 500
-
-    @volatile var error: Throwable = null
-
-    val callback = new Callback() {
-      override def onCompletion(recordMetadata: RecordMetadata, e: Exception): 
Unit = {
-        if (error == null && e != null) {
-          error = e
-        }
-      }
-    }
-
-    def produce(): Unit = {
-      val taskContext = if (Random.nextBoolean) {
-        new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), 
null, null, null)
-      } else {
-        null
-      }
-      TaskContext.setTaskContext(taskContext)
-      val producer = CachedKafkaProducer.acquire(kafkaParams)
-      try {
-        data.foreach { d =>
-          val record = new ProducerRecord[Array[Byte], Array[Byte]](topic, 0, 
null, d.getBytes)
-          producer.send(record, callback)
-        }
-      } finally {
-        CachedKafkaProducer.release(producer)
-      }
-    }
-
-    val threadpool = Executors.newFixedThreadPool(numThreads)
-    try {
-      val futures = (1 to numProducerUsages).map { i =>
-        threadpool.submit(new Runnable {
-          override def run(): Unit = { produce() }
-        })
-      }
-      futures.foreach(_.get(1, TimeUnit.MINUTES))
-      assert(error == null)
-    } finally {
-      threadpool.shutdown()
-    }
+  test("Should close the correct kafka producer for the given kafkaPrams.") {
+    val kafkaParams = new ju.HashMap[String, Object]()
+    kafkaParams.put("acks", "0")
+    kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
+    kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
+    kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
+    val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
+    kafkaParams.put("acks", "1")
+    val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
+    // With updated conf, a new producer instance should be created.
+    assert(producer != producer2)
+
+    val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], 
KP]](Symbol("getAsMap"))
+    val map = CachedKafkaProducer.invokePrivate(cacheMap())
+    assert(map.size == 2)
+
+    CachedKafkaProducer.close(kafkaParams)
+    val map2 = CachedKafkaProducer.invokePrivate(cacheMap())
+    assert(map2.size == 1)
+    import scala.collection.JavaConverters._
+    val (seq: Seq[(String, Object)], _producer: KP) = 
map2.asScala.toArray.apply(0)
+    assert(_producer == producer)
   }
 }
diff --git 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConnectorPoolSuite.scala
 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala
similarity index 96%
rename from 
external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConnectorPoolSuite.scala
rename to 
external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala
index 3143429..78d7fee 100644
--- 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConnectorPoolSuite.scala
+++ 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/InternalKafkaConsumerPoolSuite.scala
@@ -29,13 +29,7 @@ import org.apache.spark.SparkConf
 import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey
 import org.apache.spark.sql.test.SharedSparkSession
 
-/*
- * There are multiple implementations of [[InternalKafkaConnectorPool]] but 
they don't differ
- * significantly. Because of that only [[InternalKafkaConsumerPool]] used to 
test all the
- * functionality. If the behavior of implementations starts to differ it worth 
to add further
- * tests but for now it would be mainly copy-paste.
- */
-class InternalKafkaConnectorPoolSuite extends SharedSparkSession {
+class InternalKafkaConsumerPoolSuite extends SharedSparkSession {
 
   test("basic multiple borrows and returns for single key") {
     val pool = new InternalKafkaConsumerPool(new SparkConf())
diff --git 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
index 6e1f10e..d229551 100644
--- 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
+++ 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
@@ -195,7 +195,7 @@ class KafkaDataConsumerSuite
 
     @volatile var error: Throwable = null
 
-    def consume(): Unit = {
+    def consume(i: Int): Unit = {
       val taskContext = if (Random.nextBoolean) {
         new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), 
null, null, null)
       } else {
@@ -233,9 +233,9 @@ class KafkaDataConsumerSuite
 
     val threadpool = Executors.newFixedThreadPool(numThreads)
     try {
-      val futures = (1 to numConsumerUsages).map { _ =>
+      val futures = (1 to numConsumerUsages).map { i =>
         threadpool.submit(new Runnable {
-          override def run(): Unit = { consume() }
+          override def run(): Unit = { consume(i) }
         })
       }
       futures.foreach(_.get(1, TimeUnit.MINUTES))
diff --git 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala
 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala
index 2900322..19acda9 100644
--- 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala
+++ 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala
@@ -21,16 +21,12 @@ import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.SparkFunSuite
 
-/** A trait to clean cached Kafka connector in `afterAll` */
+/** A trait to clean cached Kafka producers in `afterAll` */
 trait KafkaTest extends BeforeAndAfterAll {
   self: SparkFunSuite =>
 
   override def afterAll(): Unit = {
-    try {
-      KafkaDataConsumer.clear()
-      CachedKafkaProducer.clear()
-    } finally {
-      super.afterAll()
-    }
+    super.afterAll()
+    CachedKafkaProducer.clear()
   }
 }
diff --git 
a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
 
b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
index 246672b..82913cf 100644
--- 
a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
+++ 
b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
@@ -121,6 +121,8 @@ class KafkaDataConsumerSuite extends SparkFunSuite with 
MockitoSugar with Before
     val numThreads = 100
     val numConsumerUsages = 500
 
+    @volatile var error: Throwable = null
+
     def consume(i: Int): Unit = {
       val useCache = Random.nextBoolean
       val taskContext = if (Random.nextBoolean) {
@@ -136,6 +138,10 @@ class KafkaDataConsumerSuite extends SparkFunSuite with 
MockitoSugar with Before
           new String(bytes)
         }
         assert(rcvd == data)
+      } catch {
+        case e: Throwable =>
+          error = e
+          throw e
       } finally {
         consumer.release()
       }
@@ -149,6 +155,7 @@ class KafkaDataConsumerSuite extends SparkFunSuite with 
MockitoSugar with Before
         })
       }
       futures.foreach(_.get(1, TimeUnit.MINUTES))
+      assert(error == null)
     } finally {
       threadPool.shutdown()
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to