Repository: spark
Updated Branches:
  refs/heads/branch-2.3 21b6de459 -> 6937571ab


[SPARK-23623][SS] Avoid concurrent use of cached consumers in 
CachedKafkaConsumer (branch-2.3)

This is a backport of #20767 to branch 2.3

## What changes were proposed in this pull request?
CacheKafkaConsumer in the project `kafka-0-10-sql` is designed to maintain a 
pool of KafkaConsumers that can be reused. However, it was built with the 
assumption there will be only one task using trying to read the same Kafka 
TopicPartition at the same time. Hence, the cache was keyed by the 
TopicPartition a consumer is supposed to read. And any cases where this 
assumption may not be true, we have SparkPlan flag to disable the use of a 
cache. So it was up to the planner to correctly identify when it was not safe 
to use the cache and set the flag accordingly.

Fundamentally, this is the wrong way to approach the problem. It is HARD for a 
high-level planner to reason about the low-level execution model, whether there 
will be multiple tasks in the same query trying to read the same partition. 
Case in point, 2.3.0 introduced stream-stream joins, and you can build a 
streaming self-join query on Kafka. It's pretty non-trivial to figure out how 
this leads to two tasks reading the same partition twice, possibly 
concurrently. And due to the non-triviality, it is hard to figure this out in 
the planner and set the flag to avoid the cache / consumer pool. And this can 
inadvertently lead to ConcurrentModificationException ,or worse, silent reading 
of incorrect data.

