Repository: samza
Updated Branches:
  refs/heads/master 9396ee5cc -> e5f31c57c


http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala
----------------------------------------------------------------------
diff --git 
a/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala
 
b/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala
index 3769e10..5a16580 100644
--- 
a/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala
+++ 
b/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala
@@ -19,154 +19,181 @@
 
 package org.apache.samza.system.kafka
 
-import org.apache.samza.util.Logging
-import org.apache.kafka.clients.producer.{RecordMetadata, Callback, 
ProducerRecord, Producer}
-import org.apache.samza.system.SystemProducer
+
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.Future
+
+import org.apache.kafka.clients.producer.Callback
+import org.apache.kafka.clients.producer.Producer
+import org.apache.kafka.clients.producer.ProducerRecord
+import org.apache.kafka.clients.producer.RecordMetadata
+import org.apache.kafka.common.PartitionInfo
+import org.apache.samza.SamzaException
 import org.apache.samza.system.OutgoingMessageEnvelope
+import org.apache.samza.system.SystemProducer
 import org.apache.samza.util.ExponentialSleepStrategy
-import org.apache.samza.util.TimerUtils
 import org.apache.samza.util.KafkaUtil
-import java.util.concurrent.atomic.{AtomicInteger, AtomicReference, 
AtomicBoolean}
-import java.util.{Map => javaMap}
-import org.apache.samza.SamzaException
-import org.apache.kafka.common.errors.RetriableException
-import org.apache.kafka.common.PartitionInfo
-import java.util
-import java.util.concurrent.Future
-import scala.collection.JavaConversions._
+import org.apache.samza.util.Logging
+import org.apache.samza.util.TimerUtils
 
