Repository: spark
Updated Branches:
  refs/heads/master bde1d6a61 -> dedbceec1


http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java
 
b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java
new file mode 100644
index 0000000..aba45f5
--- /dev/null
+++ 
b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010;
+
+import java.io.Serializable;
+import java.util.*;
+
+import scala.collection.JavaConverters;
+
+import org.apache.kafka.common.TopicPartition;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class JavaConsumerStrategySuite implements Serializable {
+
+  @Test
+  public void testConsumerStrategyConstructors() {
+    final String topic1 = "topic1";
+    final Collection<String> topics = Arrays.asList(topic1);
+    final scala.collection.Iterable<String> sTopics =
+      JavaConverters.collectionAsScalaIterableConverter(topics).asScala();
+    final TopicPartition tp1 = new TopicPartition(topic1, 0);
+    final TopicPartition tp2 = new TopicPartition(topic1, 1);
+    final Collection<TopicPartition> parts = Arrays.asList(tp1, tp2);
+    final scala.collection.Iterable<TopicPartition> sParts =
+      JavaConverters.collectionAsScalaIterableConverter(parts).asScala();
+    final Map<String, Object> kafkaParams = new HashMap<String, Object>();
+    kafkaParams.put("bootstrap.servers", "not used");
+    final scala.collection.Map<String, Object> sKafkaParams =
+      JavaConverters.mapAsScalaMapConverter(kafkaParams).asScala();
+    final Map<TopicPartition, Object> offsets = new HashMap<>();
+    offsets.put(tp1, 23L);
+    final scala.collection.Map<TopicPartition, Object> sOffsets =
+      JavaConverters.mapAsScalaMapConverter(offsets).asScala();
+
+    // make sure constructors can be called from java
+    final ConsumerStrategy<String, String> sub0 =
+      Subscribe.<String, String>apply(topics, kafkaParams, offsets);
+    final ConsumerStrategy<String, String> sub1 =
+      Subscribe.<String, String>apply(sTopics, sKafkaParams, sOffsets);
+    final ConsumerStrategy<String, String> sub2 =
+      Subscribe.<String, String>apply(sTopics, sKafkaParams);
+    final ConsumerStrategy<String, String> sub3 =
+      Subscribe.<String, String>create(topics, kafkaParams, offsets);
+    final ConsumerStrategy<String, String> sub4 =
+      Subscribe.<String, String>create(topics, kafkaParams);
+
+    Assert.assertEquals(
+      sub1.executorKafkaParams().get("bootstrap.servers"),
+      sub3.executorKafkaParams().get("bootstrap.servers"));
+
+    final ConsumerStrategy<String, String> asn0 =
+      Assign.<String, String>apply(parts, kafkaParams, offsets);
+    final ConsumerStrategy<String, String> asn1 =
+      Assign.<String, String>apply(sParts, sKafkaParams, sOffsets);
+    final ConsumerStrategy<String, String> asn2 =
+      Assign.<String, String>apply(sParts, sKafkaParams);
+    final ConsumerStrategy<String, String> asn3 =
+      Assign.<String, String>create(parts, kafkaParams, offsets);
+    final ConsumerStrategy<String, String> asn4 =
+      Assign.<String, String>create(parts, kafkaParams);
+
+    Assert.assertEquals(
+      asn1.executorKafkaParams().get("bootstrap.servers"),
+      asn3.executorKafkaParams().get("bootstrap.servers"));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java
 
b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java
new file mode 100644
index 0000000..e57ede7
--- /dev/null
+++ 
b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010;
+
+import java.io.Serializable;
+import java.util.*;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.kafka.common.serialization.StringDeserializer;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.spark.streaming.Durations;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaInputDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+public class JavaDirectKafkaStreamSuite implements Serializable {
+  private transient JavaStreamingContext ssc = null;
+  private transient KafkaTestUtils kafkaTestUtils = null;
+
+  @Before
+  public void setUp() {
+    kafkaTestUtils = new KafkaTestUtils();
+    kafkaTestUtils.setup();
+    SparkConf sparkConf = new SparkConf()
+      .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+    ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200));
+  }
+
+  @After
+  public void tearDown() {
+    if (ssc != null) {
+      ssc.stop();
+      ssc = null;
+    }
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown();
+      kafkaTestUtils = null;
+    }
+  }
+
+  @Test
+  public void testKafkaStream() throws InterruptedException {
+    final String topic1 = "topic1";
+    final String topic2 = "topic2";
+    // hold a reference to the current offset ranges, so it can be used 
downstream
+    final AtomicReference<OffsetRange[]> offsetRanges = new 
AtomicReference<>();
+
+    String[] topic1data = createTopicAndSendData(topic1);
+    String[] topic2data = createTopicAndSendData(topic2);
+
+    Set<String> sent = new HashSet<>();
+    sent.addAll(Arrays.asList(topic1data));
+    sent.addAll(Arrays.asList(topic2data));
+
+    Random random = new Random();
+
+    final Map<String, Object> kafkaParams = new HashMap<>();
+    kafkaParams.put("bootstrap.servers", kafkaTestUtils.brokerAddress());
+    kafkaParams.put("key.deserializer", StringDeserializer.class);
+    kafkaParams.put("value.deserializer", StringDeserializer.class);
+    kafkaParams.put("auto.offset.reset", "earliest");
+    kafkaParams.put("group.id", "java-test-consumer-" + random.nextInt() +
+      "-" + System.currentTimeMillis());
+
+    JavaInputDStream<ConsumerRecord<String, String>> istream1 = 
KafkaUtils.createDirectStream(
+        ssc,
+        PreferConsistent.create(),
+        Subscribe.<String, String>create(Arrays.asList(topic1), kafkaParams)
+    );
+
+    JavaDStream<String> stream1 = istream1.transform(
+      // Make sure you can get offset ranges from the rdd
+      new Function<JavaRDD<ConsumerRecord<String, String>>,
+        JavaRDD<ConsumerRecord<String, String>>>() {
+          @Override
+          public JavaRDD<ConsumerRecord<String, String>> call(
+            JavaRDD<ConsumerRecord<String, String>> rdd
+          ) {
+            OffsetRange[] offsets = ((HasOffsetRanges) 
rdd.rdd()).offsetRanges();
+            offsetRanges.set(offsets);
+            Assert.assertEquals(topic1, offsets[0].topic());
+            return rdd;
+          }
+        }
+    ).map(
+        new Function<ConsumerRecord<String, String>, String>() {
+          @Override
+          public String call(ConsumerRecord<String, String> r) {
+            return r.value();
+          }
+        }
+    );
+
+    final Map<String, Object> kafkaParams2 = new HashMap<>(kafkaParams);
+    kafkaParams2.put("group.id", "java-test-consumer-" + random.nextInt() +
+      "-" + System.currentTimeMillis());
+
+    JavaInputDStream<ConsumerRecord<String, String>> istream2 = 
KafkaUtils.createDirectStream(
+        ssc,
+        PreferConsistent.create(),
+        Subscribe.<String, String>create(Arrays.asList(topic2), kafkaParams2)
+    );
+
+    JavaDStream<String> stream2 = istream2.transform(
+      // Make sure you can get offset ranges from the rdd
+      new Function<JavaRDD<ConsumerRecord<String, String>>,
+        JavaRDD<ConsumerRecord<String, String>>>() {
+          @Override
+          public JavaRDD<ConsumerRecord<String, String>> call(
+            JavaRDD<ConsumerRecord<String, String>> rdd
+          ) {
+            OffsetRange[] offsets = ((HasOffsetRanges) 
rdd.rdd()).offsetRanges();
+            offsetRanges.set(offsets);
+            Assert.assertEquals(topic2, offsets[0].topic());
+            return rdd;
+          }
+        }
+    ).map(
+        new Function<ConsumerRecord<String, String>, String>() {
+          @Override
+          public String call(ConsumerRecord<String, String> r) {
+            return r.value();
+          }
+        }
+    );
+
+    JavaDStream<String> unifiedStream = stream1.union(stream2);
+
+    final Set<String> result = Collections.synchronizedSet(new 
HashSet<String>());
+    unifiedStream.foreachRDD(new VoidFunction<JavaRDD<String>>() {
+          @Override
+          public void call(JavaRDD<String> rdd) {
+            result.addAll(rdd.collect());
+          }
+        }
+    );
+    ssc.start();
+    long startTime = System.currentTimeMillis();
+    boolean matches = false;
+    while (!matches && System.currentTimeMillis() - startTime < 20000) {
+      matches = sent.size() == result.size();
+      Thread.sleep(50);
+    }
+    Assert.assertEquals(sent, result);
+    ssc.stop();
+  }
+
+  private  String[] createTopicAndSendData(String topic) {
+    String[] data = { topic + "-1", topic + "-2", topic + "-3"};
+    kafkaTestUtils.createTopic(topic);
+    kafkaTestUtils.sendMessages(topic, data);
+    return data;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java
 
b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java
new file mode 100644
index 0000000..548ba13
--- /dev/null
+++ 
b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.kafka.common.serialization.StringDeserializer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+
+public class JavaKafkaRDDSuite implements Serializable {
+  private transient JavaSparkContext sc = null;
+  private transient KafkaTestUtils kafkaTestUtils = null;
+
+  @Before
+  public void setUp() {
+    kafkaTestUtils = new KafkaTestUtils();
+    kafkaTestUtils.setup();
+    SparkConf sparkConf = new SparkConf()
+      .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+    sc = new JavaSparkContext(sparkConf);
+  }
+
+  @After
+  public void tearDown() {
+    if (sc != null) {
+      sc.stop();
+      sc = null;
+    }
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown();
+      kafkaTestUtils = null;
+    }
+  }
+
+  @Test
+  public void testKafkaRDD() throws InterruptedException {
+    String topic1 = "topic1";
+    String topic2 = "topic2";
+
+    createTopicAndSendData(topic1);
+    createTopicAndSendData(topic2);
+
+    Map<String, Object> kafkaParams = new HashMap<>();
+    kafkaParams.put("bootstrap.servers", kafkaTestUtils.brokerAddress());
+    kafkaParams.put("key.deserializer", StringDeserializer.class);
+    kafkaParams.put("value.deserializer", StringDeserializer.class);
+
+    OffsetRange[] offsetRanges = {
+      OffsetRange.create(topic1, 0, 0, 1),
+      OffsetRange.create(topic2, 0, 0, 1)
+    };
+
+    Map<TopicPartition, String> leaders = new HashMap<>();
+    String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":");
+    String broker = hostAndPort[0];
+    leaders.put(offsetRanges[0].topicPartition(), broker);
+    leaders.put(offsetRanges[1].topicPartition(), broker);
+
+    Function<ConsumerRecord<String, String>, String> handler =
+      new Function<ConsumerRecord<String, String>, String>() {
+        @Override
+        public String call(ConsumerRecord<String, String> r) {
+          return r.value();
+        }
+      };
+
+    JavaRDD<String> rdd1 = KafkaUtils.<String, String>createRDD(
+        sc,
+        kafkaParams,
+        offsetRanges,
+        PreferFixed.create(leaders)
+    ).map(handler);
+
+    JavaRDD<String> rdd2 = KafkaUtils.<String, String>createRDD(
+        sc,
+        kafkaParams,
+        offsetRanges,
+        PreferConsistent.create()
+    ).map(handler);
+
+    // just making sure the java user apis work; the scala tests handle logic 
corner cases
+    long count1 = rdd1.count();
+    long count2 = rdd2.count();
+    Assert.assertTrue(count1 > 0);
+    Assert.assertEquals(count1, count2);
+  }
+
+  private  String[] createTopicAndSendData(String topic) {
+    String[] data = { topic + "-1", topic + "-2", topic + "-3"};
+    kafkaTestUtils.createTopic(topic);
+    kafkaTestUtils.sendMessages(topic, data);
+    return data;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java
 
b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java
new file mode 100644
index 0000000..7873c09
--- /dev/null
+++ 
b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010;
+
+import java.io.Serializable;
+import java.util.*;
+
+import scala.collection.JavaConverters;
+
+import org.apache.kafka.common.TopicPartition;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class JavaLocationStrategySuite implements Serializable {
+
+  @Test
+  public void testLocationStrategyConstructors() {
+    final String topic1 = "topic1";
+    final TopicPartition tp1 = new TopicPartition(topic1, 0);
+    final TopicPartition tp2 = new TopicPartition(topic1, 1);
+    final Map<TopicPartition, String> hosts = new HashMap<>();
+    hosts.put(tp1, "node1");
+    hosts.put(tp2, "node2");
+    final scala.collection.Map<TopicPartition, String> sHosts =
+      JavaConverters.mapAsScalaMapConverter(hosts).asScala();
+
+    // make sure constructors can be called from java
+    final LocationStrategy c1 = PreferConsistent.create();
+    final LocationStrategy c2 = PreferConsistent$.MODULE$;
+    Assert.assertEquals(c1, c2);
+
+    final LocationStrategy c3 = PreferBrokers.create();
+    final LocationStrategy c4 = PreferBrokers$.MODULE$;
+    Assert.assertEquals(c3, c4);
+
+    final LocationStrategy c5 = PreferFixed.create(hosts);
+    final LocationStrategy c6 = PreferFixed.apply(sHosts);
+    Assert.assertEquals(c5, c6);
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/external/kafka-0-10/src/test/resources/log4j.properties
----------------------------------------------------------------------
diff --git a/external/kafka-0-10/src/test/resources/log4j.properties 
b/external/kafka-0-10/src/test/resources/log4j.properties
new file mode 100644
index 0000000..75e3b53
--- /dev/null
+++ b/external/kafka-0-10/src/test/resources/log4j.properties
@@ -0,0 +1,28 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p 
%c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.spark-project.jetty=WARN
+

http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
 
b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
new file mode 100644
index 0000000..776d11a
--- /dev/null
+++ 
b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
@@ -0,0 +1,612 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.io.File
+import java.util.{ Arrays, HashMap => JHashMap, Map => JMap }
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.JavaConverters._
+import scala.concurrent.duration._
+import scala.language.postfixOps
+import scala.util.Random
+
+import org.apache.kafka.clients.consumer._
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.StringDeserializer
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+import org.apache.spark.util.Utils
+
+class DirectKafkaStreamSuite
+  extends SparkFunSuite
+  with BeforeAndAfter
+  with BeforeAndAfterAll
+  with Eventually
+  with Logging {
+  val sparkConf = new SparkConf()
+    .setMaster("local[4]")
+    .setAppName(this.getClass.getSimpleName)
+
+  private var sc: SparkContext = _
+  private var ssc: StreamingContext = _
+  private var testDir: File = _
+
+  private var kafkaTestUtils: KafkaTestUtils = _
+
+  override def beforeAll {
+    kafkaTestUtils = new KafkaTestUtils
+    kafkaTestUtils.setup()
+  }
+
+  override def afterAll {
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown()
+      kafkaTestUtils = null
+    }
+  }
+
+  after {
+    if (ssc != null) {
+      ssc.stop()
+      sc = null
+    }
+    if (sc != null) {
+      sc.stop()
+    }
+    if (testDir != null) {
+      Utils.deleteRecursively(testDir)
+    }
+  }
+
+  def getKafkaParams(extra: (String, Object)*): JHashMap[String, Object] = {
+    val kp = new JHashMap[String, Object]()
+    kp.put("bootstrap.servers", kafkaTestUtils.brokerAddress)
+    kp.put("key.deserializer", classOf[StringDeserializer])
+    kp.put("value.deserializer", classOf[StringDeserializer])
+    kp.put("group.id", 
s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}")
+    extra.foreach(e => kp.put(e._1, e._2))
+    kp
+  }
+
+  val preferredHosts = PreferConsistent
+
+  test("basic stream receiving with multiple topics and smallest starting 
offset") {
+    val topics = List("basic1", "basic2", "basic3")
+    val data = Map("a" -> 7, "b" -> 9)
+    topics.foreach { t =>
+      kafkaTestUtils.createTopic(t)
+      kafkaTestUtils.sendMessages(t, data)
+    }
+    val totalSent = data.values.sum * topics.size
+    val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+
+    ssc = new StreamingContext(sparkConf, Milliseconds(200))
+    val stream = withClue("Error creating direct stream") {
+      KafkaUtils.createDirectStream[String, String](
+        ssc, preferredHosts, Subscribe[String, String](topics, 
kafkaParams.asScala))
+    }
+    val allReceived = new ConcurrentLinkedQueue[(String, String)]()
+
+    // hold a reference to the current offset ranges, so it can be used 
downstream
+    var offsetRanges = Array[OffsetRange]()
+    val tf = stream.transform { rdd =>
+      // Get the offset ranges in the RDD
+      offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+      rdd.map(r => (r.key, r.value))
+    }
+
+    tf.foreachRDD { rdd =>
+      for (o <- offsetRanges) {
+        logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}")
+      }
+      val collected = rdd.mapPartitionsWithIndex { (i, iter) =>
+      // For each partition, get size of the range in the partition,
+      // and the number of items in the partition
+        val off = offsetRanges(i)
+        val all = iter.toSeq
+        val partSize = all.size
+        val rangeSize = off.untilOffset - off.fromOffset
+        Iterator((partSize, rangeSize))
+      }.collect
+
+      // Verify whether number of elements in each partition
+      // matches with the corresponding offset range
+      collected.foreach { case (partSize, rangeSize) =>
+        assert(partSize === rangeSize, "offset ranges are wrong")
+      }
+    }
+
+    stream.foreachRDD { rdd =>
+      allReceived.addAll(Arrays.asList(rdd.map(r => (r.key, 
r.value)).collect(): _*))
+    }
+    ssc.start()
+    eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
+      assert(allReceived.size === totalSent,
+        "didn't get expected number of messages, messages:\n" +
+          allReceived.asScala.mkString("\n"))
+    }
+    ssc.stop()
+  }
+
+  test("receiving from largest starting offset") {
+    val topic = "latest"
+    val topicPartition = new TopicPartition(topic, 0)
+    val data = Map("a" -> 10)
+    kafkaTestUtils.createTopic(topic)
+    val kafkaParams = getKafkaParams("auto.offset.reset" -> "latest")
+    val kc = new KafkaConsumer(kafkaParams)
+    kc.assign(Arrays.asList(topicPartition))
+    def getLatestOffset(): Long = {
+      kc.seekToEnd(Arrays.asList(topicPartition))
+      kc.position(topicPartition)
+    }
+
+    // Send some initial messages before starting context
+    kafkaTestUtils.sendMessages(topic, data)
+    eventually(timeout(10 seconds), interval(20 milliseconds)) {
+      assert(getLatestOffset() > 3)
+    }
+    val offsetBeforeStart = getLatestOffset()
+    kc.close()
+
+    // Setup context and kafka stream with largest offset
+    ssc = new StreamingContext(sparkConf, Milliseconds(200))
+    val stream = withClue("Error creating direct stream") {
+      val s = new DirectKafkaInputDStream[String, String](
+        ssc, preferredHosts, Subscribe[String, String](List(topic), 
kafkaParams.asScala))
+      s.consumer.poll(0)
+      assert(
+        s.consumer.position(topicPartition) >= offsetBeforeStart,
+        "Start offset not from latest"
+      )
+      s
+    }
+
+    val collectedData = new ConcurrentLinkedQueue[String]()
+    stream.map { _.value }.foreachRDD { rdd =>
+      collectedData.addAll(Arrays.asList(rdd.collect(): _*))
+    }
+    ssc.start()
+    val newData = Map("b" -> 10)
+    kafkaTestUtils.sendMessages(topic, newData)
+    eventually(timeout(10 seconds), interval(50 milliseconds)) {
+      collectedData.contains("b")
+    }
+    assert(!collectedData.contains("a"))
+  }
+
+
+  test("creating stream by offset") {
+    val topic = "offset"
+    val topicPartition = new TopicPartition(topic, 0)
+    val data = Map("a" -> 10)
+    kafkaTestUtils.createTopic(topic)
+    val kafkaParams = getKafkaParams("auto.offset.reset" -> "latest")
+    val kc = new KafkaConsumer(kafkaParams)
+    kc.assign(Arrays.asList(topicPartition))
+    def getLatestOffset(): Long = {
+      kc.seekToEnd(Arrays.asList(topicPartition))
+      kc.position(topicPartition)
+    }
+
+    // Send some initial messages before starting context
+    kafkaTestUtils.sendMessages(topic, data)
+    eventually(timeout(10 seconds), interval(20 milliseconds)) {
+      assert(getLatestOffset() >= 10)
+    }
+    val offsetBeforeStart = getLatestOffset()
+    kc.close()
+
+    // Setup context and kafka stream with largest offset
+    ssc = new StreamingContext(sparkConf, Milliseconds(200))
+    val stream = withClue("Error creating direct stream") {
+      val s = new DirectKafkaInputDStream[String, String](ssc, preferredHosts,
+        Assign[String, String](
+          List(topicPartition),
+          kafkaParams.asScala,
+          Map(topicPartition -> 11L)))
+      s.consumer.poll(0)
+      assert(
+        s.consumer.position(topicPartition) >= offsetBeforeStart,
+        "Start offset not from latest"
+      )
+      s
+    }
+
+    val collectedData = new ConcurrentLinkedQueue[String]()
+    stream.map(_.value).foreachRDD { rdd => 
collectedData.addAll(Arrays.asList(rdd.collect(): _*)) }
+    ssc.start()
+    val newData = Map("b" -> 10)
+    kafkaTestUtils.sendMessages(topic, newData)
+    eventually(timeout(10 seconds), interval(50 milliseconds)) {
+      collectedData.contains("b")
+    }
+    assert(!collectedData.contains("a"))
+  }
+
+  // Test to verify the offset ranges can be recovered from the checkpoints
+  test("offset recovery") {
+    val topic = "recovery"
+    kafkaTestUtils.createTopic(topic)
+    testDir = Utils.createTempDir()
+
+    val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+
+    // Send data to Kafka
+    def sendData(data: Seq[Int]) {
+      val strings = data.map { _.toString}
+      kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap)
+    }
+
+    // Setup the streaming context
+    ssc = new StreamingContext(sparkConf, Milliseconds(100))
+    val kafkaStream = withClue("Error creating direct stream") {
+      KafkaUtils.createDirectStream[String, String](
+        ssc, preferredHosts, Subscribe[String, String](List(topic), 
kafkaParams.asScala))
+    }
+    val keyedStream = kafkaStream.map { r => "key" -> r.value.toInt }
+    val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: 
Option[Int]) =>
+      Some(values.sum + state.getOrElse(0))
+    }
+    ssc.checkpoint(testDir.getAbsolutePath)
+
+    // This is ensure all the data is eventually receiving only once
+    stateStream.foreachRDD { (rdd: RDD[(String, Int)]) =>
+      rdd.collect().headOption.foreach { x =>
+        DirectKafkaStreamSuite.total.set(x._2)
+      }
+    }
+
+    ssc.start()
+
+    // Send some data
+    for (i <- (1 to 10).grouped(4)) {
+      sendData(i)
+    }
+
+    eventually(timeout(10 seconds), interval(50 milliseconds)) {
+      assert(DirectKafkaStreamSuite.total.get === (1 to 10).sum)
+    }
+
+    ssc.stop()
+
+    // Verify that offset ranges were generated
+    val offsetRangesBeforeStop = getOffsetRanges(kafkaStream)
+    assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated")
+    assert(
+      offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 },
+      "starting offset not zero"
+    )
+
+    logInfo("====== RESTARTING ========")
+
+    // Recover context from checkpoints
+    ssc = new StreamingContext(testDir.getAbsolutePath)
+    val recoveredStream =
+      
ssc.graph.getInputStreams().head.asInstanceOf[DStream[ConsumerRecord[String, 
String]]]
+
+    // Verify offset ranges have been recovered
+    val recoveredOffsetRanges = getOffsetRanges(recoveredStream).map { x => 
(x._1, x._2.toSet) }
+    assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered")
+    val earlierOffsetRanges = offsetRangesBeforeStop.map { x => (x._1, 
x._2.toSet) }
+    assert(
+      recoveredOffsetRanges.forall { or =>
+        earlierOffsetRanges.contains((or._1, or._2))
+      },
+      "Recovered ranges are not the same as the ones generated\n" +
+        earlierOffsetRanges + "\n" + recoveredOffsetRanges
+    )
+    // Restart context, give more data and verify the total at the end
+    // If the total is write that means each records has been received only 
once
+    ssc.start()
+    for (i <- (11 to 20).grouped(4)) {
+      sendData(i)
+    }
+
+    eventually(timeout(10 seconds), interval(50 milliseconds)) {
+      assert(DirectKafkaStreamSuite.total.get === (1 to 20).sum)
+    }
+    ssc.stop()
+  }
+
+    // Test to verify the offsets can be recovered from Kafka
+  test("offset recovery from kafka") {
+    val topic = "recoveryfromkafka"
+    kafkaTestUtils.createTopic(topic)
+
+    val kafkaParams = getKafkaParams(
+      "auto.offset.reset" -> "earliest",
+      ("enable.auto.commit", false: java.lang.Boolean)
+    )
+
+    val collectedData = new ConcurrentLinkedQueue[String]()
+    val committed = new JHashMap[TopicPartition, OffsetAndMetadata]()
+
+    // Send data to Kafka and wait for it to be received
+    def sendDataAndWaitForReceive(data: Seq[Int]) {
+      val strings = data.map { _.toString}
+      kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap)
+      eventually(timeout(10 seconds), interval(50 milliseconds)) {
+        assert(strings.forall { collectedData.contains })
+      }
+    }
+
+    // Setup the streaming context
+    ssc = new StreamingContext(sparkConf, Milliseconds(100))
+    withClue("Error creating direct stream") {
+      val kafkaStream = KafkaUtils.createDirectStream[String, String](
+        ssc, preferredHosts, Subscribe[String, String](List(topic), 
kafkaParams.asScala))
+      kafkaStream.foreachRDD { (rdd: RDD[ConsumerRecord[String, String]], 
time: Time) =>
+        val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+        val data = rdd.map(_.value).collect()
+        collectedData.addAll(Arrays.asList(data: _*))
+        kafkaStream.asInstanceOf[CanCommitOffsets]
+          .commitAsync(offsets, new OffsetCommitCallback() {
+            def onComplete(m: JMap[TopicPartition, OffsetAndMetadata], e: 
Exception) {
+              if (null != e) {
+                logError("commit failed", e)
+              } else {
+                committed.putAll(m)
+              }
+            }
+          })
+      }
+    }
+    ssc.start()
+    // Send some data and wait for them to be received
+    for (i <- (1 to 10).grouped(4)) {
+      sendDataAndWaitForReceive(i)
+    }
+    ssc.stop()
+    assert(! committed.isEmpty)
+    val consumer = new KafkaConsumer[String, String](kafkaParams)
+    consumer.subscribe(Arrays.asList(topic))
+    consumer.poll(0)
+    committed.asScala.foreach {
+      case (k, v) =>
+        // commits are async, not exactly once
+        assert(v.offset > 0)
+        assert(consumer.position(k) >= v.offset)
+    }
+  }
+
+
+  test("Direct Kafka stream report input information") {
+    val topic = "report-test"
+    val data = Map("a" -> 7, "b" -> 9)
+    kafkaTestUtils.createTopic(topic)
+    kafkaTestUtils.sendMessages(topic, data)
+
+    val totalSent = data.values.sum
+    val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+
+    import DirectKafkaStreamSuite._
+    ssc = new StreamingContext(sparkConf, Milliseconds(200))
+    val collector = new InputInfoCollector
+    ssc.addStreamingListener(collector)
+
+    val stream = withClue("Error creating direct stream") {
+      KafkaUtils.createDirectStream[String, String](
+        ssc, preferredHosts, Subscribe[String, String](List(topic), 
kafkaParams.asScala))
+    }
+
+    val allReceived = new ConcurrentLinkedQueue[(String, String)]
+
+    stream.map(r => (r.key, r.value))
+      .foreachRDD { rdd => allReceived.addAll(Arrays.asList(rdd.collect(): 
_*)) }
+    ssc.start()
+    eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
+      assert(allReceived.size === totalSent,
+        "didn't get expected number of messages, messages:\n" +
+          allReceived.asScala.mkString("\n"))
+
+      // Calculate all the record number collected in the StreamingListener.
+      assert(collector.numRecordsSubmitted.get() === totalSent)
+      assert(collector.numRecordsStarted.get() === totalSent)
+      assert(collector.numRecordsCompleted.get() === totalSent)
+    }
+    ssc.stop()
+  }
+
+  test("maxMessagesPerPartition with backpressure disabled") {
+    val topic = "maxMessagesPerPartition"
+    val kafkaStream = getDirectKafkaStream(topic, None)
+
+    val input = Map(new TopicPartition(topic, 0) -> 50L, new 
TopicPartition(topic, 1) -> 50L)
+    assert(kafkaStream.maxMessagesPerPartition(input).get ==
+      Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 
10L))
+  }
+
+  test("maxMessagesPerPartition with no lag") {
+    val topic = "maxMessagesPerPartition"
+    val rateController = Some(new ConstantRateController(0, new 
ConstantEstimator(100), 100))
+    val kafkaStream = getDirectKafkaStream(topic, rateController)
+
+    val input = Map(new TopicPartition(topic, 0) -> 0L, new 
TopicPartition(topic, 1) -> 0L)
+    assert(kafkaStream.maxMessagesPerPartition(input).isEmpty)
+  }
+
+  test("maxMessagesPerPartition respects max rate") {
+    val topic = "maxMessagesPerPartition"
+    val rateController = Some(new ConstantRateController(0, new 
ConstantEstimator(100), 1000))
+    val kafkaStream = getDirectKafkaStream(topic, rateController)
+
+    val input = Map(new TopicPartition(topic, 0) -> 1000L, new 
TopicPartition(topic, 1) -> 1000L)
+    assert(kafkaStream.maxMessagesPerPartition(input).get ==
+      Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 
10L))
+  }
+
+  test("using rate controller") {
+    val topic = "backpressure"
+    val topicPartitions = Set(new TopicPartition(topic, 0), new 
TopicPartition(topic, 1))
+    kafkaTestUtils.createTopic(topic, 2)
+    val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+    val executorKafkaParams = new JHashMap[String, Object](kafkaParams)
+    KafkaUtils.fixKafkaParams(executorKafkaParams)
+
+    val batchIntervalMilliseconds = 100
+    val estimator = new ConstantEstimator(100)
+    val messages = Map("foo" -> 200)
+    kafkaTestUtils.sendMessages(topic, messages)
+
+    val sparkConf = new SparkConf()
+      // Safe, even with streaming, because we're using the direct API.
+      // Using 1 core is useful to make the test more predictable.
+      .setMaster("local[1]")
+      .setAppName(this.getClass.getSimpleName)
+      .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+    // Setup the streaming context
+    ssc = new StreamingContext(sparkConf, 
Milliseconds(batchIntervalMilliseconds))
+
+    val kafkaStream = withClue("Error creating direct stream") {
+      new DirectKafkaInputDStream[String, String](
+        ssc, preferredHosts, Subscribe[String, String](List(topic), 
kafkaParams.asScala)) {
+        override protected[streaming] val rateController =
+          Some(new DirectKafkaRateController(id, estimator))
+      }.map(r => (r.key, r.value))
+    }
+
+    val collectedData = new ConcurrentLinkedQueue[Array[String]]()
+
+    // Used for assertion failure messages.
+    def dataToString: String =
+      collectedData.asScala.map(_.mkString("[", ",", "]")).mkString("{", ", ", 
"}")
+
+    // This is to collect the raw data received from Kafka
+    kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) =>
+      val data = rdd.map { _._2 }.collect()
+      collectedData.add(data)
+    }
+
+    ssc.start()
+
+    // Try different rate limits.
+    // Wait for arrays of data to appear matching the rate.
+    Seq(100, 50, 20).foreach { rate =>
+      collectedData.clear()       // Empty this buffer on each pass.
+      estimator.updateRate(rate)  // Set a new rate.
+      // Expect blocks of data equal to "rate", scaled by the interval length 
in secs.
+      val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001)
+      eventually(timeout(5.seconds), 
interval(batchIntervalMilliseconds.milliseconds)) {
+        // Assert that rate estimator values are used to determine 
maxMessagesPerPartition.
+        // Funky "-" in message makes the complete assertion message read 
better.
+        assert(collectedData.asScala.exists(_.size == expectedSize),
+          s" - No arrays of size $expectedSize for rate $rate found in 
$dataToString")
+      }
+    }
+
+    ssc.stop()
+  }
+
+  /** Get the generated offset ranges from the DirectKafkaStream */
+  private def getOffsetRanges[K, V](
+      kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, 
Array[OffsetRange])] = {
+    kafkaStream.generatedRDDs.mapValues { rdd =>
+      rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+    }.toSeq.sortBy { _._1 }
+  }
+
+  private def getDirectKafkaStream(topic: String, mockRateController: 
Option[RateController]) = {
+    val batchIntervalMilliseconds = 100
+
+    val sparkConf = new SparkConf()
+      .setMaster("local[1]")
+      .setAppName(this.getClass.getSimpleName)
+      .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+    // Setup the streaming context
+    ssc = new StreamingContext(sparkConf, 
Milliseconds(batchIntervalMilliseconds))
+
+    val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+    val ekp = new JHashMap[String, Object](kafkaParams)
+    KafkaUtils.fixKafkaParams(ekp)
+
+    val s = new DirectKafkaInputDStream[String, String](
+      ssc,
+      preferredHosts,
+      new ConsumerStrategy[String, String] {
+        def executorKafkaParams = ekp
+        def onStart(currentOffsets: Map[TopicPartition, Long]): 
Consumer[String, String] = {
+          val consumer = new KafkaConsumer[String, String](kafkaParams)
+          val tps = List(new TopicPartition(topic, 0), new 
TopicPartition(topic, 1))
+          consumer.assign(Arrays.asList(tps: _*))
+          tps.foreach(tp => consumer.seek(tp, 0))
+          consumer
+        }
+      }
+    ) {
+        override protected[streaming] val rateController = mockRateController
+    }
+    // manual start necessary because we arent consuming the stream, just 
checking its state
+    s.start()
+    s
+  }
+}
+
+object DirectKafkaStreamSuite {
+  val total = new AtomicLong(-1L)
+
+  class InputInfoCollector extends StreamingListener {
+    val numRecordsSubmitted = new AtomicLong(0L)
+    val numRecordsStarted = new AtomicLong(0L)
+    val numRecordsCompleted = new AtomicLong(0L)
+
+    override def onBatchSubmitted(batchSubmitted: 
StreamingListenerBatchSubmitted): Unit = {
+      numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords)
+    }
+
+    override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): 
Unit = {
+      numRecordsStarted.addAndGet(batchStarted.batchInfo.numRecords)
+    }
+
+    override def onBatchCompleted(batchCompleted: 
StreamingListenerBatchCompleted): Unit = {
+      numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords)
+    }
+  }
+}
+
+private[streaming] class ConstantEstimator(@volatile private var rate: Long)
+  extends RateEstimator {
+
+  def updateRate(newRate: Long): Unit = {
+    rate = newRate
+  }
+
+  def compute(
+      time: Long,
+      elements: Long,
+      processingDelay: Long,
+      schedulingDelay: Long): Option[Double] = Some(rate)
+}
+
+private[streaming] class ConstantRateController(id: Int, estimator: 
RateEstimator, rate: Long)
+  extends RateController(id, estimator) {
+  override def publish(rate: Long): Unit = ()
+  override def getLatestRate(): Long = rate
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
 
b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
new file mode 100644
index 0000000..3d2546d
--- /dev/null
+++ 
b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{ util => ju }
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.StringDeserializer
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark._
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+
+class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
+
+  private var kafkaTestUtils: KafkaTestUtils = _
+
+  private val sparkConf = new SparkConf().setMaster("local[4]")
+    .setAppName(this.getClass.getSimpleName)
+  private var sc: SparkContext = _
+
+  override def beforeAll {
+    sc = new SparkContext(sparkConf)
+    kafkaTestUtils = new KafkaTestUtils
+    kafkaTestUtils.setup()
+  }
+
+  override def afterAll {
+    if (sc != null) {
+      sc.stop
+      sc = null
+    }
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown()
+      kafkaTestUtils = null
+    }
+  }
+
+  private def getKafkaParams() = Map[String, Object](
+    "bootstrap.servers" -> kafkaTestUtils.brokerAddress,
+    "key.deserializer" -> classOf[StringDeserializer],
+    "value.deserializer" -> classOf[StringDeserializer],
+    "group.id" -> 
s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}"
+  ).asJava
+
+  private val preferredHosts = PreferConsistent
+
+  test("basic usage") {
+    val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}"
+    kafkaTestUtils.createTopic(topic)
+    val messages = Array("the", "quick", "brown", "fox")
+    kafkaTestUtils.sendMessages(topic, messages)
+
+    val kafkaParams = getKafkaParams()
+
+    val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
+
+    val rdd = KafkaUtils.createRDD[String, String](sc, kafkaParams, 
offsetRanges, preferredHosts)
+      .map(_.value)
+
+    val received = rdd.collect.toSet
+    assert(received === messages.toSet)
+
+    // size-related method optimizations return sane results
+    assert(rdd.count === messages.size)
+    assert(rdd.countApprox(0).getFinalValue.mean === messages.size)
+    assert(!rdd.isEmpty)
+    assert(rdd.take(1).size === 1)
+    assert(rdd.take(1).head === messages.head)
+    assert(rdd.take(messages.size + 10).size === messages.size)
+
+    val emptyRdd = KafkaUtils.createRDD[String, String](
+      sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts)
+
+    assert(emptyRdd.isEmpty)
+
+    // invalid offset ranges throw exceptions
+    val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1))
+    intercept[SparkException] {
+      val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, 
badRanges, preferredHosts)
+        .map(_.value)
+        .collect()
+    }
+  }
+
+  test("iterator boundary conditions") {
+    // the idea is to find e.g. off-by-one errors between what kafka has 
available and the rdd
+    val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}"
+    val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
+    kafkaTestUtils.createTopic(topic)
+
+    val kafkaParams = getKafkaParams()
+
+    // this is the "lots of messages" case
+    kafkaTestUtils.sendMessages(topic, sent)
+    var sentCount = sent.values.sum
+
+    val rdd = KafkaUtils.createRDD[String, String](sc, kafkaParams,
+      Array(OffsetRange(topic, 0, 0, sentCount)), preferredHosts)
+
+    val ranges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+    val rangeCount = ranges.map(o => o.untilOffset - o.fromOffset).sum
+
+    assert(rangeCount === sentCount, "offset range didn't include all sent 
messages")
+    assert(rdd.map(_.offset).collect.sorted === (0 until sentCount).toArray,
+      "didn't get all sent messages")
+
+    // this is the "0 messages" case
+    val rdd2 = KafkaUtils.createRDD[String, String](sc, kafkaParams,
+      Array(OffsetRange(topic, 0, sentCount, sentCount)), preferredHosts)
+
+    // shouldn't get anything, since message is sent after rdd was defined
+    val sentOnlyOne = Map("d" -> 1)
+
+    kafkaTestUtils.sendMessages(topic, sentOnlyOne)
+
+    assert(rdd2.map(_.value).collect.size === 0, "got messages when there 
shouldn't be any")
+
+    // this is the "exactly 1 message" case, namely the single message from 
sentOnlyOne above
+    val rdd3 = KafkaUtils.createRDD[String, String](sc, kafkaParams,
+      Array(OffsetRange(topic, 0, sentCount, sentCount + 1)), preferredHosts)
+
+    // send lots of messages after rdd was defined, they shouldn't show up
+    kafkaTestUtils.sendMessages(topic, Map("extra" -> 22))
+
+    assert(rdd3.map(_.value).collect.head === sentOnlyOne.keys.head,
+      "didn't get exactly one message")
+  }
+
+  test("executor sorting") {
+    val kafkaParams = new ju.HashMap[String, Object](getKafkaParams())
+    kafkaParams.put("auto.offset.reset", "none")
+    val rdd = new KafkaRDD[String, String](
+      sc,
+      kafkaParams,
+      Array(OffsetRange("unused", 0, 1, 2)),
+      ju.Collections.emptyMap[TopicPartition, String](),
+      true)
+    val a3 = ExecutorCacheTaskLocation("a", "3")
+    val a4 = ExecutorCacheTaskLocation("a", "4")
+    val b1 = ExecutorCacheTaskLocation("b", "1")
+    val b2 = ExecutorCacheTaskLocation("b", "2")
+
+    val correct = Array(b2, b1, a4, a3)
+
+    correct.permutations.foreach { p =>
+      assert(p.sortWith(rdd.compareExecutors) === correct)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 89ed87f..c99d786 100644
--- a/pom.xml
+++ b/pom.xml
@@ -109,6 +109,8 @@
     <module>launcher</module>
     <module>external/kafka-0-8</module>
     <module>external/kafka-0-8-assembly</module>
+    <module>external/kafka-0-10</module>
+    <module>external/kafka-0-10-assembly</module>
   </modules>
 
   <properties>

http://git-wip-us.apache.org/repos/asf/spark/blob/dedbceec/project/SparkBuild.scala
----------------------------------------------------------------------
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 4c01ad3..8e3dcc2 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -44,9 +44,9 @@ object BuildCommons {
   ).map(ProjectRef(buildLocation, _))
 
   val streamingProjects@Seq(
-    streaming, streamingFlumeSink, streamingFlume, streamingKafka
+    streaming, streamingFlumeSink, streamingFlume, streamingKafka, 
streamingKafka010
   ) = Seq(
-    "streaming", "streaming-flume-sink", "streaming-flume", 
"streaming-kafka-0-8"
+    "streaming", "streaming-flume-sink", "streaming-flume", 
"streaming-kafka-0-8", "streaming-kafka-0-10"
   ).map(ProjectRef(buildLocation, _))
 
   val allProjects@Seq(
@@ -61,8 +61,8 @@ object BuildCommons {
     Seq("yarn", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl",
       "docker-integration-tests").map(ProjectRef(buildLocation, _))
 
-  val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, 
streamingKafkaAssembly, streamingKinesisAslAssembly) =
-    Seq("network-yarn", "streaming-flume-assembly", 
"streaming-kafka-0-8-assembly", "streaming-kinesis-asl-assembly")
+  val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, 
streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) 
=
+    Seq("network-yarn", "streaming-flume-assembly", 
"streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", 
"streaming-kinesis-asl-assembly")
       .map(ProjectRef(buildLocation, _))
 
   val copyJarsProjects@Seq(assembly, examples) = Seq("assembly", "examples")
@@ -352,7 +352,7 @@ object SparkBuild extends PomBuild {
   val mimaProjects = allProjects.filterNot { x =>
     Seq(
       spark, hive, hiveThriftServer, catalyst, repl, networkCommon, 
networkShuffle, networkYarn,
-      unsafe, tags, sketch, mllibLocal
+      unsafe, tags, sketch, mllibLocal, streamingKafka010
     ).contains(x)
   }
 
@@ -608,7 +608,7 @@ object Assembly {
         
.getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
     },
     jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, 
mName, hv) =>
-      if (mName.contains("streaming-flume-assembly") || 
mName.contains("streaming-kafka-0-8-assembly") || 
mName.contains("streaming-kinesis-asl-assembly")) {
+      if (mName.contains("streaming-flume-assembly") || 
mName.contains("streaming-kafka-0-8-assembly") || 
mName.contains("streaming-kafka-0-10-assembly") || 
mName.contains("streaming-kinesis-asl-assembly")) {
         // This must match the same name used in maven (see 
external/kafka-0-8-assembly/pom.xml)
         s"${mName}-${v}.jar"
       } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to