Updated Branches:
  refs/heads/master c988f4f40 -> dca9555ed

SAMZA-94; protect against misbehaving serdes in key value store, and disallow 
null key and values in samza-kv.


Project: http://git-wip-us.apache.org/repos/asf/incubator-samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-samza/commit/dca9555e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-samza/tree/dca9555e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-samza/diff/dca9555e

Branch: refs/heads/master
Commit: dca9555ed4afebb16d10377b41d8e7d727a053b9
Parents: c988f4f
Author: Chris Riccomini <[email protected]>
Authored: Mon Dec 9 12:21:26 2013 -0800
Committer: Chris Riccomini <[email protected]>
Committed: Mon Dec 9 12:21:26 2013 -0800

----------------------------------------------------------------------
 build.gradle                                    |  3 +-
 .../apache/samza/storage/kv/KeyValueStore.java  |  9 ++++-
 .../storage/kv/SerializedKeyValueStore.scala    | 41 ++++++++++++++------
 .../samza/storage/kv/TestKeyValueStores.scala   | 30 ++++++++------
 .../scala/org/apache/samza/util/TestUtil.scala  | 14 +++++++
 5 files changed, 71 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-samza/blob/dca9555e/build.gradle
----------------------------------------------------------------------
diff --git a/build.gradle b/build.gradle
index 556a0a3..49a8459 100644
--- a/build.gradle
+++ b/build.gradle
@@ -169,6 +169,7 @@ project(":samza-kv_$scalaVersion") {
     compile "org.clapper:grizzled-slf4j_$scalaVersion:$grizzledVersion"
     compile "org.fusesource.leveldbjni:leveldbjni-all:$leveldbVersion"
     testCompile "junit:junit:$junitVersion"
+    testCompile project(":samza-test_$scalaVersion")
   }
 }
 
