Repository: samza Updated Branches: refs/heads/master 9396ee5cc -> e5f31c57c
http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala ---------------------------------------------------------------------- diff --git a/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala b/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala index 3769e10..5a16580 100644 --- a/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala +++ b/samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala @@ -19,154 +19,181 @@ package org.apache.samza.system.kafka -import org.apache.samza.util.Logging -import org.apache.kafka.clients.producer.{RecordMetadata, Callback, ProducerRecord, Producer} -import org.apache.samza.system.SystemProducer + +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Future + +import org.apache.kafka.clients.producer.Callback +import org.apache.kafka.clients.producer.Producer +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.clients.producer.RecordMetadata +import org.apache.kafka.common.PartitionInfo +import org.apache.samza.SamzaException import org.apache.samza.system.OutgoingMessageEnvelope +import org.apache.samza.system.SystemProducer import org.apache.samza.util.ExponentialSleepStrategy -import org.apache.samza.util.TimerUtils import org.apache.samza.util.KafkaUtil -import java.util.concurrent.atomic.{AtomicInteger, AtomicReference, AtomicBoolean} -import java.util.{Map => javaMap} -import org.apache.samza.SamzaException -import org.apache.kafka.common.errors.RetriableException -import org.apache.kafka.common.PartitionInfo -import java.util -import java.util.concurrent.Future -import scala.collection.JavaConversions._ +import org.apache.samza.util.Logging +import org.apache.samza.util.TimerUtils +import scala.collection.JavaConversions._ class KafkaSystemProducer(systemName: String, retryBackoff: ExponentialSleepStrategy = new ExponentialSleepStrategy, getProducer: () => Producer[Array[Byte], Array[Byte]], metrics: KafkaSystemProducerMetrics, - val clock: () => Long = () => System.nanoTime, - val maxRetries: Int = 30) extends SystemProducer with Logging with TimerUtils + val clock: () => Long = () => System.nanoTime) extends SystemProducer with Logging with TimerUtils { - var producer: Producer[Array[Byte], Array[Byte]] = null - val latestFuture: javaMap[String, Future[RecordMetadata]] = new util.HashMap[String, Future[RecordMetadata]]() - val sendFailed: AtomicBoolean = new AtomicBoolean(false) - var exceptionThrown: AtomicReference[Exception] = new AtomicReference[Exception]() - val StreamNameNullOrEmptyErrorMsg = "Stream Name should be specified in the stream configuration file."; - - // Backward-compatible constructor for Java clients - def this(systemName: String, - retryBackoff: ExponentialSleepStrategy, - getProducer: () => Producer[Array[Byte], Array[Byte]], - metrics: KafkaSystemProducerMetrics, - clock: () => Long) = this(systemName, retryBackoff, getProducer, metrics, clock, 30) - - def start() { + + class SourceData { + /** + * lock to make send() and store its future atomic + */ + val sendLock: Object = new Object + /** + * The most recent send's Future handle + */ + @volatile + var latestFuture: Future[RecordMetadata] = null + /** + * exceptionThrown: to store the exception in case of any "ultimate" send failure (ie. failure + * after exhausting max_retries in Kafka producer) in the I/O thread, we do not continue to queue up more send + * requests from the samza thread. It helps the samza thread identify if the failure happened in I/O thread or not. + */ + @volatile + var exceptionThrown: SamzaException = null + } + + @volatile var producer: Producer[Array[Byte], Array[Byte]] = null + var producerLock: Object = new Object + val StreamNameNullOrEmptyErrorMsg = "Stream Name should be specified in the stream configuration file." + val sources: ConcurrentHashMap[String, SourceData] = new ConcurrentHashMap[String, SourceData] + + def start(): Unit = { + producerLock.synchronized { + if (producer == null) { + info("Creating a new producer for system %s." format systemName) + producer = getProducer() + } + } } def stop() { - if (producer != null) { - latestFuture.keys.foreach(flush(_)) - producer.close - producer = null + producerLock.synchronized { + try { + if (producer != null) { + producer.close + producer = null + + sources.foreach {p => + if (p._2.exceptionThrown == null) { + flush(p._1) + } + } + } + } catch { + case e: Exception => logger.error(e.getMessage, e) + } } } def register(source: String) { - if(latestFuture.containsKey(source)) { + if(sources.putIfAbsent(source, new SourceData) != null) { throw new SamzaException("%s is already registered with the %s system producer" format (source, systemName)) } - latestFuture.put(source, null) } def send(source: String, envelope: OutgoingMessageEnvelope) { - var numRetries: AtomicInteger = new AtomicInteger(0) - trace("Enqueueing message: %s, %s." format (source, envelope)) - if(producer == null) { - info("Creating a new producer for system %s." format systemName) - producer = getProducer() - debug("Created a new producer for system %s." format systemName) - } - // Java-based Kafka producer API requires an "Integer" type partitionKey and does not allow custom overriding of Partitioners - // Any kind of custom partitioning has to be done on the client-side + trace("Enqueuing message: %s, %s." format (source, envelope)) + val topicName = envelope.getSystemStream.getStream if (topicName == null || topicName == "") { throw new IllegalArgumentException(StreamNameNullOrEmptyErrorMsg) } - val partitions: java.util.List[PartitionInfo] = producer.partitionsFor(topicName) - val partitionKey = if(envelope.getPartitionKey != null) KafkaUtil.getIntegerPartitionKey(envelope, partitions) else null + + val sourceData = sources.get(source) + if (sourceData == null) { + throw new IllegalArgumentException("Source %s must be registered first before send." format source) + } + + val exception = sourceData.exceptionThrown + if (exception != null) { + metrics.sendFailed.inc + throw exception + } + + val currentProducer = producer + if (currentProducer == null) { + throw new SamzaException("Kafka system producer is not available.") + } + + // Java-based Kafka producer API requires an "Integer" type partitionKey and does not allow custom overriding of Partitioners + // Any kind of custom partitioning has to be done on the client-side + val partitions: java.util.List[PartitionInfo] = currentProducer.partitionsFor(topicName) + val partitionKey = if (envelope.getPartitionKey != null) KafkaUtil.getIntegerPartitionKey(envelope, partitions) + else null val record = new ProducerRecord(envelope.getSystemStream.getStream, partitionKey, envelope.getKey.asInstanceOf[Array[Byte]], envelope.getMessage.asInstanceOf[Array[Byte]]) - sendFailed.set(false) - - retryBackoff.run( - loop => { - if(sendFailed.get()) { - throw exceptionThrown.get() - } + try { + sourceData.sendLock.synchronized { val futureRef: Future[RecordMetadata] = - producer.send(record, new Callback { + currentProducer.send(record, new Callback { def onCompletion(metadata: RecordMetadata, exception: Exception): Unit = { if (exception == null) { //send was successful. Don't retry metrics.sendSuccess.inc } else { - //If there is an exception in the callback, it means that the Kafka producer has exhausted the max-retries - //Hence, fail container! - exceptionThrown.compareAndSet(null, exception) - sendFailed.set(true) + //If there is an exception in the callback, fail container! + //Close producer. + currentProducer.close + sourceData.exceptionThrown = new SamzaException("Unable to send message from %s to system %s." format(source, systemName), + exception) + metrics.sendFailed.inc + logger.error("Unable to send message on Topic:%s Partition:%s" format(topicName, partitionKey), + exception) } } }) - latestFuture.put(source, futureRef) - metrics.sends.inc - if(!sendFailed.get()) - loop.done - }, - (exception, loop) => { - if((exception != null && !exception.isInstanceOf[RetriableException]) || numRetries.get() >= maxRetries) { - // Irrecoverable exceptions. - error("Exception detail : ", exception) - //Close producer - stop() - producer = null - //Mark loop as done as we are not going to retry - loop.done - metrics.sendFailed.inc - throw new SamzaException(("Failed to send message on Topic:%s Partition:%s NumRetries:%s Exception:\n %s,") - .format(topicName, partitionKey, numRetries, exception)) - } else { - numRetries.incrementAndGet() - warn(("Retrying send due to RetriableException - %s for Topic:%s Partition:%s. " + - "Turn on debugging to get a full stack trace").format(exception, topicName, partitionKey)) - debug("Exception detail:", exception) - metrics.retries.inc - } + sourceData.latestFuture = futureRef + } + metrics.sends.inc + } catch { + case e: Exception => { + currentProducer.close() + metrics.sendFailed.inc + throw new SamzaException(("Failed to send message on Topic:%s Partition:%s Exception:\n %s,") + .format(topicName, partitionKey, e)) } - ) + } } def flush(source: String) { updateTimer(metrics.flushNs) { metrics.flushes.inc + + val sourceData = sources.get(source) //if latestFuture is null, it probably means that there has been no calls to "send" messages //Hence, nothing to do in flush - if(latestFuture.get(source) != null) { - while (!latestFuture.get(source).isDone && !sendFailed.get()) { - //do nothing - } - if (sendFailed.get()) { - logger.error("Unable to send message from %s to system %s" format(source, systemName)) - //Close producer. - if (producer != null) { - producer.close + if(sourceData.latestFuture != null) { + while(!sourceData.latestFuture.isDone && sourceData.exceptionThrown == null) { + try { + sourceData.latestFuture.get() + } catch { + case t: Throwable => error(t.getMessage, t) } - producer = null + } + + if (sourceData.exceptionThrown != null) { metrics.flushFailed.inc - throw new SamzaException("Unable to send message from %s to system %s." format(source, systemName), exceptionThrown.get) + throw sourceData.exceptionThrown } else { trace("Flushed %s." format (source)) } } } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java ---------------------------------------------------------------------- diff --git a/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java b/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java index 04c9113..224ca2f 100644 --- a/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java +++ b/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java @@ -50,9 +50,7 @@ public class TestKafkaSystemProducerJava { } }); - // Default value should have been used. - assertEquals(30, ksp.maxRetries()); long now = System.currentTimeMillis(); assertTrue((Long)ksp.clock().apply() >= now); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala ---------------------------------------------------------------------- diff --git a/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala b/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala index 8e32bba..fab998a 100644 --- a/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala +++ b/samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala @@ -140,7 +140,7 @@ class TestKafkaSystemProducer { systemProducer.flush("test") } assertTrue(thrown.isInstanceOf[SamzaException]) - assertEquals(2, mockProducer.getMsgsSent) + assertEquals(3, mockProducer.getMsgsSent) // msg1, msg2 and msg4 will be sent systemProducer.stop() } @@ -150,14 +150,12 @@ class TestKafkaSystemProducer { val msg2 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), "b".getBytes) val msg3 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), "c".getBytes) val msg4 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), "d".getBytes) - val numMaxRetries = 3 val mockProducer = new MockKafkaProducer(1, "test", 1) val producerMetrics = new KafkaSystemProducerMetrics() val producer = new KafkaSystemProducer(systemName = "test", getProducer = () => mockProducer, - metrics = producerMetrics, - maxRetries = numMaxRetries) + metrics = producerMetrics) producer.register("test") producer.start() @@ -169,14 +167,15 @@ class TestKafkaSystemProducer { assertEquals(0, producerMetrics.retries.getCount) mockProducer.setErrorNext(true, new TimeoutException()) + producer.send("test", msg4) val thrown = intercept[SamzaException] { - producer.send("test", msg4) + producer.flush("test") } assertTrue(thrown.isInstanceOf[SamzaException]) assertTrue(thrown.getCause.isInstanceOf[TimeoutException]) - assertEquals(true, producer.sendFailed.get()) assertEquals(3, mockProducer.getMsgsSent) - assertEquals(numMaxRetries, producerMetrics.retries.getCount) + // retriable exception will be thrown immediately + assertEquals(0, producerMetrics.retries.getCount) producer.stop() } @@ -199,12 +198,12 @@ class TestKafkaSystemProducer { producer.send("test", msg3) mockProducer.setErrorNext(true, new RecordTooLargeException()) + producer.send("test", msg4) val thrown = intercept[SamzaException] { - producer.send("test", msg4) + producer.flush("test") } assertTrue(thrown.isInstanceOf[SamzaException]) assertTrue(thrown.getCause.isInstanceOf[RecordTooLargeException]) - assertEquals(true, producer.sendFailed.get()) assertEquals(3, mockProducer.getMsgsSent) assertEquals(0, producerMetrics.retries.getCount) producer.stop() http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala ---------------------------------------------------------------------- diff --git a/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala b/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala index 72f25a3..4c245b6 100644 --- a/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala +++ b/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala @@ -26,14 +26,14 @@ import java.util /** * In memory implementation of a key value store. * - * This uses a TreeMap to store the keys in order + * This uses a ConcurrentSkipListMap to store the keys in order * * @param metrics A metrics instance to publish key-value store related statistics */ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStoreMetrics) extends KeyValueStore[Array[Byte], Array[Byte]] with Logging { - val underlying = new util.TreeMap[Array[Byte], Array[Byte]] (UnsignedBytes.lexicographicalComparator()) + val underlying = new util.concurrent.ConcurrentSkipListMap[Array[Byte], Array[Byte]] (UnsignedBytes.lexicographicalComparator()) override def flush(): Unit = { // No-op for In memory store. @@ -47,7 +47,7 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor override def close(): Unit = Unit - override def remove(): Unit = iter.remove() + override def remove(): Unit = throw new UnsupportedOperationException("InMemoryKeyValueStore iterator doesn't support remove") override def next(): Entry[Array[Byte], Array[Byte]] = { val n = iter.next() http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala ---------------------------------------------------------------------- diff --git a/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala b/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala index 9b9b1f6..73b89f7 100644 --- a/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala +++ b/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala @@ -106,7 +106,6 @@ class RocksDbKeyValueStore( // after the directories are created, which happens much later from now. private lazy val db = RocksDbKeyValueStore.openDB(dir, options, storeConfig, isLoggedStore, storeName) private val lexicographic = new LexicographicComparator() - private var deletesSinceLastCompaction = 0 def get(key: Array[Byte]): Array[Byte] = { metrics.gets.inc @@ -141,7 +140,6 @@ class RocksDbKeyValueStore( require(key != null, "Null key not allowed.") if (value == null) { db.remove(writeOptions, key) - deletesSinceLastCompaction += 1 } else { metrics.bytesWritten.inc(key.size + value.size) db.put(writeOptions, key, value) @@ -168,7 +166,6 @@ class RocksDbKeyValueStore( } metrics.puts.inc(wrote) metrics.deletes.inc(deletes) - deletesSinceLastCompaction += deletes } def delete(key: Array[Byte]) { http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala ---------------------------------------------------------------------- diff --git a/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala b/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala index b7f1cdc..05d39ea 100644 --- a/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala +++ b/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala @@ -26,7 +26,7 @@ import java.util import org.apache.samza.config.MapConfig import org.apache.samza.util.ExponentialSleepStrategy import org.junit.{Assert, Test} -import org.rocksdb.{RocksDB, FlushOptions, Options} +import org.rocksdb.{RocksIterator, RocksDB, FlushOptions, Options} class TestRocksDbKeyValueStore { @@ -85,4 +85,61 @@ class TestRocksDbKeyValueStore rocksDB.close() rocksDBReadOnly.close() } + + @Test + def testIteratorWithRemoval(): Unit = { + val lock = new Object + + val map = new util.HashMap[String, String]() + val config = new MapConfig(map) + val options = new Options() + options.setCreateIfMissing(true) + val rocksDB = RocksDbKeyValueStore.openDB(new File(System.getProperty("java.io.tmpdir")), + options, + config, + false, + "dbStore") + + val key = "key".getBytes("UTF-8") + val key1 = "key1".getBytes("UTF-8") + val value = "val".getBytes("UTF-8") + val value1 = "val1".getBytes("UTF-8") + + var iter: RocksIterator = null + + lock.synchronized { + rocksDB.put(key, value) + rocksDB.put(key1, value1) + // SAMZA-836: Mysteriously,calling new FlushOptions() does not invoke the NativeLibraryLoader in rocksdbjni-3.13.1! + // Moving this line after calling new Options() resolve the issue. + val flushOptions = new FlushOptions().setWaitForFlush(true) + rocksDB.flush(flushOptions) + + iter = rocksDB.newIterator() + iter.seekToFirst() + } + + while (iter.isValid) { + iter.next() + } + iter.dispose() + + lock.synchronized { + rocksDB.remove(key) + iter = rocksDB.newIterator() + iter.seek(key) + } + + while (iter.isValid) { + iter.next() + } + iter.dispose() + + val dbDir = new File(System.getProperty("java.io.tmpdir")).toString + val rocksDBReadOnly = RocksDB.openReadOnly(options, dbDir) + Assert.assertEquals(new String(rocksDBReadOnly.get(key1), "UTF-8"), "val1") + Assert.assertEquals(rocksDBReadOnly.get(key), null) + rocksDB.close() + rocksDBReadOnly.close() + } } http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala ---------------------------------------------------------------------- diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala index e7e4ede..44f96b4 100644 --- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala +++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala @@ -37,7 +37,7 @@ import java.util.Arrays * that have not yet been written to disk. All writes go to the dirty list and when the list is long enough we call putAll on all those values at once. Dirty items * that time out of the cache before being written will also trigger a putAll of the dirty list. * - * This class is very non-thread safe. + * This class is thread safe. * * @param store The store to cache * @param cacheSize The number of entries to hold in the in memory-cache @@ -59,6 +59,9 @@ class CachedStore[K, V]( /** the list of items to be written out on flush from newest to oldest */ private var dirty = new mutable.DoubleLinkedList[K]() + /** the synchronization lock to protect access to the store from multiple threads **/ + private val lock = new Object + /** an lru cache of values that holds cacheEntries and calls putAll() on dirty entries if necessary when discarding */ private val cache = new java.util.LinkedHashMap[K, CacheEntry[K, V]]((cacheSize * 1.2).toInt, 1.0f, true) { override def removeEldestEntry(eldest: java.util.Map.Entry[K, CacheEntry[K, V]]): Boolean = { @@ -76,7 +79,7 @@ class CachedStore[K, V]( } /** tracks whether an array has been used as a key. since this is dangerous with LinkedHashMap, we want to warn on it. **/ - private var containsArrayKeys = false + @volatile private var containsArrayKeys = false // Use counters here, rather than directly accessing variables using .size // since metrics can be accessed in other threads, and cache.size is not @@ -85,7 +88,7 @@ class CachedStore[K, V]( metrics.setDirtyCount(() => dirtyCount) metrics.setCacheSize(() => cacheCount) - override def get(key: K) = { + override def get(key: K) = lock.synchronized({ metrics.gets.inc val c = cache.get(key) @@ -98,7 +101,7 @@ class CachedStore[K, V]( cacheCount = cache.size v } - } + }) private class CachedStoreIterator(val iter: KeyValueIterator[K, V]) extends KeyValueIterator[K, V] { @@ -107,10 +110,7 @@ class CachedStore[K, V]( override def close(): Unit = iter.close() - override def remove(): Unit = { - iter.remove() - delete(last.getKey) - } + override def remove(): Unit = throw new UnsupportedOperationException("CachedStore iterator doesn't support remove") override def next() = { last = iter.next() @@ -120,79 +120,85 @@ class CachedStore[K, V]( override def hasNext: Boolean = iter.hasNext } - override def range(from: K, to: K): KeyValueIterator[K, V] = { + override def range(from: K, to: K): KeyValueIterator[K, V] = lock.synchronized({ metrics.ranges.inc putAllDirtyEntries() new CachedStoreIterator(store.range(from, to)) - } + }) - override def all(): KeyValueIterator[K, V] = { + override def all(): KeyValueIterator[K, V] = lock.synchronized({ metrics.alls.inc putAllDirtyEntries() new CachedStoreIterator(store.all()) - } + }) override def put(key: K, value: V) { - metrics.puts.inc + lock.synchronized({ + metrics.puts.inc - checkKeyIsArray(key) + checkKeyIsArray(key) - // Add the key to the front of the dirty list (and remove any prior - // occurrences to dedupe). - val found = cache.get(key) - if (found == null || found.dirty == null) { - this.dirtyCount += 1 - } else { - // If we are removing the head of the list, move the head to the next - // element. See SAMZA-45 for details. - if (found.dirty.prev == null) { - this.dirty = found.dirty.next - this.dirty.prev = null + // Add the key to the front of the dirty list (and remove any prior + // occurrences to dedupe). + val found = cache.get(key) + if (found == null || found.dirty == null) { + this.dirtyCount += 1 } else { - found.dirty.remove() + // If we are removing the head of the list, move the head to the next + // element. See SAMZA-45 for details. + if (found.dirty.prev == null) { + this.dirty = found.dirty.next + this.dirty.prev = null + } else { + found.dirty.remove() + } } - } - this.dirty = new mutable.DoubleLinkedList(key, this.dirty) + this.dirty = new mutable.DoubleLinkedList(key, this.dirty) - // Add the key to the cache (but don't allocate a new cache entry if we - // already have one). - if (found == null) { - cache.put(key, new CacheEntry(value, this.dirty)) - cacheCount = cache.size - } else { - found.value = value - found.dirty = this.dirty - } + // Add the key to the cache (but don't allocate a new cache entry if we + // already have one). + if (found == null) { + cache.put(key, new CacheEntry(value, this.dirty)) + cacheCount = cache.size + } else { + found.value = value + found.dirty = this.dirty + } - // putAll() dirty values if the write list is full. - val purgeNeeded = if (dirtyCount >= writeBatchSize) { - debug("Dirty count %s >= write batch size %s. Calling putAll() on all dirty entries." format (dirtyCount, writeBatchSize)) - true - } else if (hasArrayKeys) { - // Flush every time to support the following legacy behavior: - // If array keys are used with a cached store, get() will always miss the cache because of array equality semantics - // However, it will fall back to the underlying store which does support arrays. - true - } else { - false - } + // putAll() dirty values if the write list is full. + val purgeNeeded = if (dirtyCount >= writeBatchSize) { + debug("Dirty count %s >= write batch size %s. Calling putAll() on all dirty entries." format (dirtyCount, writeBatchSize)) + true + } else if (hasArrayKeys) { + // Flush every time to support the following legacy behavior: + // If array keys are used with a cached store, get() will always miss the cache because of array equality semantics + // However, it will fall back to the underlying store which does support arrays. + true + } else { + false + } - if (purgeNeeded) { - putAllDirtyEntries() - } + if (purgeNeeded) { + putAllDirtyEntries() + } + }) } override def flush() { trace("Purging dirty entries from CachedStore.") - metrics.flushes.inc - putAllDirtyEntries() - trace("Flushing store.") - store.flush() + lock.synchronized({ + metrics.flushes.inc + putAllDirtyEntries() + store.flush() + }) trace("Flushed store.") } + /** + * The synchronization lock must be held before calling this method. + */ private def putAllDirtyEntries() { trace("Calling putAll() on dirty entries.") // write out the contents of the dirty list oldest first @@ -212,26 +218,34 @@ class CachedStore[K, V]( } override def putAll(entries: java.util.List[Entry[K, V]]) { - val iter = entries.iterator - while (iter.hasNext) { - val curr = iter.next - put(curr.getKey, curr.getValue) - } + lock.synchronized({ + val iter = entries.iterator + while (iter.hasNext) { + val curr = iter.next + put(curr.getKey, curr.getValue) + } + }) } override def delete(key: K) { - metrics.deletes.inc - put(key, null.asInstanceOf[V]) + lock.synchronized({ + metrics.deletes.inc + put(key, null.asInstanceOf[V]) + }) } override def close() { - trace("Closing.") - flush() - store.close() + lock.synchronized({ + trace("Closing.") + flush() + store.close() + }) } override def deleteAll(keys: java.util.List[K]) = { - KeyValueStore.Extension.deleteAll(this, keys) + lock.synchronized({ + KeyValueStore.Extension.deleteAll(this, keys) + }) } private def checkKeyIsArray(key: K) { @@ -243,30 +257,32 @@ class CachedStore[K, V]( } override def getAll(keys: java.util.List[K]): java.util.Map[K, V] = { - metrics.gets.inc(keys.size) - val returnValue = new java.util.HashMap[K, V](keys.size) - val misses = new java.util.ArrayList[K] - val keysIterator = keys.iterator - while (keysIterator.hasNext) { - val key = keysIterator.next - val cached = cache.get(key) - if (cached != null) { - metrics.cacheHits.inc - returnValue.put(key, cached.value) - } else { - misses.add(key) + lock.synchronized({ + metrics.gets.inc(keys.size) + val returnValue = new java.util.HashMap[K, V](keys.size) + val misses = new java.util.ArrayList[K] + val keysIterator = keys.iterator + while (keysIterator.hasNext) { + val key = keysIterator.next + val cached = cache.get(key) + if (cached != null) { + metrics.cacheHits.inc + returnValue.put(key, cached.value) + } else { + misses.add(key) + } } - } - if (!misses.isEmpty) { - val entryIterator = store.getAll(misses).entrySet.iterator - while (entryIterator.hasNext) { - val entry = entryIterator.next - returnValue.put(entry.getKey, entry.getValue) - cache.put(entry.getKey, new CacheEntry(entry.getValue, null)) + if (!misses.isEmpty) { + val entryIterator = store.getAll(misses).entrySet.iterator + while (entryIterator.hasNext) { + val entry = entryIterator.next + returnValue.put(entry.getKey, entry.getValue) + cache.put(entry.getKey, new CacheEntry(entry.getValue, null)) + } + cacheCount = cache.size // update outside the loop since it's used for metrics and not for time-sensitive logic } - cacheCount = cache.size // update outside the loop since it's used for metrics and not for time-sensitive logic - } - returnValue + returnValue + }) } def hasArrayKeys = containsArrayKeys http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala ---------------------------------------------------------------------- diff --git a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala index 96eb5fa..e16bdc0 100644 --- a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala +++ b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala @@ -127,21 +127,6 @@ class TestCachedStore { } assertFalse(iter.hasNext) - // test iterator remove - iter = store.all() - iter.next() - iter.remove() - - assertNull(kv.get(keys.get(0))) - assertNull(store.get(keys.get(0))) - - iter = store.range(keys.get(1), keys.get(2)) - iter.next() - iter.remove() - - assertFalse(iter.hasNext) - assertNull(kv.get(keys.get(1))) - assertNull(store.get(keys.get(1))) } @Test http://git-wip-us.apache.org/repos/asf/samza/blob/e5f31c57/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala ---------------------------------------------------------------------- diff --git a/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala b/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala index fd4e762..d7d23ec 100644 --- a/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala +++ b/samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala @@ -20,7 +20,9 @@ package org.apache.samza.storage.kv import java.io.File -import java.util.{Arrays, Random} +import java.util.Arrays +import java.util.Random +import java.util.concurrent.CountDownLatch import org.apache.samza.config.MapConfig import org.apache.samza.serializers.Serde @@ -29,7 +31,9 @@ import org.junit.Assert._ import org.junit.runner.RunWith import org.junit.runners.Parameterized import org.junit.runners.Parameterized.Parameters -import org.junit.{After, Before, Test} +import org.junit.After +import org.junit.Before +import org.junit.Test import org.scalatest.Assertions.intercept import scala.collection.JavaConversions._ @@ -136,8 +140,9 @@ class TestKeyValueStores(typeOfStore: String, storeConfig: String) { @Test def putAndGet() { - store.put(b("k"), b("v")) - assertArrayEquals(b("v"), store.get(b("k"))) + val k = b("k") + store.put(k, b("v")) + assertArrayEquals(b("v"), store.get(k)) } @Test @@ -379,6 +384,181 @@ class TestKeyValueStores(typeOfStore: String, storeConfig: String) { }) } + @Test + def testParallelReadWriteSameKey(): Unit = { + // Make test deterministic by seeding the random number generator. + val key = b("key") + val val1 = "val1" + val val2 = "val2" + + val runner1 = new Thread(new Runnable { + override def run(): Unit = { + store.put(key, b(val1)) + } + }) + + val runner2 = new Thread(new Runnable { + override def run(): Unit = { + while(!val1.equals({store.get(key) match { + case null => "" + case _ => { new String(store.get(key), "UTF-8") } + }})) {} + store.put(key, b(val2)) + } + }) + + runner2.start() + runner1.start() + + runner2.join(1000) + runner1.join(1000) + + assertEquals("val2", new String(store.get(key), "UTF-8")) + + store.delete(key) + store.flush() + } + + @Test + def testParallelReadWriteDiffKeys(): Unit = { + // Make test deterministic by seeding the random number generator. + val key1 = b("key1") + val key2 = b("key2") + val val1 = "val1" + val val2 = "val2" + + val runner1 = new Thread(new Runnable { + override def run(): Unit = { + store.put(key1, b(val1)) + } + }) + + val runner2 = new Thread(new Runnable { + override def run(): Unit = { + while(!val1.equals({store.get(key1) match { + case null => "" + case _ => { new String(store.get(key1), "UTF-8") } + }})) {} + store.delete(key1) + } + }) + + val runner3 = new Thread(new Runnable { + override def run(): Unit = { + store.put(key2, b(val2)) + } + }) + + val runner4 = new Thread(new Runnable { + override def run(): Unit = { + while(!val2.equals({store.get(key2) match { + case null => "" + case _ => { new String(store.get(key2), "UTF-8") } + }})) {} + store.delete(key2) + } + }) + + runner2.start() + runner1.start() + runner3.start() + runner4.start() + + runner2.join(1000) + runner1.join(1000) + runner3.join(1000) + runner4.join(1000) + + assertNull(store.get(key1)) + assertNull(store.get(key2)) + + store.flush() + } + + @Test + def testParallelIteratorAndWrite(): Unit = { + // Make test deterministic by seeding the random number generator. + val key1 = b("key1") + val key2 = b("key2") + val val1 = "val1" + val val2 = "val2" + @volatile var throwable: Throwable = null + + store.put(key1, b(val1)) + store.put(key2, b(val2)) + + val runner1StartLatch = new CountDownLatch(1) + val runner2StartLatch = new CountDownLatch(1) + + val runner1 = new Thread(new Runnable { + override def run(): Unit = { + runner1StartLatch.await() + store.put(key1, b("val1-2")) + store.delete(key2) + store.flush() + runner2StartLatch.countDown() + } + }) + + val runner2 = new Thread(new Runnable { + override def run(): Unit = { + runner2StartLatch.await() + val iter = store.all() //snapshot after change + try { + while (iter.hasNext) { + val e = iter.next() + if ("key1".equals(new String(e.getKey, "UTF-8"))) { + assertEquals("val1-2", new String(e.getValue, "UTF-8")) + } + System.out.println(String.format("iterator1: key: %s, value: %s", new String(e.getKey, "UTF-8"), new String(e.getValue, "UTF-8"))) + } + iter.close() + } catch { + case t: Throwable => throwable = t + } + } + }) + + val runner3 = new Thread(new Runnable { + override def run(): Unit = { + val iter = store.all() //snapshot + runner1StartLatch.countDown() + try { + while (iter.hasNext) { + val e = iter.next() + val key = new String(e.getKey, "UTF-8") + val value = new String(e.getValue, "UTF-8") + if (key.equals("key1")) { + assertEquals(val1, value) + } + else if (key.equals("key2") && !val2.equals(value)) { + assertEquals(val2, value) + } + else if (!key.equals("key1") && !key.equals("key2")) { + throw new Exception("unknow key " + new String(e.getKey, "UTF-8") + ", value: " + new String(e.getValue, "UTF-8")) + } + System.out.println(String.format("iterator2: key: %s, value: %s", new String(e.getKey, "UTF-8"), new String(e.getValue, "UTF-8"))) + } + iter.close() + } catch { + case t: Throwable => throwable = t + } + } + }) + + runner2.start() + runner3.start() + runner1.start() + + runner2.join() + runner1.join() + runner3.join() + + if(throwable != null) throw throwable + + store.flush() + } + def checkRange(vals: IndexedSeq[String], iter: KeyValueIterator[Array[Byte], Array[Byte]]) { for (v <- vals) { assertTrue(iter.hasNext) @@ -417,5 +597,6 @@ object TestKeyValueStores { Array("rocksdb","cache"), Array("rocksdb","serde"), Array("rocksdb","cache-and-serde"), - Array("rocksdb","none")) + Array("rocksdb","none") + ) }