+import scala.collection.JavaConversions._
 
 class KafkaSystemProducer(systemName: String,
                           retryBackoff: ExponentialSleepStrategy = new 
ExponentialSleepStrategy,
                           getProducer: () => Producer[Array[Byte], 
Array[Byte]],
                           metrics: KafkaSystemProducerMetrics,
-                          val clock: () => Long = () => System.nanoTime,
-                          val maxRetries: Int = 30) extends SystemProducer 
with Logging with TimerUtils
+                          val clock: () => Long = () => System.nanoTime) 
extends SystemProducer with Logging with TimerUtils
 {
-  var producer: Producer[Array[Byte], Array[Byte]] = null
-  val latestFuture: javaMap[String, Future[RecordMetadata]] = new 
util.HashMap[String, Future[RecordMetadata]]()
-  val sendFailed: AtomicBoolean = new AtomicBoolean(false)
-  var exceptionThrown: AtomicReference[Exception] = new 
AtomicReference[Exception]()
-  val StreamNameNullOrEmptyErrorMsg = "Stream Name should be specified in the 
stream configuration file.";
-
-  // Backward-compatible constructor for Java clients
-  def this(systemName: String,
-           retryBackoff: ExponentialSleepStrategy,
-           getProducer: () => Producer[Array[Byte], Array[Byte]],
-           metrics: KafkaSystemProducerMetrics,
-           clock: () => Long) = this(systemName, retryBackoff, getProducer, 
metrics, clock, 30)
-
-  def start() {
+
+  class SourceData {
+    /**
+     * lock to make send() and store its future atomic
+     */
+    val sendLock: Object = new Object
+    /**
+     * The most recent send's Future handle
+     */
+    @volatile
+    var latestFuture: Future[RecordMetadata] = null
+    /**
+     * exceptionThrown: to store the exception in case of any "ultimate" send 
failure (ie. failure
+     * after exhausting max_retries in Kafka producer) in the I/O thread, we 
do not continue to queue up more send
+     * requests from the samza thread. It helps the samza thread identify if 
the failure happened in I/O thread or not.
+     */
+    @volatile
+    var exceptionThrown: SamzaException = null
+  }
+
+  @volatile var producer: Producer[Array[Byte], Array[Byte]] = null
+  var producerLock: Object = new Object
+  val StreamNameNullOrEmptyErrorMsg = "Stream Name should be specified in the 
stream configuration file."
+  val sources: ConcurrentHashMap[String, SourceData] = new 
ConcurrentHashMap[String, SourceData]
+
+  def start(): Unit = {
+    producerLock.synchronized {
+      if (producer == null) {
+        info("Creating a new producer for system %s." format systemName)
+        producer = getProducer()
+      }
+    }
   }
 
   def stop() {
-    if (producer != null) {
-      latestFuture.keys.foreach(flush(_))
-      producer.close
-      producer = null
+    producerLock.synchronized {
+      try {
+        if (producer != null) {
+          producer.close
+          producer = null
+
+          sources.foreach {p =>
+            if (p._2.exceptionThrown == null) {
+              flush(p._1)
+            }
+          }
+        }
+      } catch {
+        case e: Exception => logger.error(e.getMessage, e)
+      }
     }
   }
 
   def register(source: String) {
-    if(latestFuture.containsKey(source)) {
+    if(sources.putIfAbsent(source, new SourceData) != null) {
       throw new SamzaException("%s is already registered with the %s system 
producer" format (source, systemName))
     }
-    latestFuture.put(source, null)
   }
 
   def send(source: String, envelope: OutgoingMessageEnvelope) {
-    var numRetries: AtomicInteger = new AtomicInteger(0)
-    trace("Enqueueing message: %s, %s." format (source, envelope))
-    if(producer == null) {
-      info("Creating a new producer for system %s." format systemName)
-      producer = getProducer()
-      debug("Created a new producer for system %s." format systemName)
-    }
-    // Java-based Kafka producer API requires an "Integer" type partitionKey 
and does not allow custom overriding of Partitioners
-    // Any kind of custom partitioning has to be done on the client-side
+    trace("Enqueuing message: %s, %s." format (source, envelope))
+
     val topicName = envelope.getSystemStream.getStream
     if (topicName == null || topicName == "") {
       throw new IllegalArgumentException(StreamNameNullOrEmptyErrorMsg)
     }
-    val partitions: java.util.List[PartitionInfo]  = 
producer.partitionsFor(topicName)
-    val partitionKey = if(envelope.getPartitionKey != null) 
KafkaUtil.getIntegerPartitionKey(envelope, partitions) else null
+
+    val sourceData = sources.get(source)
+    if (sourceData == null) {
+      throw new IllegalArgumentException("Source %s must be registered first 
before send." format source)
+    }
+
+    val exception = sourceData.exceptionThrown
+    if (exception != null) {
+      metrics.sendFailed.inc
+      throw exception
+    }
+
+    val currentProducer = producer
+    if (currentProducer == null) {
+      throw new SamzaException("Kafka system producer is not available.")
+    }
+
+    // Java-based Kafka producer API requires an "Integer" type partitionKey 
and does not allow custom overriding of Partitioners
+    // Any kind of custom partitioning has to be done on the client-side
+    val partitions: java.util.List[PartitionInfo] = 
currentProducer.partitionsFor(topicName)
+    val partitionKey = if (envelope.getPartitionKey != null) 
KafkaUtil.getIntegerPartitionKey(envelope, partitions)
+    else null
     val record = new ProducerRecord(envelope.getSystemStream.getStream,
                                     partitionKey,
                                     envelope.getKey.asInstanceOf[Array[Byte]],
                                     
envelope.getMessage.asInstanceOf[Array[Byte]])
 
-    sendFailed.set(false)
-
-    retryBackoff.run(
-      loop => {
-        if(sendFailed.get()) {
-          throw exceptionThrown.get()
-        }
+    try {
+      sourceData.sendLock.synchronized {
         val futureRef: Future[RecordMetadata] =
-          producer.send(record, new Callback {
+          currentProducer.send(record, new Callback {
             def onCompletion(metadata: RecordMetadata, exception: Exception): 
Unit = {
               if (exception == null) {
                 //send was successful. Don't retry
                 metrics.sendSuccess.inc
               }
               else {
-                //If there is an exception in the callback, it means that the 
Kafka producer has exhausted the max-retries
-                //Hence, fail container!
-                exceptionThrown.compareAndSet(null, exception)
-                sendFailed.set(true)
+                //If there is an exception in the callback, fail container!
+                //Close producer.
+                currentProducer.close
+                sourceData.exceptionThrown = new SamzaException("Unable to 
send message from %s to system %s." format(source, systemName),
+                                                     exception)
+                metrics.sendFailed.inc
+                logger.error("Unable to send message on Topic:%s Partition:%s" 
format(topicName, partitionKey),
+                             exception)
               }
             }
           })
-        latestFuture.put(source, futureRef)
-        metrics.sends.inc
-        if(!sendFailed.get())
-          loop.done
-      },
-      (exception, loop) => {
-        if((exception != null && !exception.isInstanceOf[RetriableException]) 
|| numRetries.get() >= maxRetries) {
-          // Irrecoverable exceptions.
-          error("Exception detail : ", exception)
-          //Close producer
-          stop()
-          producer = null
-          //Mark loop as done as we are not going to retry
-          loop.done
-          metrics.sendFailed.inc
-          throw new SamzaException(("Failed to send message on Topic:%s 
Partition:%s NumRetries:%s Exception:\n %s,")
-            .format(topicName, partitionKey, numRetries, exception))
-        } else {
-          numRetries.incrementAndGet()
-          warn(("Retrying send due to RetriableException - %s for Topic:%s 
Partition:%s. " +
-            "Turn on debugging to get a full stack trace").format(exception, 
topicName, partitionKey))
-          debug("Exception detail:", exception)
-          metrics.retries.inc
-        }
+        sourceData.latestFuture = futureRef
+      }
+      metrics.sends.inc
+    } catch {
+      case e: Exception => {
+        currentProducer.close()
+        metrics.sendFailed.inc
+        throw new SamzaException(("Failed to send message on Topic:%s 
Partition:%s Exception:\n %s,")
+          .format(topicName, partitionKey, e))
       }
-    )
+    }
   }
 
   def flush(source: String) {
     updateTimer(metrics.flushNs) {
       metrics.flushes.inc
+
+      val sourceData = sources.get(source)
       //if latestFuture is null, it probably means that there has been no 
calls to "send" messages
       //Hence, nothing to do in flush
-      if(latestFuture.get(source) != null) {
-        while (!latestFuture.get(source).isDone && !sendFailed.get()) {
-          //do nothing
-        }
-        if (sendFailed.get()) {
-          logger.error("Unable to send message from %s to system %s" 
format(source, systemName))
-          //Close producer.
-          if (producer != null) {
-            producer.close
+      if(sourceData.latestFuture != null) {
+        while(!sourceData.latestFuture.isDone && sourceData.exceptionThrown == 
null) {
+          try {
+            sourceData.latestFuture.get()
+          } catch {
+            case t: Throwable => error(t.getMessage, t)
           }
-          producer = null
+        }
+
+        if (sourceData.exceptionThrown != null) {
           metrics.flushFailed.inc
-          throw new SamzaException("Unable to send message from %s to system 
%s." format(source, systemName), exceptionThrown.get)
+          throw sourceData.exceptionThrown
         } else {
           trace("Flushed %s." format (source))
         }
       }
     }
   }
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java
----------------------------------------------------------------------
diff --git 
a/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java
 
b/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java
index 04c9113..224ca2f 100644
--- 
a/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java
+++ 
b/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java
@@ -50,9 +50,7 @@ public class TestKafkaSystemProducerJava {
       }
     });
 
-    // Default value should have been used.
-    assertEquals(30, ksp.maxRetries());
     long now = System.currentTimeMillis();
     assertTrue((Long)ksp.clock().apply() >= now);
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala
----------------------------------------------------------------------
diff --git 
a/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala
 
b/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala
index 8e32bba..fab998a 100644
--- 
a/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala
+++ 
b/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala
@@ -140,7 +140,7 @@ class TestKafkaSystemProducer {
       systemProducer.flush("test")
     }
     assertTrue(thrown.isInstanceOf[SamzaException])
-    assertEquals(2, mockProducer.getMsgsSent)
+    assertEquals(3, mockProducer.getMsgsSent) // msg1, msg2 and msg4 will be 
sent
     systemProducer.stop()
   }
 
@@ -150,14 +150,12 @@ class TestKafkaSystemProducer {
     val msg2 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), 
"b".getBytes)
     val msg3 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), 
"c".getBytes)
     val msg4 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), 
"d".getBytes)
-    val numMaxRetries = 3
 
     val mockProducer = new MockKafkaProducer(1, "test", 1)
     val producerMetrics = new KafkaSystemProducerMetrics()
     val producer = new KafkaSystemProducer(systemName =  "test",
       getProducer = () => mockProducer,
-      metrics = producerMetrics,
-      maxRetries = numMaxRetries)
+      metrics = producerMetrics)
 
     producer.register("test")
     producer.start()
@@ -169,14 +167,15 @@ class TestKafkaSystemProducer {
     assertEquals(0, producerMetrics.retries.getCount)
     mockProducer.setErrorNext(true, new TimeoutException())
 
+    producer.send("test", msg4)
     val thrown = intercept[SamzaException] {
-      producer.send("test", msg4)
+      producer.flush("test")
     }
     assertTrue(thrown.isInstanceOf[SamzaException])
     assertTrue(thrown.getCause.isInstanceOf[TimeoutException])
-    assertEquals(true, producer.sendFailed.get())
     assertEquals(3, mockProducer.getMsgsSent)
-    assertEquals(numMaxRetries, producerMetrics.retries.getCount)
+    // retriable exception will be thrown immediately
+    assertEquals(0, producerMetrics.retries.getCount)
     producer.stop()
   }
 
@@ -199,12 +198,12 @@ class TestKafkaSystemProducer {
     producer.send("test", msg3)
     mockProducer.setErrorNext(true, new RecordTooLargeException())
 
+    producer.send("test", msg4)
     val thrown = intercept[SamzaException] {
-       producer.send("test", msg4)
+       producer.flush("test")
     }
     assertTrue(thrown.isInstanceOf[SamzaException])
     assertTrue(thrown.getCause.isInstanceOf[RecordTooLargeException])
-    assertEquals(true, producer.sendFailed.get())
     assertEquals(3, mockProducer.getMsgsSent)
     assertEquals(0, producerMetrics.retries.getCount)
     producer.stop()

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
----------------------------------------------------------------------
diff --git 
a/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
 
b/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
index 72f25a3..4c245b6 100644
--- 
a/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
+++ 
b/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
@@ -26,14 +26,14 @@ import java.util
 /**
  * In memory implementation of a key value store.
  *
- * This uses a TreeMap to store the keys in order
+ * This uses a ConcurrentSkipListMap to store the keys in order
  *
  * @param metrics A metrics instance to publish key-value store related 
statistics
  */
 class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new 
KeyValueStoreMetrics)
   extends KeyValueStore[Array[Byte], Array[Byte]] with Logging {
 
-  val underlying = new util.TreeMap[Array[Byte], Array[Byte]] 
(UnsignedBytes.lexicographicalComparator())
+  val underlying = new util.concurrent.ConcurrentSkipListMap[Array[Byte], 
Array[Byte]] (UnsignedBytes.lexicographicalComparator())
 
   override def flush(): Unit = {
     // No-op for In memory store.
@@ -47,7 +47,7 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics 
= new KeyValueStor
 
     override def close(): Unit = Unit
 
-    override def remove(): Unit = iter.remove()
+    override def remove(): Unit = throw new 
UnsupportedOperationException("InMemoryKeyValueStore iterator doesn't support 
remove")
 
     override def next(): Entry[Array[Byte], Array[Byte]] = {
       val n = iter.next()

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
----------------------------------------------------------------------
diff --git 
a/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
 
b/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
index 9b9b1f6..73b89f7 100644
--- 
a/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
+++ 
b/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
@@ -106,7 +106,6 @@ class RocksDbKeyValueStore(
   // after the directories are created, which happens much later from now.
   private lazy val db = RocksDbKeyValueStore.openDB(dir, options, storeConfig, 
isLoggedStore, storeName)
   private val lexicographic = new LexicographicComparator()
-  private var deletesSinceLastCompaction = 0
 
   def get(key: Array[Byte]): Array[Byte] = {
     metrics.gets.inc
@@ -141,7 +140,6 @@ class RocksDbKeyValueStore(
     require(key != null, "Null key not allowed.")
     if (value == null) {
       db.remove(writeOptions, key)
-      deletesSinceLastCompaction += 1
     } else {
       metrics.bytesWritten.inc(key.size + value.size)
       db.put(writeOptions, key, value)
@@ -168,7 +166,6 @@ class RocksDbKeyValueStore(
     }
     metrics.puts.inc(wrote)
     metrics.deletes.inc(deletes)
-    deletesSinceLastCompaction += deletes
   }
 
   def delete(key: Array[Byte]) {

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
----------------------------------------------------------------------
diff --git 
a/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
 
b/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
index b7f1cdc..05d39ea 100644
--- 
a/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
+++ 
b/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
@@ -26,7 +26,7 @@ import java.util
 import org.apache.samza.config.MapConfig
 import org.apache.samza.util.ExponentialSleepStrategy
 import org.junit.{Assert, Test}
-import org.rocksdb.{RocksDB, FlushOptions, Options}
+import org.rocksdb.{RocksIterator, RocksDB, FlushOptions, Options}
 
 class TestRocksDbKeyValueStore
 {
@@ -85,4 +85,61 @@ class TestRocksDbKeyValueStore
     rocksDB.close()
     rocksDBReadOnly.close()
   }
+
+  @Test
+  def testIteratorWithRemoval(): Unit = {
+    val lock = new Object
+
+    val map = new util.HashMap[String, String]()
+    val config = new MapConfig(map)
+    val options = new Options()
+    options.setCreateIfMissing(true)
+    val rocksDB = RocksDbKeyValueStore.openDB(new 
File(System.getProperty("java.io.tmpdir")),
+                                              options,
+                                              config,
+                                              false,
+                                              "dbStore")
+
+    val key = "key".getBytes("UTF-8")
+    val key1 = "key1".getBytes("UTF-8")
+    val value = "val".getBytes("UTF-8")
+    val value1 = "val1".getBytes("UTF-8")
+
+    var iter: RocksIterator = null
+
+    lock.synchronized {
+      rocksDB.put(key, value)
+      rocksDB.put(key1, value1)
+      // SAMZA-836: Mysteriously,calling new FlushOptions() does not invoke 
the NativeLibraryLoader in rocksdbjni-3.13.1!
+      // Moving this line after calling new Options() resolve the issue.
+      val flushOptions = new FlushOptions().setWaitForFlush(true)
+      rocksDB.flush(flushOptions)
+
+      iter = rocksDB.newIterator()
+      iter.seekToFirst()
+    }
+
+    while (iter.isValid) {
+      iter.next()
+    }
+    iter.dispose()
+
+    lock.synchronized {
+      rocksDB.remove(key)
+      iter = rocksDB.newIterator()
+      iter.seek(key)
+    }
+
+    while (iter.isValid) {
+      iter.next()
+    }
+    iter.dispose()
+
+    val dbDir = new File(System.getProperty("java.io.tmpdir")).toString
+    val rocksDBReadOnly = RocksDB.openReadOnly(options, dbDir)
+    Assert.assertEquals(new String(rocksDBReadOnly.get(key1), "UTF-8"), "val1")
+    Assert.assertEquals(rocksDBReadOnly.get(key), null)
+    rocksDB.close()
+    rocksDBReadOnly.close()
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala 
b/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
index e7e4ede..44f96b4 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
@@ -37,7 +37,7 @@ import java.util.Arrays
  * that have not yet been written to disk. All writes go to the dirty list and 
when the list is long enough we call putAll on all those values at once. Dirty 
items
  * that time out of the cache before being written will also trigger a putAll 
of the dirty list.
  *
- * This class is very non-thread safe.
+ * This class is thread safe.
  *
  * @param store The store to cache
  * @param cacheSize The number of entries to hold in the in memory-cache
@@ -59,6 +59,9 @@ class CachedStore[K, V](
   /** the list of items to be written out on flush from newest to oldest */
   private var dirty = new mutable.DoubleLinkedList[K]()
 
+  /** the synchronization lock to protect access to the store from multiple 
threads **/
+  private val lock = new Object
+
   /** an lru cache of values that holds cacheEntries and calls putAll() on 
dirty entries if necessary when discarding */
   private val cache = new java.util.LinkedHashMap[K, CacheEntry[K, 
V]]((cacheSize * 1.2).toInt, 1.0f, true) {
     override def removeEldestEntry(eldest: java.util.Map.Entry[K, 
CacheEntry[K, V]]): Boolean = {
@@ -76,7 +79,7 @@ class CachedStore[K, V](
   }
 
   /** tracks whether an array has been used as a key. since this is dangerous 
with LinkedHashMap, we want to warn on it. **/
-  private var containsArrayKeys = false
+  @volatile private var containsArrayKeys = false
 
   // Use counters here, rather than directly accessing variables using .size
   // since metrics can be accessed in other threads, and cache.size is not
@@ -85,7 +88,7 @@ class CachedStore[K, V](
   metrics.setDirtyCount(() => dirtyCount)
   metrics.setCacheSize(() => cacheCount)
 
-  override def get(key: K) = {
+  override def get(key: K) = lock.synchronized({
     metrics.gets.inc
 
     val c = cache.get(key)
@@ -98,7 +101,7 @@ class CachedStore[K, V](
       cacheCount = cache.size
       v
     }
-  }
+  })
 
   private class CachedStoreIterator(val iter: KeyValueIterator[K, V])
     extends KeyValueIterator[K, V] {
@@ -107,10 +110,7 @@ class CachedStore[K, V](
 
     override def close(): Unit = iter.close()
 
-    override def remove(): Unit = {
-      iter.remove()
-      delete(last.getKey)
-    }
+    override def remove(): Unit = throw new 
UnsupportedOperationException("CachedStore iterator doesn't support remove")
 
     override def next() = {
       last = iter.next()
@@ -120,79 +120,85 @@ class CachedStore[K, V](
     override def hasNext: Boolean = iter.hasNext
   }
 
-  override def range(from: K, to: K): KeyValueIterator[K, V] = {
+  override def range(from: K, to: K): KeyValueIterator[K, V] = 
lock.synchronized({
     metrics.ranges.inc
     putAllDirtyEntries()
 
     new CachedStoreIterator(store.range(from, to))
-  }
+  })
 
-  override def all(): KeyValueIterator[K, V] = {
+  override def all(): KeyValueIterator[K, V] = lock.synchronized({
     metrics.alls.inc
     putAllDirtyEntries()
 
     new CachedStoreIterator(store.all())
-  }
+  })
 
   override def put(key: K, value: V) {
-    metrics.puts.inc
+    lock.synchronized({
+      metrics.puts.inc
 
-    checkKeyIsArray(key)
+      checkKeyIsArray(key)
 
-    // Add the key to the front of the dirty list (and remove any prior
-    // occurrences to dedupe).
-    val found = cache.get(key)
-    if (found == null || found.dirty == null) {
-      this.dirtyCount += 1
-    } else {
-      // If we are removing the head of the list, move the head to the next
-      // element. See SAMZA-45 for details.
-      if (found.dirty.prev == null) {
-        this.dirty = found.dirty.next
-        this.dirty.prev = null
+      // Add the key to the front of the dirty list (and remove any prior
+      // occurrences to dedupe).
+      val found = cache.get(key)
+      if (found == null || found.dirty == null) {
+        this.dirtyCount += 1
       } else {
-        found.dirty.remove()
+        // If we are removing the head of the list, move the head to the next
+        // element. See SAMZA-45 for details.
+        if (found.dirty.prev == null) {
+          this.dirty = found.dirty.next
+          this.dirty.prev = null
+        } else {
+          found.dirty.remove()
+        }
       }
-    }
-    this.dirty = new mutable.DoubleLinkedList(key, this.dirty)
+      this.dirty = new mutable.DoubleLinkedList(key, this.dirty)
 
-    // Add the key to the cache (but don't allocate a new cache entry if we
-    // already have one).
-    if (found == null) {
-      cache.put(key, new CacheEntry(value, this.dirty))
-      cacheCount = cache.size
-    } else {
-      found.value = value
-      found.dirty = this.dirty
-    }
+      // Add the key to the cache (but don't allocate a new cache entry if we
+      // already have one).
+      if (found == null) {
+        cache.put(key, new CacheEntry(value, this.dirty))
+        cacheCount = cache.size
+      } else {
+        found.value = value
+        found.dirty = this.dirty
+      }
 
-    // putAll() dirty values if the write list is full.
-    val purgeNeeded = if (dirtyCount >= writeBatchSize) {
-      debug("Dirty count %s >= write batch size %s. Calling putAll() on all 
dirty entries." format (dirtyCount, writeBatchSize))
-      true
-    } else if (hasArrayKeys) {
-      // Flush every time to support the following legacy behavior:
-      // If array keys are used with a cached store, get() will always miss 
the cache because of array equality semantics
-      // However, it will fall back to the underlying store which does support 
arrays.
-      true
-    } else {
-      false
-    }
+      // putAll() dirty values if the write list is full.
+      val purgeNeeded = if (dirtyCount >= writeBatchSize) {
+        debug("Dirty count %s >= write batch size %s. Calling putAll() on all 
dirty entries." format (dirtyCount, writeBatchSize))
+        true
+      } else if (hasArrayKeys) {
+        // Flush every time to support the following legacy behavior:
+        // If array keys are used with a cached store, get() will always miss 
the cache because of array equality semantics
+        // However, it will fall back to the underlying store which does 
support arrays.
+        true
+      } else {
+        false
+      }
 
-    if (purgeNeeded) {
-      putAllDirtyEntries()
-    }
+      if (purgeNeeded) {
+        putAllDirtyEntries()
+      }
+    })
   }
 
   override def flush() {
     trace("Purging dirty entries from CachedStore.")
-    metrics.flushes.inc
-    putAllDirtyEntries()
-    trace("Flushing store.")
-    store.flush()
+    lock.synchronized({
+      metrics.flushes.inc
+      putAllDirtyEntries()
+      store.flush()
+    })
     trace("Flushed store.")
   }
 
+  /**
+   * The synchronization lock must be held before calling this method.
+   */
   private def putAllDirtyEntries() {
     trace("Calling putAll() on dirty entries.")
     // write out the contents of the dirty list oldest first
@@ -212,26 +218,34 @@ class CachedStore[K, V](
   }
 
   override def putAll(entries: java.util.List[Entry[K, V]]) {
-    val iter = entries.iterator
-    while (iter.hasNext) {
-      val curr = iter.next
-      put(curr.getKey, curr.getValue)
-    }
+    lock.synchronized({
+      val iter = entries.iterator
+      while (iter.hasNext) {
+        val curr = iter.next
+        put(curr.getKey, curr.getValue)
+      }
+    })
   }
 
   override def delete(key: K) {
-    metrics.deletes.inc
-    put(key, null.asInstanceOf[V])
+    lock.synchronized({
+      metrics.deletes.inc
+      put(key, null.asInstanceOf[V])
+    })
   }
 
   override def close() {
-    trace("Closing.")
-    flush()
-    store.close()
+    lock.synchronized({
+      trace("Closing.")
+      flush()
+      store.close()
+    })
   }
 
   override def deleteAll(keys: java.util.List[K]) = {
-    KeyValueStore.Extension.deleteAll(this, keys)
+    lock.synchronized({
+      KeyValueStore.Extension.deleteAll(this, keys)
+    })
   }
 
   private def checkKeyIsArray(key: K) {
@@ -243,30 +257,32 @@ class CachedStore[K, V](
   }
 
   override def getAll(keys: java.util.List[K]): java.util.Map[K, V] = {
-    metrics.gets.inc(keys.size)
-    val returnValue = new java.util.HashMap[K, V](keys.size)
-    val misses = new java.util.ArrayList[K]
-    val keysIterator = keys.iterator
-    while (keysIterator.hasNext) {
-      val key = keysIterator.next
-      val cached = cache.get(key)
-      if (cached != null) {
-        metrics.cacheHits.inc
-        returnValue.put(key, cached.value)
-      } else {
-        misses.add(key)
+    lock.synchronized({
+      metrics.gets.inc(keys.size)
+      val returnValue = new java.util.HashMap[K, V](keys.size)
+      val misses = new java.util.ArrayList[K]
+      val keysIterator = keys.iterator
+      while (keysIterator.hasNext) {
+        val key = keysIterator.next
+        val cached = cache.get(key)
+        if (cached != null) {
+          metrics.cacheHits.inc
+          returnValue.put(key, cached.value)
+        } else {
+          misses.add(key)
+        }
       }
-    }
-    if (!misses.isEmpty) {
-      val entryIterator = store.getAll(misses).entrySet.iterator
-      while (entryIterator.hasNext) {
-        val entry = entryIterator.next
-        returnValue.put(entry.getKey, entry.getValue)
-        cache.put(entry.getKey, new CacheEntry(entry.getValue, null))
+      if (!misses.isEmpty) {
+        val entryIterator = store.getAll(misses).entrySet.iterator
+        while (entryIterator.hasNext) {
+          val entry = entryIterator.next
+          returnValue.put(entry.getKey, entry.getValue)
+          cache.put(entry.getKey, new CacheEntry(entry.getValue, null))
+        }
+        cacheCount = cache.size // update outside the loop since it's used for 
metrics and not for time-sensitive logic
       }
-      cacheCount = cache.size // update outside the loop since it's used for 
metrics and not for time-sensitive logic
-    }
-    returnValue
+      returnValue
+    })
   }
 
   def hasArrayKeys = containsArrayKeys

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala 
b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
index 96eb5fa..e16bdc0 100644
--- a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
+++ b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
@@ -127,21 +127,6 @@ class TestCachedStore {
     }
     assertFalse(iter.hasNext)
 
-    // test iterator remove
-    iter = store.all()
-    iter.next()
-    iter.remove()
-
-    assertNull(kv.get(keys.get(0)))
-    assertNull(store.get(keys.get(0)))
-
-    iter = store.range(keys.get(1), keys.get(2))
-    iter.next()
-    iter.remove()
-
-    assertFalse(iter.hasNext)
-    assertNull(kv.get(keys.get(1)))
-    assertNull(store.get(keys.get(1)))
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
----------------------------------------------------------------------
diff --git 
a/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
 
b/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
index fd4e762..d7d23ec 100644
--- 
a/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
+++ 
b/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
@@ -20,7 +20,9 @@
 package org.apache.samza.storage.kv
 
 import java.io.File
-import java.util.{Arrays, Random}
+import java.util.Arrays
+import java.util.Random
+import java.util.concurrent.CountDownLatch
 
 import org.apache.samza.config.MapConfig
 import org.apache.samza.serializers.Serde
@@ -29,7 +31,9 @@ import org.junit.Assert._
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
 import org.junit.runners.Parameterized.Parameters
-import org.junit.{After, Before, Test}
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
 import org.scalatest.Assertions.intercept
 
 import scala.collection.JavaConversions._
@@ -136,8 +140,9 @@ class TestKeyValueStores(typeOfStore: String, storeConfig: 
String) {
 
   @Test
   def putAndGet() {
-    store.put(b("k"), b("v"))
-    assertArrayEquals(b("v"), store.get(b("k")))
+    val k = b("k")
+    store.put(k, b("v"))
+    assertArrayEquals(b("v"), store.get(k))
   }
 
   @Test
@@ -379,6 +384,181 @@ class TestKeyValueStores(typeOfStore: String, 
storeConfig: String) {
     })
   }
 
+  @Test
+  def testParallelReadWriteSameKey(): Unit = {
+    // Make test deterministic by seeding the random number generator.
+    val key = b("key")
+    val val1 = "val1"
+    val val2 = "val2"
+
+    val runner1 = new Thread(new Runnable {
+      override def run(): Unit = {
+        store.put(key, b(val1))
+      }
+    })
+
+    val runner2 = new Thread(new Runnable {
+      override def run(): Unit = {
+        while(!val1.equals({store.get(key) match {
+          case null => ""
+          case _ => { new String(store.get(key), "UTF-8") }
+        }})) {}
+        store.put(key, b(val2))
+      }
+    })
+
+    runner2.start()
+    runner1.start()
+
+    runner2.join(1000)
+    runner1.join(1000)
+
+    assertEquals("val2", new String(store.get(key), "UTF-8"))
+
+    store.delete(key)
+    store.flush()
+  }
+
+  @Test
+  def testParallelReadWriteDiffKeys(): Unit = {
+    // Make test deterministic by seeding the random number generator.
+    val key1 = b("key1")
+    val key2 = b("key2")
+    val val1 = "val1"
+    val val2 = "val2"
+
+    val runner1 = new Thread(new Runnable {
+      override def run(): Unit = {
+        store.put(key1, b(val1))
+      }
+    })
+
+    val runner2 = new Thread(new Runnable {
+      override def run(): Unit = {
+        while(!val1.equals({store.get(key1) match {
+          case null => ""
+          case _ => { new String(store.get(key1), "UTF-8") }
+        }})) {}
+        store.delete(key1)
+      }
+    })
+
+    val runner3 = new Thread(new Runnable {
+      override def run(): Unit = {
+        store.put(key2, b(val2))
+      }
+    })
+
+    val runner4 = new Thread(new Runnable {
+      override def run(): Unit = {
+        while(!val2.equals({store.get(key2) match {
+          case null => ""
+          case _ => { new String(store.get(key2), "UTF-8") }
+        }})) {}
+        store.delete(key2)
+      }
+    })
+
+    runner2.start()
+    runner1.start()
+    runner3.start()
+    runner4.start()
+
+    runner2.join(1000)
+    runner1.join(1000)
+    runner3.join(1000)
+    runner4.join(1000)
+
+    assertNull(store.get(key1))
+    assertNull(store.get(key2))
+
+    store.flush()
+  }
+
+  @Test
+  def testParallelIteratorAndWrite(): Unit = {
+    // Make test deterministic by seeding the random number generator.
+    val key1 = b("key1")
+    val key2 = b("key2")
+    val val1 = "val1"
+    val val2 = "val2"
+    @volatile var throwable: Throwable = null
+
+    store.put(key1, b(val1))
+    store.put(key2, b(val2))
+
+    val runner1StartLatch = new CountDownLatch(1)
+    val runner2StartLatch = new CountDownLatch(1)
+
+    val runner1 = new Thread(new Runnable {
+      override def run(): Unit = {
+        runner1StartLatch.await()
+        store.put(key1, b("val1-2"))
+        store.delete(key2)
+        store.flush()
+        runner2StartLatch.countDown()
+      }
+    })
+
+    val runner2 = new Thread(new Runnable {
+      override def run(): Unit = {
+        runner2StartLatch.await()
+        val iter = store.all() //snapshot after change
+        try {
+          while (iter.hasNext) {
+            val e = iter.next()
+            if ("key1".equals(new String(e.getKey, "UTF-8"))) {
+              assertEquals("val1-2", new String(e.getValue, "UTF-8"))
+            }
+            System.out.println(String.format("iterator1: key: %s, value: %s", 
new String(e.getKey, "UTF-8"), new String(e.getValue, "UTF-8")))
+          }
+          iter.close()
+        } catch {
+          case t: Throwable => throwable = t
+        }
+      }
+    })
+
+    val runner3 = new Thread(new Runnable {
+      override def run(): Unit = {
+        val iter = store.all()  //snapshot
+        runner1StartLatch.countDown()
+        try {
+          while (iter.hasNext) {
+            val e = iter.next()
+            val key = new String(e.getKey, "UTF-8")
+            val value = new String(e.getValue, "UTF-8")
+            if (key.equals("key1")) {
+              assertEquals(val1, value)
+            }
+            else if (key.equals("key2") && !val2.equals(value)) {
+              assertEquals(val2, value)
+            }
+            else if (!key.equals("key1") && !key.equals("key2")) {
+              throw new Exception("unknow key " + new String(e.getKey, 
"UTF-8") + ", value: " + new String(e.getValue, "UTF-8"))
+            }
+            System.out.println(String.format("iterator2: key: %s, value: %s", 
new String(e.getKey, "UTF-8"), new String(e.getValue, "UTF-8")))
+          }
+          iter.close()
+        } catch {
+          case t: Throwable => throwable = t
+        }
+      }
+    })
+
+    runner2.start()
+    runner3.start()
+    runner1.start()
+
+    runner2.join()
+    runner1.join()
+    runner3.join()
+
+    if(throwable != null) throw throwable
+
+    store.flush()
+  }
+
   def checkRange(vals: IndexedSeq[String], iter: KeyValueIterator[Array[Byte], 
Array[Byte]]) {
     for (v <- vals) {
       assertTrue(iter.hasNext)
@@ -417,5 +597,6 @@ object TestKeyValueStores {
       Array("rocksdb","cache"),
       Array("rocksdb","serde"),
       Array("rocksdb","cache-and-serde"),
-      Array("rocksdb","none"))
+      Array("rocksdb","none")
+  )
 }

Reply via email to