@@ -182,8 +183,8 @@ project(":samza-test_$scalaVersion") {
     compile "org.clapper:grizzled-slf4j_$scalaVersion:$grizzledVersion"
     compile "net.sf.jopt-simple:jopt-simple:$joptSimpleVersion"
     compile "javax.mail:mail:1.4"
+    compile "junit:junit:$junitVersion"
     compile files("../samza-kafka/lib/kafka_$scalaVersion-" + kafkaVersion + 
".jar")
-    testCompile "junit:junit:$junitVersion"
     testCompile files("../samza-kafka/lib/kafka_$scalaVersion-" + kafkaVersion 
+ "-test.jar")
     testCompile "com.101tec:zkclient:$zkClientVersion"
     testCompile project(":samza-core_$scalaVersion")

http://git-wip-us.apache.org/repos/asf/incubator-samza/blob/dca9555e/samza-kv/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java 
b/samza-kv/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java
index 49089df..bdaa234 100644
--- a/samza-kv/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java
+++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java
@@ -22,8 +22,8 @@ package org.apache.samza.storage.kv;
 import java.util.List;
 
 /**
- * A key-value store that supports put/get/delete and range queries
- *
+ * A key-value store that supports put/get/delete and range queries.
+ * 
  * @param <K> The key type
  * @param <V> The value type
  */
@@ -33,6 +33,7 @@ public interface KeyValueStore<K, V> {
    * Get the value corresponding to this key
    * @param key The key to fetch
    * @return The value or null if no value is found.
+   * @throws NullPointerException If null is used for key.
    */
   public V get(K key);
   
@@ -40,17 +41,20 @@ public interface KeyValueStore<K, V> {
    * Update the value associated with this key
    * @param key They key to associate the value to
    * @param value The value
+   * @throws NullPointerException If null is used for key or value.
    */
   public void put(K key, V value);
   
   /**
    * Update all the given key/value pairs
+   * @throws NullPointerException If null is used for any key or value.
    */
   public void putAll(List<Entry<K,V>> entries);
   
   /**
    * Delete the value from the store (if there is one)
    * @param key The key
+   * @throws NullPointerException If null is used for key.
    */
   public void delete(K key);
   
@@ -59,6 +63,7 @@ public interface KeyValueStore<K, V> {
    * @param from The first key that could be in the range
    * @param to The last key that could be in the range
    * @return The iterator for this range.
+   * @throws NullPointerException If null is used for from or to.
    */
   public KeyValueIterator<K,V> range(K from, K to);
   

http://git-wip-us.apache.org/repos/asf/incubator-samza/blob/dca9555e/samza-kv/src/main/scala/org/apache/samza/storage/kv/SerializedKeyValueStore.scala
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/main/scala/org/apache/samza/storage/kv/SerializedKeyValueStore.scala
 
b/samza-kv/src/main/scala/org/apache/samza/storage/kv/SerializedKeyValueStore.scala
index 53a5cbe..c3bf4dc 100644
--- 
a/samza-kv/src/main/scala/org/apache/samza/storage/kv/SerializedKeyValueStore.scala
+++ 
b/samza-kv/src/main/scala/org/apache/samza/storage/kv/SerializedKeyValueStore.scala
@@ -33,7 +33,7 @@ class SerializedKeyValueStore[K, V](
   metrics: SerializedKeyValueStoreMetrics = new 
SerializedKeyValueStoreMetrics) extends KeyValueStore[K, V] with Logging {
 
   def get(key: K): V = {
-    val keyBytes = keySerde.toBytes(key)
+    val keyBytes = bytesNotNull(key, keySerde)
     val found = store.get(keyBytes)
     metrics.gets.inc
     metrics.bytesSerialized.inc(keyBytes.size)
@@ -47,8 +47,8 @@ class SerializedKeyValueStore[K, V](
 
   def put(key: K, value: V) {
     metrics.puts.inc
-    val keyBytes = keySerde.toBytes(key)
-    val valBytes = msgSerde.toBytes(value)
+    val keyBytes = bytesNotNull(key, keySerde)
+    val valBytes = bytesNotNull(value, msgSerde)
     metrics.bytesSerialized.inc(keyBytes.size + valBytes.size)
     store.put(keyBytes, valBytes)
   }
@@ -59,8 +59,8 @@ class SerializedKeyValueStore[K, V](
     var bytesSerialized = 0L
     while (iter.hasNext) {
       val curr = iter.next
-      val keyBytes = keySerde.toBytes(curr.getKey)
-      val valBytes = msgSerde.toBytes(curr.getValue)
+      val keyBytes = bytesNotNull(curr.getKey, keySerde)
+      val valBytes = bytesNotNull(curr.getValue, msgSerde)
       bytesSerialized += keyBytes.size
       if (valBytes != null) {
         bytesSerialized += valBytes.size
@@ -74,15 +74,15 @@ class SerializedKeyValueStore[K, V](
 
   def delete(key: K) {
     metrics.deletes.inc
-    val keyBytes = keySerde.toBytes(key)
+    val keyBytes = bytesNotNull(key, keySerde)
     metrics.bytesSerialized.inc(keyBytes.size)
     store.delete(keyBytes)
   }
 
   def range(from: K, to: K): KeyValueIterator[K, V] = {
     metrics.ranges.inc
-    val fromBytes = keySerde.toBytes(from)
-    val toBytes = keySerde.toBytes(to)
+    val fromBytes = bytesNotNull(from, keySerde)
+    val toBytes = bytesNotNull(to, keySerde)
     metrics.bytesSerialized.inc(fromBytes.size + toBytes.size)
     new DeserializingIterator(store.range(fromBytes, toBytes))
   }
@@ -100,11 +100,20 @@ class SerializedKeyValueStore[K, V](
       val nxt = iter.next()
       val keyBytes = nxt.getKey
       val valBytes = nxt.getValue
-      metrics.bytesDeserialized.inc(keyBytes.size)
-      if (valBytes != null) {
+      val key = if (keyBytes != null) {
+        metrics.bytesDeserialized.inc(keyBytes.size)
+        keySerde.fromBytes(keyBytes).asInstanceOf[K]
+      } else {
+        warn("Got a null key while iterating over a store. This is highly 
unexpected, since null in key and value is disallowed for key value stores.")
+        null.asInstanceOf[K]
+      }
+      val value = if (valBytes != null) {
         metrics.bytesDeserialized.inc(valBytes.size)
+        msgSerde.fromBytes(valBytes)
+      } else {
+        null.asInstanceOf[V]
       }
-      new Entry(keySerde.fromBytes(keyBytes).asInstanceOf[K], 
msgSerde.fromBytes(valBytes).asInstanceOf[V])
+      new Entry(key, value)
     }
   }
 
@@ -121,4 +130,14 @@ class SerializedKeyValueStore[K, V](
 
     store.close
   }
+
+  /**
+   * Null is not allowed for keys and values because some change log systems
+   * (Kafka) model deletes as null.
+   */
+  private def bytesNotNull[T](t: T, serde: Serde[T]): Array[Byte] = if (t != 
null) {
+    serde.toBytes(t)
+  } else {
+    throw new NullPointerException("Null is not a valid key or value.")
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-samza/blob/dca9555e/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala 
b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
index 2e5f6a3..e8db3f2 100644
--- 
a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
+++ 
b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala
@@ -22,9 +22,7 @@ package org.apache.samza.storage.kv
 import java.io.File
 import java.util.Arrays
 import java.util.Random
-
 import scala.collection.JavaConversions._
-
 import org.apache.samza.serializers.IntegerSerde
 import org.iq80.leveldb.Options
 import org.junit.After
@@ -34,6 +32,8 @@ import org.junit.Test
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
 import org.junit.runners.Parameterized.Parameters
+import org.apache.samza.serializers.StringSerde
+import org.apache.samza.util.TestUtil._
 
 @RunWith(value = classOf[Parameterized])
 class TestKeyValueStores(cache: Boolean) {
@@ -82,6 +82,22 @@ class TestKeyValueStores(cache: Boolean) {
   }
 
   @Test
+  def testNulls() {
+    val stringSerde = new StringSerde("UTF-8")
+    val serializedStore = new SerializedKeyValueStore(store, stringSerde, 
stringSerde)
+    val expectedNPEMessage = Some("Null is not a valid key or value.")
+
+    expect(classOf[NullPointerException], expectedNPEMessage) { 
serializedStore.get(null) }
+    expect(classOf[NullPointerException], expectedNPEMessage) { 
serializedStore.delete(null) }
+    expect(classOf[NullPointerException], expectedNPEMessage) { 
serializedStore.put(null, "") }
+    expect(classOf[NullPointerException], expectedNPEMessage) { 
serializedStore.put("", null) }
+    expect(classOf[NullPointerException], expectedNPEMessage) { 
serializedStore.putAll(List(new Entry("", ""), new Entry[String, String]("", 
null))) }
+    expect(classOf[NullPointerException], expectedNPEMessage) { 
serializedStore.putAll(List(new Entry[String, String](null, ""))) }
+    expect(classOf[NullPointerException], expectedNPEMessage) { 
serializedStore.range("", null) }
+    expect(classOf[NullPointerException], expectedNPEMessage) { 
serializedStore.range(null, "") }
+  }
+
+  @Test
   def testPutAll() {
     // Use CacheSize - 1 so we fully fill the cache, but don't write any data 
     // out. Our check (below) uses == for cached entries, and using 
@@ -139,16 +155,6 @@ class TestKeyValueStores(cache: Boolean) {
     vals.foreach(v => assertNull(store.get(v)))
   }
 
-  @Test
-  def testSerializedValueIsNull {
-    val serializedStore = new SerializedKeyValueStore(
-      store,
-      new IntegerSerde,
-      new IntegerSerde)
-
-    serializedStore.putAll(List(new Entry[java.lang.Integer, 
java.lang.Integer](0, null)))
-  }
-
   /**
    * This test specifically targets an issue in Scala 2.8.1's DoubleLinkedList
    * implementation. The issue is that it doesn't work. More specifically,

http://git-wip-us.apache.org/repos/asf/incubator-samza/blob/dca9555e/samza-test/src/main/scala/org/apache/samza/util/TestUtil.scala
----------------------------------------------------------------------
diff --git a/samza-test/src/main/scala/org/apache/samza/util/TestUtil.scala 
b/samza-test/src/main/scala/org/apache/samza/util/TestUtil.scala
new file mode 100644
index 0000000..69c085a
--- /dev/null
+++ b/samza-test/src/main/scala/org/apache/samza/util/TestUtil.scala
@@ -0,0 +1,14 @@
+package org.apache.samza.util
+
+import org.junit.Assert._
+
+object TestUtil {
+  def expect[T](exception: Class[T], msg: Option[String] = None)(block: => 
Unit) = try {
+    block
+  } catch {
+    case e => if (msg.isDefined) {
+      assertEquals(msg.get, e.getMessage)
+    }
+    case _ => fail("Expected an NPE.")
+  }
+}
\ No newline at end of file

Reply via email to