Here is a better way to design this. The planner shouldnt have to understand 
these low-level optimizations. Rather the consumer pool should be smart enough 
avoid concurrent use of a cached consumer. Currently, it tries to do so but 
incorrectly (the flag inuse is not checked when returning a cached consumer, 
see 
[this](https://github.com/apache/spark/blob/master/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala#L403)).
 If there is another request for the same partition as a currently in-use 
consumer, the pool should automatically return a fresh consumer that should be 
closed when the task is done. Then the planner does not have to have a flag to 
avoid reuses.

This PR is a step towards that goal. It does the following.
- There are effectively two kinds of consumer that may be generated
  - Cached consumer - this should be returned to the pool at task end
  - Non-cached consumer - this should be closed at task end
- A trait called KafkaConsumer is introduced to hide this difference from the 
users of the consumer so that the client code does not have to reason about 
whether to stop and release. They simply called `val consumer = 
KafkaConsumer.acquire` and then `consumer.release()`.
- If there is request for a consumer that is in-use, then a new consumer is 
generated.
- If there is a concurrent attempt of the same task, then a new consumer is 
generated, and the existing cached consumer is marked for close upon release.
- In addition, I renamed the classes because CachedKafkaConsumer is a misnomer 
given that what it returns may or may not be cached.

This PR does not remove the planner flag to avoid reuse to make this patch safe 
enough for merging in branch-2.3. This can be done later in master-only.

## How was this patch tested?
A new stress test that verifies it is safe to concurrently get consumers for 
the same partition from the consumer pool.

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #20848 from tdas/SPARK-23623-2.3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6937571a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6937571a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6937571a

Branch: refs/heads/branch-2.3
Commit: 6937571ab8818a62ec2457a373eb3f6f618985e1
Parents: 21b6de4
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Sat Mar 17 16:24:51 2018 -0700
Committer: Shixiong Zhu <zsxw...@gmail.com>
Committed: Sat Mar 17 16:24:51 2018 -0700

----------------------------------------------------------------------
 .../sql/kafka010/CachedKafkaConsumer.scala      | 438 ----------------
 .../sql/kafka010/KafkaContinuousReader.scala    |   4 +-
 .../spark/sql/kafka010/KafkaDataConsumer.scala  | 516 +++++++++++++++++++
 .../spark/sql/kafka010/KafkaSourceRDD.scala     |  23 +-
 .../sql/kafka010/CachedKafkaConsumerSuite.scala |  34 --
 .../sql/kafka010/KafkaDataConsumerSuite.scala   | 124 +++++
 6 files changed, 648 insertions(+), 491 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6937571a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
deleted file mode 100644
index 90ed7b1..0000000
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
+++ /dev/null
@@ -1,438 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.kafka010
-
-import java.{util => ju}
-import java.util.concurrent.TimeoutException
-
-import scala.collection.JavaConverters._
-
-import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, 
KafkaConsumer, OffsetOutOfRangeException}
-import org.apache.kafka.common.TopicPartition
-
-import org.apache.spark.{SparkEnv, SparkException, TaskContext}
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.kafka010.KafkaSource._
-import org.apache.spark.util.UninterruptibleThread
-
-
-/**
- * Consumer of single topicpartition, intended for cached reuse.
- * Underlying consumer is not threadsafe, so neither is this,
- * but processing the same topicpartition and group id in multiple threads is 
usually bad anyway.
- */
-private[kafka010] case class CachedKafkaConsumer private(
-    topicPartition: TopicPartition,
-    kafkaParams: ju.Map[String, Object]) extends Logging {
-  import CachedKafkaConsumer._
-
-  private val groupId = 
kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
-
-  private var consumer = createConsumer
-
-  /** indicates whether this consumer is in use or not */
-  private var inuse = true
-
-  /** Iterator to the already fetch data */
-  private var fetchedData = 
ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]]
-  private var nextOffsetInFetchedData = UNKNOWN_OFFSET
-
-  /** Create a KafkaConsumer to fetch records for `topicPartition` */
-  private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = {
-    val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
-    val tps = new ju.ArrayList[TopicPartition]()
-    tps.add(topicPartition)
-    c.assign(tps)
-    c
-  }
-
-  case class AvailableOffsetRange(earliest: Long, latest: Long)
-
-  private def runUninterruptiblyIfPossible[T](body: => T): T = 
Thread.currentThread match {
-    case ut: UninterruptibleThread =>
-      ut.runUninterruptibly(body)
-    case _ =>
-      logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. 
" +
-        "It may hang when CachedKafkaConsumer's methods are interrupted 
because of KAFKA-1894")
-      body
-  }
-
-  /**
-   * Return the available offset range of the current partition. It's a pair 
of the earliest offset
-   * and the latest offset.
-   */
-  def getAvailableOffsetRange(): AvailableOffsetRange = 
runUninterruptiblyIfPossible {
-    consumer.seekToBeginning(Set(topicPartition).asJava)
-    val earliestOffset = consumer.position(topicPartition)
-    consumer.seekToEnd(Set(topicPartition).asJava)
-    val latestOffset = consumer.position(topicPartition)
-    AvailableOffsetRange(earliestOffset, latestOffset)
-  }
-
-  /**
-   * Get the record for the given offset if available. Otherwise it will 
either throw error
-   * (if failOnDataLoss = true), or return the next available offset within 
[offset, untilOffset),
-   * or null.
-   *
-   * @param offset the offset to fetch.
-   * @param untilOffset the max offset to fetch. Exclusive.
-   * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
-   * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will 
either return record at
-   *                       offset if available, or throw exception.when 
`failOnDataLoss` is `false`,
-   *                       this method will either return record at offset if 
available, or return
-   *                       the next earliest available record less than 
untilOffset, or null. It
-   *                       will not throw any exception.
-   */
-  def get(
-      offset: Long,
-      untilOffset: Long,
-      pollTimeoutMs: Long,
-      failOnDataLoss: Boolean):
-    ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible {
-    require(offset < untilOffset,
-      s"offset must always be less than untilOffset [offset: $offset, 
untilOffset: $untilOffset]")
-    logDebug(s"Get $groupId $topicPartition nextOffset 
$nextOffsetInFetchedData requested $offset")
-    // The following loop is basically for `failOnDataLoss = false`. When 
`failOnDataLoss` is
-    // `false`, first, we will try to fetch the record at `offset`. If no such 
record exists, then
-    // we will move to the next available offset within `[offset, 
untilOffset)` and retry.
-    // If `failOnDataLoss` is `true`, the loop body will be executed only once.
-    var toFetchOffset = offset
-    var consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]] = null
-    // We want to break out of the while loop on a successful fetch to avoid 
using "return"
-    // which may causes a NonLocalReturnControl exception when this method is 
used as a function.
-    var isFetchComplete = false
-
-    while (toFetchOffset != UNKNOWN_OFFSET && !isFetchComplete) {
-      try {
-        consumerRecord = fetchData(toFetchOffset, untilOffset, pollTimeoutMs, 
failOnDataLoss)
-        isFetchComplete = true
-      } catch {
-        case e: OffsetOutOfRangeException =>
-          // When there is some error thrown, it's better to use a new 
consumer to drop all cached
-          // states in the old consumer. We don't need to worry about the 
performance because this
-          // is not a common path.
-          resetConsumer()
-          reportDataLoss(failOnDataLoss, s"Cannot fetch offset 
$toFetchOffset", e)
-          toFetchOffset = getEarliestAvailableOffsetBetween(toFetchOffset, 
untilOffset)
-      }
-    }
-
-    if (isFetchComplete) {
-      consumerRecord
-    } else {
-      resetFetchedData()
-      null
-    }
-  }
-
-  /**
-   * Return the next earliest available offset in [offset, untilOffset). If 
all offsets in
-   * [offset, untilOffset) are invalid (e.g., the topic is deleted and 
recreated), it will return
-   * `UNKNOWN_OFFSET`.
-   */
-  private def getEarliestAvailableOffsetBetween(offset: Long, untilOffset: 
Long): Long = {
-    val range = getAvailableOffsetRange()
-    logWarning(s"Some data may be lost. Recovering from the earliest offset: 
${range.earliest}")
-    if (offset >= range.latest || range.earliest >= untilOffset) {
-      // [offset, untilOffset) and [earliestOffset, latestOffset) have no 
overlap,
-      // either
-      // --------------------------------------------------------
-      //         ^                 ^         ^         ^
-      //         |                 |         |         |
-      //   earliestOffset   latestOffset   offset   untilOffset
-      //
-      // or
-      // --------------------------------------------------------
-      //      ^          ^              ^                ^
-      //      |          |              |                |
-      //   offset   untilOffset   earliestOffset   latestOffset
-      val warningMessage =
-        s"""
-          |The current available offset range is $range.
-          | Offset ${offset} is out of range, and records in [$offset, 
$untilOffset) will be
-          | skipped ${additionalMessage(failOnDataLoss = false)}
-        """.stripMargin
-      logWarning(warningMessage)
-      UNKNOWN_OFFSET
-    } else if (offset >= range.earliest) {
-      // 
-----------------------------------------------------------------------------
-      //         ^            ^                  ^                             
    ^
-      //         |            |                  |                             
    |
-      //   earliestOffset   offset   min(untilOffset,latestOffset)   
max(untilOffset, latestOffset)
-      //
-      // This will happen when a topic is deleted and recreated, and new data 
are pushed very fast,
-      // then we will see `offset` disappears first then appears again. 
Although the parameters
-      // are same, the state in Kafka cluster is changed, so the outer loop 
won't be endless.
-      logWarning(s"Found a disappeared offset $offset. " +
-        s"Some data may be lost ${additionalMessage(failOnDataLoss = false)}")
-      offset
-    } else {
-      // 
------------------------------------------------------------------------------
-      //      ^           ^                       ^                            
     ^
-      //      |           |                       |                            
     |
-      //   offset   earliestOffset   min(untilOffset,latestOffset)   
max(untilOffset, latestOffset)
-      val warningMessage =
-        s"""
-           |The current available offset range is $range.
-           | Offset ${offset} is out of range, and records in [$offset, 
${range.earliest}) will be
-           | skipped ${additionalMessage(failOnDataLoss = false)}
-        """.stripMargin
-      logWarning(warningMessage)
-      range.earliest
-    }
-  }
-
-  /**
-   * Get the record for the given offset if available. Otherwise it will 
either throw error
-   * (if failOnDataLoss = true), or return the next available offset within 
[offset, untilOffset),
-   * or null.
-   *
-   * @throws OffsetOutOfRangeException if `offset` is out of range
-   * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` 
milliseconds.
-   */
-  private def fetchData(
-      offset: Long,
-      untilOffset: Long,
-      pollTimeoutMs: Long,
-      failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = {
-    if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) {
-      // This is the first fetch, or the last pre-fetched data has been 
drained.
-      // Seek to the offset because we may call seekToBeginning or seekToEnd 
before this.
-      seek(offset)
-      poll(pollTimeoutMs)
-    }
-
-    if (!fetchedData.hasNext()) {
-      // We cannot fetch anything after `poll`. Two possible cases:
-      // - `offset` is out of range so that Kafka returns nothing. Just throw
-      // `OffsetOutOfRangeException` to let the caller handle it.
-      // - Cannot fetch any data before timeout. TimeoutException will be 
thrown.
-      val range = getAvailableOffsetRange()
-      if (offset < range.earliest || offset >= range.latest) {
-        throw new OffsetOutOfRangeException(
-          Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava)
-      } else {
-        throw new TimeoutException(
-          s"Cannot fetch record for offset $offset in $pollTimeoutMs 
milliseconds")
-      }
-    } else {
-      val record = fetchedData.next()
-      nextOffsetInFetchedData = record.offset + 1
-      // In general, Kafka uses the specified offset as the start point, and 
tries to fetch the next
-      // available offset. Hence we need to handle offset mismatch.
-      if (record.offset > offset) {
-        // This may happen when some records aged out but their offsets 
already got verified
-        if (failOnDataLoss) {
-          reportDataLoss(true, s"Cannot fetch records in [$offset, 
${record.offset})")
-          // Never happen as "reportDataLoss" will throw an exception
-          null
-        } else {
-          if (record.offset >= untilOffset) {
-            reportDataLoss(false, s"Skip missing records in [$offset, 
$untilOffset)")
-            null
-          } else {
-            reportDataLoss(false, s"Skip missing records in [$offset, 
${record.offset})")
-            record
-          }
-        }
-      } else if (record.offset < offset) {
-        // This should not happen. If it does happen, then we probably 
misunderstand Kafka internal
-        // mechanism.
-        throw new IllegalStateException(
-          s"Tried to fetch $offset but the returned record offset was 
${record.offset}")
-      } else {
-        record
-      }
-    }
-  }
-
-  /** Create a new consumer and reset cached states */
-  private def resetConsumer(): Unit = {
-    consumer.close()
-    consumer = createConsumer
-    resetFetchedData()
-  }
-
-  /** Reset the internal pre-fetched data. */
-  private def resetFetchedData(): Unit = {
-    nextOffsetInFetchedData = UNKNOWN_OFFSET
-    fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], 
Array[Byte]]]
-  }
-
-  /**
-   * Return an addition message including useful message and instruction.
-   */
-  private def additionalMessage(failOnDataLoss: Boolean): String = {
-    if (failOnDataLoss) {
-      s"(GroupId: $groupId, TopicPartition: $topicPartition). " +
-        s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE"
-    } else {
-      s"(GroupId: $groupId, TopicPartition: $topicPartition). " +
-        s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE"
-    }
-  }
-
-  /**
-   * Throw an exception or log a warning as per `failOnDataLoss`.
-   */
-  private def reportDataLoss(
-      failOnDataLoss: Boolean,
-      message: String,
-      cause: Throwable = null): Unit = {
-    val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}"
-    reportDataLoss0(failOnDataLoss, finalMessage, cause)
-  }
-
-  def close(): Unit = consumer.close()
-
-  private def seek(offset: Long): Unit = {
-    logDebug(s"Seeking to $groupId $topicPartition $offset")
-    consumer.seek(topicPartition, offset)
-  }
-
-  private def poll(pollTimeoutMs: Long): Unit = {
-    val p = consumer.poll(pollTimeoutMs)
-    val r = p.records(topicPartition)
-    logDebug(s"Polled $groupId ${p.partitions()}  ${r.size}")
-    fetchedData = r.iterator
-  }
-}
-
-private[kafka010] object CachedKafkaConsumer extends Logging {
-
-  private val UNKNOWN_OFFSET = -2L
-
-  private case class CacheKey(groupId: String, topicPartition: TopicPartition)
-
-  private lazy val cache = {
-    val conf = SparkEnv.get.conf
-    val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64)
-    new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) 
{
-      override def removeEldestEntry(
-        entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = {
-        if (entry.getValue.inuse == false && this.size > capacity) {
-          logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, 
" +
-            s"removing consumer for ${entry.getKey}")
-          try {
-            entry.getValue.close()
-          } catch {
-            case e: SparkException =>
-              logError(s"Error closing earliest Kafka consumer for 
${entry.getKey}", e)
-          }
-          true
-        } else {
-          false
-        }
-      }
-    }
-  }
-
-  def releaseKafkaConsumer(
-      topic: String,
-      partition: Int,
-      kafkaParams: ju.Map[String, Object]): Unit = {
-    val groupId = 
kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
-    val topicPartition = new TopicPartition(topic, partition)
-    val key = CacheKey(groupId, topicPartition)
-
-    synchronized {
-      val consumer = cache.get(key)
-      if (consumer != null) {
-        consumer.inuse = false
-      } else {
-        logWarning(s"Attempting to release consumer that does not exist")
-      }
-    }
-  }
-
-  /**
-   * Removes (and closes) the Kafka Consumer for the given topic, partition 
and group id.
-   */
-  def removeKafkaConsumer(
-      topic: String,
-      partition: Int,
-      kafkaParams: ju.Map[String, Object]): Unit = {
-    val groupId = 
kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
-    val topicPartition = new TopicPartition(topic, partition)
-    val key = CacheKey(groupId, topicPartition)
-
-    synchronized {
-      val removedConsumer = cache.remove(key)
-      if (removedConsumer != null) {
-        removedConsumer.close()
-      }
-    }
-  }
-
-  /**
-   * Get a cached consumer for groupId, assigned to topic and partition.
-   * If matching consumer doesn't already exist, will be created using 
kafkaParams.
-   */
-  def getOrCreate(
-      topic: String,
-      partition: Int,
-      kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized 
{
-    val groupId = 
kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
-    val topicPartition = new TopicPartition(topic, partition)
-    val key = CacheKey(groupId, topicPartition)
-
-    // If this is reattempt at running the task, then invalidate cache and 
start with
-    // a new consumer
-    if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) {
-      removeKafkaConsumer(topic, partition, kafkaParams)
-      val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams)
-      consumer.inuse = true
-      cache.put(key, consumer)
-      consumer
-    } else {
-      if (!cache.containsKey(key)) {
-        cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams))
-      }
-      val consumer = cache.get(key)
-      consumer.inuse = true
-      consumer
-    }
-  }
-
-  /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */
-  def createUncached(
-      topic: String,
-      partition: Int,
-      kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = {
-    new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams)
-  }
-
-  private def reportDataLoss0(
-      failOnDataLoss: Boolean,
-      finalMessage: String,
-      cause: Throwable = null): Unit = {
-    if (failOnDataLoss) {
-      if (cause != null) {
-        throw new IllegalStateException(finalMessage, cause)
-      } else {
-        throw new IllegalStateException(finalMessage)
-      }
-    } else {
-      if (cause != null) {
-        logWarning(finalMessage, cause)
-      } else {
-        logWarning(finalMessage)
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/6937571a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
index a269a50..a2a4c83 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -189,7 +189,7 @@ class KafkaContinuousDataReader(
     failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] {
   private val topic = topicPartition.topic
   private val kafkaPartition = topicPartition.partition
-  private val consumer = CachedKafkaConsumer.createUncached(topic, 
kafkaPartition, kafkaParams)
+  private val consumer = KafkaDataConsumer.acquire(topicPartition, 
kafkaParams, useCache = false)
 
   private val sharedRow = new UnsafeRow(7)
   private val bufferHolder = new BufferHolder(sharedRow)
@@ -255,6 +255,6 @@ class KafkaContinuousDataReader(
   }
 
   override def close(): Unit = {
-    consumer.close()
+    consumer.release()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6937571a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
new file mode 100644
index 0000000..dcf2f63
--- /dev/null
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
@@ -0,0 +1,516 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+import java.util.concurrent.TimeoutException
+
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, 
KafkaConsumer, OffsetOutOfRangeException}
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.kafka010.KafkaDataConsumer.AvailableOffsetRange
+import org.apache.spark.sql.kafka010.KafkaSource._
+import org.apache.spark.util.UninterruptibleThread
+
+private[kafka010] sealed trait KafkaDataConsumer {
+  /**
+   * Get the record for the given offset if available. Otherwise it will 
either throw error
+   * (if failOnDataLoss = true), or return the next available offset within 
[offset, untilOffset),
+   * or null.
+   *
+   * @param offset         the offset to fetch.
+   * @param untilOffset    the max offset to fetch. Exclusive.
+   * @param pollTimeoutMs  timeout in milliseconds to poll data from Kafka.
+   * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will 
either return record at
+   *                       offset if available, or throw exception.when 
`failOnDataLoss` is `false`,
+   *                       this method will either return record at offset if 
available, or return
+   *                       the next earliest available record less than 
untilOffset, or null. It
+   *                       will not throw any exception.
+   */
+  def get(
+      offset: Long,
+      untilOffset: Long,
+      pollTimeoutMs: Long,
+      failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = {
+    internalConsumer.get(offset, untilOffset, pollTimeoutMs, failOnDataLoss)
+  }
+
+  /**
+   * Return the available offset range of the current partition. It's a pair 
of the earliest offset
+   * and the latest offset.
+   */
+  def getAvailableOffsetRange(): AvailableOffsetRange = 
internalConsumer.getAvailableOffsetRange()
+
+  /**
+   * Release this consumer from being further used. Depending on its 
implementation,
+   * this consumer will be either finalized, or reset for reuse later.
+   */
+  def release(): Unit
+
+  /** Reference to the internal implementation that this wrapper delegates to 
*/
+  protected def internalConsumer: InternalKafkaConsumer
+}
+
+
+/**
+ * A wrapper around Kafka's KafkaConsumer that throws error when data loss is 
detected.
+ * This is not for direct use outside this file.
+ */
+private[kafka010] case class InternalKafkaConsumer(
+    topicPartition: TopicPartition,
+    kafkaParams: ju.Map[String, Object]) extends Logging {
+  import InternalKafkaConsumer._
+
+  private val groupId = 
kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+
+  @volatile private var consumer = createConsumer
+
+  /** indicates whether this consumer is in use or not */
+  @volatile var inUse = true
+
+  /** indicate whether this consumer is going to be stopped in the next 
release */
+  @volatile var markedForClose = false
+
+  /** Iterator to the already fetch data */
+  @volatile private var fetchedData =
+    ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]]
+  @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET
+
+  /** Create a KafkaConsumer to fetch records for `topicPartition` */
+  private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = {
+    val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
+    val tps = new ju.ArrayList[TopicPartition]()
+    tps.add(topicPartition)
+    c.assign(tps)
+    c
+  }
+
+  private def runUninterruptiblyIfPossible[T](body: => T): T = 
Thread.currentThread match {
+    case ut: UninterruptibleThread =>
+      ut.runUninterruptibly(body)
+    case _ =>
+      logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. 
" +
+        "It may hang when CachedKafkaConsumer's methods are interrupted 
because of KAFKA-1894")
+      body
+  }
+
+  /**
+   * Return the available offset range of the current partition. It's a pair 
of the earliest offset
+   * and the latest offset.
+   */
+  def getAvailableOffsetRange(): AvailableOffsetRange = 
runUninterruptiblyIfPossible {
+    consumer.seekToBeginning(Set(topicPartition).asJava)
+    val earliestOffset = consumer.position(topicPartition)
+    consumer.seekToEnd(Set(topicPartition).asJava)
+    val latestOffset = consumer.position(topicPartition)
+    AvailableOffsetRange(earliestOffset, latestOffset)
+  }
+
+  /**
+   * Get the record for the given offset if available. Otherwise it will 
either throw error
+   * (if failOnDataLoss = true), or return the next available offset within 
[offset, untilOffset),
+   * or null.
+   *
+   * @param offset the offset to fetch.
+   * @param untilOffset the max offset to fetch. Exclusive.
+   * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+   * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will 
either return record at
+   *                       offset if available, or throw exception.when 
`failOnDataLoss` is `false`,
+   *                       this method will either return record at offset if 
available, or return
+   *                       the next earliest available record less than 
untilOffset, or null. It
+   *                       will not throw any exception.
+   */
+  def get(
+      offset: Long,
+      untilOffset: Long,
+      pollTimeoutMs: Long,
+      failOnDataLoss: Boolean):
+    ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible {
+    require(offset < untilOffset,
+      s"offset must always be less than untilOffset [offset: $offset, 
untilOffset: $untilOffset]")
+    logDebug(s"Get $groupId $topicPartition nextOffset 
$nextOffsetInFetchedData requested $offset")
+    // The following loop is basically for `failOnDataLoss = false`. When 
`failOnDataLoss` is
+    // `false`, first, we will try to fetch the record at `offset`. If no such 
record exists, then
+    // we will move to the next available offset within `[offset, 
untilOffset)` and retry.
+    // If `failOnDataLoss` is `true`, the loop body will be executed only once.
+    var toFetchOffset = offset
+    var consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]] = null
+    // We want to break out of the while loop on a successful fetch to avoid 
using "return"
+    // which may causes a NonLocalReturnControl exception when this method is 
used as a function.
+    var isFetchComplete = false
+
+    while (toFetchOffset != UNKNOWN_OFFSET && !isFetchComplete) {
+      try {
+        consumerRecord = fetchData(toFetchOffset, untilOffset, pollTimeoutMs, 
failOnDataLoss)
+        isFetchComplete = true
+      } catch {
+        case e: OffsetOutOfRangeException =>
+          // When there is some error thrown, it's better to use a new 
consumer to drop all cached
+          // states in the old consumer. We don't need to worry about the 
performance because this
+          // is not a common path.
+          resetConsumer()
+          reportDataLoss(failOnDataLoss, s"Cannot fetch offset 
$toFetchOffset", e)
+          toFetchOffset = getEarliestAvailableOffsetBetween(toFetchOffset, 
untilOffset)
+      }
+    }
+
+    if (isFetchComplete) {
+      consumerRecord
+    } else {
+      resetFetchedData()
+      null
+    }
+  }
+
+  /**
+   * Return the next earliest available offset in [offset, untilOffset). If 
all offsets in
+   * [offset, untilOffset) are invalid (e.g., the topic is deleted and 
recreated), it will return
+   * `UNKNOWN_OFFSET`.
+   */
+  private def getEarliestAvailableOffsetBetween(offset: Long, untilOffset: 
Long): Long = {
+    val range = getAvailableOffsetRange()
+    logWarning(s"Some data may be lost. Recovering from the earliest offset: 
${range.earliest}")
+    if (offset >= range.latest || range.earliest >= untilOffset) {
+      // [offset, untilOffset) and [earliestOffset, latestOffset) have no 
overlap,
+      // either
+      // --------------------------------------------------------
+      //         ^                 ^         ^         ^
+      //         |                 |         |         |
+      //   earliestOffset   latestOffset   offset   untilOffset
+      //
+      // or
+      // --------------------------------------------------------
+      //      ^          ^              ^                ^
+      //      |          |              |                |
+      //   offset   untilOffset   earliestOffset   latestOffset
+      val warningMessage =
+        s"""
+          |The current available offset range is $range.
+          | Offset ${offset} is out of range, and records in [$offset, 
$untilOffset) will be
+          | skipped ${additionalMessage(failOnDataLoss = false)}
+        """.stripMargin
+      logWarning(warningMessage)
+      UNKNOWN_OFFSET
+    } else if (offset >= range.earliest) {
+      // 
-----------------------------------------------------------------------------
+      //         ^            ^                  ^                             
    ^
+      //         |            |                  |                             
    |
+      //   earliestOffset   offset   min(untilOffset,latestOffset)   
max(untilOffset, latestOffset)
+      //
+      // This will happen when a topic is deleted and recreated, and new data 
are pushed very fast,
+      // then we will see `offset` disappears first then appears again. 
Although the parameters
+      // are same, the state in Kafka cluster is changed, so the outer loop 
won't be endless.
+      logWarning(s"Found a disappeared offset $offset. " +
+        s"Some data may be lost ${additionalMessage(failOnDataLoss = false)}")
+      offset
+    } else {
+      // 
------------------------------------------------------------------------------
+      //      ^           ^                       ^                            
     ^
+      //      |           |                       |                            
     |
+      //   offset   earliestOffset   min(untilOffset,latestOffset)   
max(untilOffset, latestOffset)
+      val warningMessage =
+        s"""
+           |The current available offset range is $range.
+           | Offset ${offset} is out of range, and records in [$offset, 
${range.earliest}) will be
+           | skipped ${additionalMessage(failOnDataLoss = false)}
+        """.stripMargin
+      logWarning(warningMessage)
+      range.earliest
+    }
+  }
+
+  /**
+   * Get the record for the given offset if available. Otherwise it will 
either throw error
+   * (if failOnDataLoss = true), or return the next available offset within 
[offset, untilOffset),
+   * or null.
+   *
+   * @throws OffsetOutOfRangeException if `offset` is out of range
+   * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` 
milliseconds.
+   */
+  private def fetchData(
+      offset: Long,
+      untilOffset: Long,
+      pollTimeoutMs: Long,
+      failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = {
+    if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) {
+      // This is the first fetch, or the last pre-fetched data has been 
drained.
+      // Seek to the offset because we may call seekToBeginning or seekToEnd 
before this.
+      seek(offset)
+      poll(pollTimeoutMs)
+    }
+
+    if (!fetchedData.hasNext()) {
+      // We cannot fetch anything after `poll`. Two possible cases:
+      // - `offset` is out of range so that Kafka returns nothing. Just throw
+      // `OffsetOutOfRangeException` to let the caller handle it.
+      // - Cannot fetch any data before timeout. TimeoutException will be 
thrown.
+      val range = getAvailableOffsetRange()
+      if (offset < range.earliest || offset >= range.latest) {
+        throw new OffsetOutOfRangeException(
+          Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava)
+      } else {
+        throw new TimeoutException(
+          s"Cannot fetch record for offset $offset in $pollTimeoutMs 
milliseconds")
+      }
+    } else {
+      val record = fetchedData.next()
+      nextOffsetInFetchedData = record.offset + 1
+      // In general, Kafka uses the specified offset as the start point, and 
tries to fetch the next
+      // available offset. Hence we need to handle offset mismatch.
+      if (record.offset > offset) {
+        // This may happen when some records aged out but their offsets 
already got verified
+        if (failOnDataLoss) {
+          reportDataLoss(true, s"Cannot fetch records in [$offset, 
${record.offset})")
+          // Never happen as "reportDataLoss" will throw an exception
+          null
+        } else {
+          if (record.offset >= untilOffset) {
+            reportDataLoss(false, s"Skip missing records in [$offset, 
$untilOffset)")
+            null
+          } else {
+            reportDataLoss(false, s"Skip missing records in [$offset, 
${record.offset})")
+            record
+          }
+        }
+      } else if (record.offset < offset) {
+        // This should not happen. If it does happen, then we probably 
misunderstand Kafka internal
+        // mechanism.
+        throw new IllegalStateException(
+          s"Tried to fetch $offset but the returned record offset was 
${record.offset}")
+      } else {
+        record
+      }
+    }
+  }
+
+  /** Create a new consumer and reset cached states */
+  private def resetConsumer(): Unit = {
+    consumer.close()
+    consumer = createConsumer
+    resetFetchedData()
+  }
+
+  /** Reset the internal pre-fetched data. */
+  private def resetFetchedData(): Unit = {
+    nextOffsetInFetchedData = UNKNOWN_OFFSET
+    fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], 
Array[Byte]]]
+  }
+
+  /**
+   * Return an addition message including useful message and instruction.
+   */
+  private def additionalMessage(failOnDataLoss: Boolean): String = {
+    if (failOnDataLoss) {
+      s"(GroupId: $groupId, TopicPartition: $topicPartition). " +
+        s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE"
+    } else {
+      s"(GroupId: $groupId, TopicPartition: $topicPartition). " +
+        s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE"
+    }
+  }
+
+  /**
+   * Throw an exception or log a warning as per `failOnDataLoss`.
+   */
+  private def reportDataLoss(
+      failOnDataLoss: Boolean,
+      message: String,
+      cause: Throwable = null): Unit = {
+    val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}"
+    reportDataLoss0(failOnDataLoss, finalMessage, cause)
+  }
+
+  def close(): Unit = consumer.close()
+
+  private def seek(offset: Long): Unit = {
+    logDebug(s"Seeking to $groupId $topicPartition $offset")
+    consumer.seek(topicPartition, offset)
+  }
+
+  private def poll(pollTimeoutMs: Long): Unit = {
+    val p = consumer.poll(pollTimeoutMs)
+    val r = p.records(topicPartition)
+    logDebug(s"Polled $groupId ${p.partitions()}  ${r.size}")
+    fetchedData = r.iterator
+  }
+}
+
+
+private[kafka010] object KafkaDataConsumer extends Logging {
+
+  case class AvailableOffsetRange(earliest: Long, latest: Long)
+
+  private case class CachedKafkaDataConsumer(internalConsumer: 
InternalKafkaConsumer)
+    extends KafkaDataConsumer {
+    assert(internalConsumer.inUse) // make sure this has been set to true
+    override def release(): Unit = { 
KafkaDataConsumer.release(internalConsumer) }
+  }
+
+  private case class NonCachedKafkaDataConsumer(internalConsumer: 
InternalKafkaConsumer)
+    extends KafkaDataConsumer {
+    override def release(): Unit = { internalConsumer.close() }
+  }
+
+  private case class CacheKey(groupId: String, topicPartition: TopicPartition) 
{
+    def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, 
Object]) =
+      
this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], 
topicPartition)
+  }
+
+  // This cache has the following important properties.
+  // - We make a best-effort attempt to maintain the max size of the cache as 
configured capacity.
+  //   The capacity is not guaranteed to be maintained, especially when there 
are more active
+  //   tasks simultaneously using consumers than the capacity.
+  private lazy val cache = {
+    val conf = SparkEnv.get.conf
+    val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64)
+    new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, 
true) {
+      override def removeEldestEntry(
+        entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer]): Boolean = {
+
+        // Try to remove the least-used entry if its currently not in use.
+        //
+        // If you cannot remove it, then the cache will keep growing. In the 
worst case,
+        // the cache will grow to the max number of concurrent tasks that can 
run in the executor,
+        // (that is, number of tasks slots) after which it will never reduce. 
This is unlikely to
+        // be a serious problem because an executor with more than 64 
(default) tasks slots is
+        // likely running on a beefy machine that can handle a large number of 
simultaneously
+        // active consumers.
+
+        if (entry.getValue.inUse == false && this.size > capacity) {
+          logWarning(
+            s"KafkaConsumer cache hitting max capacity of $capacity, " +
+              s"removing consumer for ${entry.getKey}")
+          try {
+            entry.getValue.close()
+          } catch {
+            case e: SparkException =>
+              logError(s"Error closing earliest Kafka consumer for 
${entry.getKey}", e)
+          }
+          true
+        } else {
+          false
+        }
+      }
+    }
+  }
+
+  /**
+   * Get a cached consumer for groupId, assigned to topic and partition.
+   * If matching consumer doesn't already exist, will be created using 
kafkaParams.
+   * The returned consumer must be released explicitly using 
[[KafkaDataConsumer.release()]].
+   *
+   * Note: This method guarantees that the consumer returned is not currently 
in use by any one
+   * else. Within this guarantee, this method will make a best effort attempt 
to re-use consumers by
+   * caching them and tracking when they are in use.
+   */
+  def acquire(
+      topicPartition: TopicPartition,
+      kafkaParams: ju.Map[String, Object],
+      useCache: Boolean): KafkaDataConsumer = synchronized {
+    val key = new CacheKey(topicPartition, kafkaParams)
+    val existingInternalConsumer = cache.get(key)
+
+    lazy val newInternalConsumer = new InternalKafkaConsumer(topicPartition, 
kafkaParams)
+
+    if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) {
+      // If this is reattempt at running the task, then invalidate cached 
consumer if any and
+      // start with a new one.
+      if (existingInternalConsumer != null) {
+        // Consumer exists in cache. If its in use, mark it for closing later, 
or close it now.
+        if (existingInternalConsumer.inUse) {
+          existingInternalConsumer.markedForClose = true
+        } else {
+          existingInternalConsumer.close()
+        }
+      }
+      cache.remove(key)  // Invalidate the cache in any case
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+
+    } else if (!useCache) {
+      // If planner asks to not reuse consumers, then do not use it, return a 
new consumer
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+
+    } else if (existingInternalConsumer == null) {
+      // If consumer is not already cached, then put a new in the cache and 
return it
+      cache.put(key, newInternalConsumer)
+      newInternalConsumer.inUse = true
+      CachedKafkaDataConsumer(newInternalConsumer)
+
+    } else if (existingInternalConsumer.inUse) {
+      // If consumer is already cached but is currently in use, then return a 
new consumer
+      NonCachedKafkaDataConsumer(newInternalConsumer)
+
+    } else {
+      // If consumer is already cached and is currently not in use, then 
return that consumer
+      existingInternalConsumer.inUse = true
+      CachedKafkaDataConsumer(existingInternalConsumer)
+    }
+  }
+
+  private def release(intConsumer: InternalKafkaConsumer): Unit = {
+    synchronized {
+
+      // Clear the consumer from the cache if this is indeed the consumer 
present in the cache
+      val key = new CacheKey(intConsumer.topicPartition, 
intConsumer.kafkaParams)
+      val cachedIntConsumer = cache.get(key)
+      if (intConsumer.eq(cachedIntConsumer)) {
+        // The released consumer is the same object as the cached one.
+        if (intConsumer.markedForClose) {
+          intConsumer.close()
+          cache.remove(key)
+        } else {
+          intConsumer.inUse = false
+        }
+      } else {
+        // The released consumer is either not the same one as in the cache, 
or not in the cache
+        // at all. This may happen if the cache was invalidate while this 
consumer was being used.
+        // Just close this consumer.
+        intConsumer.close()
+        logInfo(s"Released a supposedly cached consumer that was not found in 
the cache")
+      }
+    }
+  }
+}
+
+private[kafka010] object InternalKafkaConsumer extends Logging {
+
+  private val UNKNOWN_OFFSET = -2L
+
+  private def reportDataLoss0(
+      failOnDataLoss: Boolean,
+      finalMessage: String,
+      cause: Throwable = null): Unit = {
+    if (failOnDataLoss) {
+      if (cause != null) {
+        throw new IllegalStateException(finalMessage, cause)
+      } else {
+        throw new IllegalStateException(finalMessage)
+      }
+    } else {
+      if (cause != null) {
+        logWarning(finalMessage, cause)
+      } else {
+        logWarning(finalMessage)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6937571a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
index 66b3409..498e344 100644
--- 
a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
+++ 
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
@@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition(
  * An RDD that reads data from Kafka based on offset ranges across multiple 
partitions.
  * Additionally, it allows preferred locations to be set for each topic + 
partition, so that
  * the [[KafkaSource]] can ensure the same executor always reads the same 
topic + partition
- * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read 
data efficiently.
+ * and cached KafkaConsumers (see [[KafkaDataConsumer]] can be used read data 
efficiently.
  *
  * @param sc the [[SparkContext]]
  * @param executorKafkaParams Kafka configuration for creating KafkaConsumer 
on the executors
@@ -126,14 +126,9 @@ private[kafka010] class KafkaSourceRDD(
     val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition]
     val topic = sourcePartition.offsetRange.topic
     val kafkaPartition = sourcePartition.offsetRange.partition
-    val consumer =
-      if (!reuseKafkaConsumer) {
-        // If we can't reuse CachedKafkaConsumers, creating a new 
CachedKafkaConsumer. As here we
-        // uses `assign`, we don't need to worry about the "group.id" 
conflicts.
-        CachedKafkaConsumer.createUncached(topic, kafkaPartition, 
executorKafkaParams)
-      } else {
-        CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, 
executorKafkaParams)
-      }
+    val consumer = KafkaDataConsumer.acquire(
+      sourcePartition.offsetRange.topicPartition, executorKafkaParams, 
reuseKafkaConsumer)
+
     val range = resolveRange(consumer, sourcePartition.offsetRange)
     assert(
       range.fromOffset <= range.untilOffset,
@@ -167,13 +162,7 @@ private[kafka010] class KafkaSourceRDD(
         }
 
         override protected def close(): Unit = {
-          if (!reuseKafkaConsumer) {
-            // Don't forget to close non-reuse KafkaConsumers. You may take 
down your cluster!
-            consumer.close()
-          } else {
-            // Indicate that we're no longer using this consumer
-            CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, 
executorKafkaParams)
-          }
+          consumer.release()
         }
       }
       // Release consumer, either by removing it or indicating we're no longer 
using it
@@ -184,7 +173,7 @@ private[kafka010] class KafkaSourceRDD(
     }
   }
 
-  private def resolveRange(consumer: CachedKafkaConsumer, range: 
KafkaSourceRDDOffsetRange) = {
+  private def resolveRange(consumer: KafkaDataConsumer, range: 
KafkaSourceRDDOffsetRange) = {
     if (range.fromOffset < 0 || range.untilOffset < 0) {
       // Late bind the offset range
       val availableOffsetRange = consumer.getAvailableOffsetRange()

http://git-wip-us.apache.org/repos/asf/spark/blob/6937571a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala
 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala
deleted file mode 100644
index 7aa7dd0..0000000
--- 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.kafka010
-
-import org.scalatest.PrivateMethodTester
-
-import org.apache.spark.sql.test.SharedSQLContext
-
-class CachedKafkaConsumerSuite extends SharedSQLContext with 
PrivateMethodTester {
-
-  test("SPARK-19886: Report error cause correctly in reportDataLoss") {
-    val cause = new Exception("D'oh!")
-    val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0)
-    val e = intercept[IllegalStateException] {
-      CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause))
-    }
-    assert(e.getCause === cause)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/6937571a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
new file mode 100644
index 0000000..0d0fb9c
--- /dev/null
+++ 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.Duration
+import scala.util.Random
+
+import org.apache.kafka.clients.consumer.ConsumerConfig
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.ByteArrayDeserializer
+import org.scalatest.PrivateMethodTester
+
+import org.apache.spark.{TaskContext, TaskContextImpl}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.ThreadUtils
+
+class KafkaDataConsumerSuite extends SharedSQLContext with PrivateMethodTester 
{
+
+  protected var testUtils: KafkaTestUtils = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    testUtils = new KafkaTestUtils(Map[String, Object]())
+    testUtils.setup()
+  }
+
+  override def afterAll(): Unit = {
+    if (testUtils != null) {
+      testUtils.teardown()
+      testUtils = null
+    }
+    super.afterAll()
+  }
+
+  test("SPARK-19886: Report error cause correctly in reportDataLoss") {
+    val cause = new Exception("D'oh!")
+    val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0)
+    val e = intercept[IllegalStateException] {
+      InternalKafkaConsumer.invokePrivate(reportDataLoss(true, "message", 
cause))
+    }
+    assert(e.getCause === cause)
+  }
+
+  test("SPARK-23623: concurrent use of KafkaDataConsumer") {
+    val topic = "topic" + Random.nextInt()
+    val data = (1 to 1000).map(_.toString)
+    testUtils.createTopic(topic, 1)
+    testUtils.sendMessages(topic, data.toArray)
+    val topicPartition = new TopicPartition(topic, 0)
+
+    import ConsumerConfig._
+    val kafkaParams = Map[String, Object](
+      GROUP_ID_CONFIG -> "groupId",
+      BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
+      KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+      VALUE_DESERIALIZER_CLASS_CONFIG -> 
classOf[ByteArrayDeserializer].getName,
+      AUTO_OFFSET_RESET_CONFIG -> "earliest",
+      ENABLE_AUTO_COMMIT_CONFIG -> "false"
+    )
+
+    val numThreads = 100
+    val numConsumerUsages = 500
+
+    @volatile var error: Throwable = null
+
+    def consume(i: Int): Unit = {
+      val useCache = Random.nextBoolean
+      val taskContext = if (Random.nextBoolean) {
+        new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), 
null, null, null)
+      } else {
+        null
+      }
+      TaskContext.setTaskContext(taskContext)
+      val consumer = KafkaDataConsumer.acquire(
+        topicPartition, kafkaParams.asJava, useCache)
+      try {
+        val range = consumer.getAvailableOffsetRange()
+        val rcvd = range.earliest until range.latest map { offset =>
+          val bytes = consumer.get(offset, Long.MaxValue, 10000, 
failOnDataLoss = false).value()
+          new String(bytes)
+        }
+        assert(rcvd == data)
+      } catch {
+        case e: Throwable =>
+          error = e
+          throw e
+      } finally {
+        consumer.release()
+      }
+    }
+
+    val threadpool = Executors.newFixedThreadPool(numThreads)
+    try {
+      val futures = (1 to numConsumerUsages).map { i =>
+        threadpool.submit(new Runnable {
+          override def run(): Unit = { consume(i) }
+        })
+      }
+      futures.foreach(_.get(1, TimeUnit.MINUTES))
+      assert(error == null)
+    } finally {
+      threadpool.shutdown()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to