Repository: spark
Updated Branches:
  refs/heads/master ab1029fb8 -> 9abe09bfc


[SPARK-24127][SS] Continuous text socket source

## What changes were proposed in this pull request?

Support for text socket stream in spark structured streaming "continuous" mode. 
This is roughly based on the idea of ContinuousMemoryStream where the executor 
queries the data from driver over an RPC endpoint.

This makes it possible to create Structured streaming continuous pipeline to 
ingest data via "nc" and run examples.

## How was this patch tested?

Unit test and ran spark examples in structured streaming continuous mode.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Closes #21199 from arunmahadevan/SPARK-24127.

Authored-by: Arun Mahadevan <ar...@apache.org>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9abe09bf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9abe09bf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9abe09bf

Branch: refs/heads/master
Commit: 9abe09bfc18580233acad676d1241684c7d8768d
Parents: ab1029f
Author: Arun Mahadevan <ar...@apache.org>
Authored: Fri Aug 10 15:53:31 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Fri Aug 10 15:53:31 2018 +0800

----------------------------------------------------------------------
 .../streaming/ContinuousRecordEndpoint.scala    |  69 +++++
 .../continuous/ContinuousTextSocketSource.scala | 292 +++++++++++++++++++
 .../sources/ContinuousMemoryStream.scala        |  32 +-
 .../execution/streaming/sources/socket.scala    |  25 +-
 .../sources/TextSocketStreamSuite.scala         |  98 ++++++-
 5 files changed, 482 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9abe09bf/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala
new file mode 100644
index 0000000..c9c2ebc
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset
+
+case class ContinuousRecordPartitionOffset(partitionId: Int, offset: Int) 
extends PartitionOffset
+case class GetRecord(offset: ContinuousRecordPartitionOffset)
+
+/**
+ * A RPC end point for continuous readers to poll for
+ * records from the driver.
+ *
+ * @param buckets the data buckets. Each bucket contains a sequence of items 
to be
+ *                returned for a partition. The number of buckets should be 
equal to
+ *                to the number of partitions.
+ * @param lock a lock object for locking the buckets for read
+ */
+class ContinuousRecordEndpoint(buckets: Seq[Seq[Any]], lock: Object)
+  extends ThreadSafeRpcEndpoint {
+
+  private var startOffsets: Seq[Int] = List.fill(buckets.size)(0)
+
+  /**
+   * Sets the start offset.
+   *
+   * @param offsets the base offset per partition to be used
+   *                while retrieving the data in {#receiveAndReply}.
+   */
+  def setStartOffsets(offsets: Seq[Int]): Unit = {
+    lock.synchronized {
+      startOffsets = offsets
+    }
+  }
+
+  override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
+
+  /**
+   * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched 
message,
+   * `SparkException` will be thrown and sent to `onError`.
+   */
+  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
+    case GetRecord(ContinuousRecordPartitionOffset(partitionId, offset)) =>
+      lock.synchronized {
+        val bufOffset = offset - startOffsets(partitionId)
+        val buf = buckets(partitionId)
+        val record = if (buf.size <= bufOffset) None else Some(buf(bufOffset))
+
+        context.reply(record.map(InternalRow(_)))
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9abe09bf/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
new file mode 100644
index 0000000..1dbdfd5
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.io.{BufferedReader, InputStreamReader, IOException}
+import java.net.Socket
+import java.sql.Timestamp
+import java.util.{Calendar, List => JList}
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
+
+import org.json4s.{DefaultFormats, NoTypeHints}
+import org.json4s.jackson.Serialization
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, 
ContinuousRecordPartitionOffset, GetRecord}
+import org.apache.spark.sql.execution.streaming.sources.TextSocketReader
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, 
InputPartitionReader, SupportsDeprecatedScanRow}
+import 
org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader,
 ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.types.{StringType, StructField, StructType, 
TimestampType}
+import org.apache.spark.util.RpcUtils
+
+
+/**
+ * A ContinuousReader that reads text lines through a TCP socket, designed 
only for tutorials and
+ * debugging. This ContinuousReader will *not* work in production applications 
due to multiple
+ * reasons, including no support for fault recovery.
+ *
+ * The driver maintains a socket connection to the host-port, keeps the 
received messages in
+ * buckets and serves the messages to the executors via a RPC endpoint.
+ */
+class TextSocketContinuousReader(options: DataSourceOptions) extends 
ContinuousReader with Logging {
+  implicit val defaultFormats: DefaultFormats = DefaultFormats
+
+  private val host: String = options.get("host").get()
+  private val port: Int = options.get("port").get().toInt
+
+  assert(SparkSession.getActiveSession.isDefined)
+  private val spark = SparkSession.getActiveSession.get
+  private val numPartitions = spark.sparkContext.defaultParallelism
+
+  @GuardedBy("this")
+  private var socket: Socket = _
+
+  @GuardedBy("this")
+  private var readThread: Thread = _
+
+  @GuardedBy("this")
+  private val buckets = Seq.fill(numPartitions)(new ListBuffer[(String, 
Timestamp)])
+
+  @GuardedBy("this")
+  private var currentOffset: Int = -1
+
+  private var startOffset: TextSocketOffset = _
+
+  private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this)
+  @volatile private var endpointRef: RpcEndpointRef = _
+
+  initialize()
+
+  override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
+    assert(offsets.length == numPartitions)
+    val offs = offsets
+      .map(_.asInstanceOf[ContinuousRecordPartitionOffset])
+      .sortBy(_.partitionId)
+      .map(_.offset)
+      .toList
+    TextSocketOffset(offs)
+  }
+
+  override def deserializeOffset(json: String): Offset = {
+    TextSocketOffset(Serialization.read[List[Int]](json))
+  }
+
+  override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
+    this.startOffset = offset
+      .orElse(TextSocketOffset(List.fill(numPartitions)(0)))
+      .asInstanceOf[TextSocketOffset]
+    recordEndpoint.setStartOffsets(startOffset.offsets)
+  }
+
+  override def getStartOffset: Offset = startOffset
+
+  override def readSchema(): StructType = {
+    if (includeTimestamp) {
+      TextSocketReader.SCHEMA_TIMESTAMP
+    } else {
+      TextSocketReader.SCHEMA_REGULAR
+    }
+  }
+
+  override def planInputPartitions(): JList[InputPartition[InternalRow]] = {
+
+    val endpointName = 
s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}"
+    endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, 
recordEndpoint)
+
+    val offsets = startOffset match {
+      case off: TextSocketOffset => off.offsets
+      case off =>
+        throw new IllegalArgumentException(
+          s"invalid offset type ${off.getClass} for 
TextSocketContinuousReader")
+    }
+
+    if (offsets.size != numPartitions) {
+      throw new IllegalArgumentException(
+        s"The previous run contained ${offsets.size} partitions, but" +
+          s" $numPartitions partitions are currently configured. The 
numPartitions option" +
+          " cannot be changed.")
+    }
+
+    startOffset.offsets.zipWithIndex.map {
+      case (offset, i) =>
+        TextSocketContinuousInputPartition(
+          endpointName, i, offset, includeTimestamp): 
InputPartition[InternalRow]
+    }.asJava
+
+  }
+
+  override def commit(end: Offset): Unit = synchronized {
+    val endOffset = end match {
+      case off: TextSocketOffset => off
+      case _ => throw new 
IllegalArgumentException(s"TextSocketContinuousReader.commit()" +
+        s"received an offset ($end) that did not originate with an instance of 
this class")
+    }
+
+    endOffset.offsets.zipWithIndex.foreach {
+      case (offset, partition) =>
+        val max = startOffset.offsets(partition) + buckets(partition).size
+        if (offset > max) {
+          throw new IllegalStateException("Invalid offset " + offset + " to 
commit" +
+          " for partition " + partition + ". Max valid offset: " + max)
+        }
+        val n = offset - startOffset.offsets(partition)
+        buckets(partition).trimStart(n)
+    }
+    startOffset = endOffset
+    recordEndpoint.setStartOffsets(startOffset.offsets)
+  }
+
+  /** Stop this source. */
+  override def stop(): Unit = synchronized {
+    if (socket != null) {
+      try {
+        // Unfortunately, BufferedReader.readLine() cannot be interrupted, so 
the only way to
+        // stop the readThread is to close the socket.
+        socket.close()
+      } catch {
+        case e: IOException =>
+      }
+      socket = null
+    }
+    if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef)
+  }
+
+  private def initialize(): Unit = synchronized {
+    socket = new Socket(host, port)
+    val reader = new BufferedReader(new 
InputStreamReader(socket.getInputStream))
+    // Thread continuously reads from a socket and inserts data into buckets
+    readThread = new Thread(s"TextSocketContinuousReader($host, $port)") {
+      setDaemon(true)
+
+      override def run(): Unit = {
+        try {
+          while (true) {
+            val line = reader.readLine()
+            if (line == null) {
+              // End of file reached
+              logWarning(s"Stream closed by $host:$port")
+              return
+            }
+            TextSocketContinuousReader.this.synchronized {
+              currentOffset += 1
+              val newData = (line,
+                Timestamp.valueOf(
+                  
TextSocketReader.DATE_FORMAT.format(Calendar.getInstance().getTime()))
+              )
+              buckets(currentOffset % numPartitions) += newData
+            }
+          }
+        } catch {
+          case e: IOException =>
+        }
+      }
+    }
+
+    readThread.start()
+  }
+
+  override def toString: String = s"TextSocketContinuousReader[host: $host, 
port: $port]"
+
+  private def includeTimestamp: Boolean = 
options.getBoolean("includeTimestamp", false)
+
+}
+
+/**
+ * Continuous text socket input partition.
+ */
+case class TextSocketContinuousInputPartition(
+    driverEndpointName: String,
+    partitionId: Int,
+    startOffset: Int,
+    includeTimestamp: Boolean)
+extends InputPartition[InternalRow] {
+
+  override def createPartitionReader(): InputPartitionReader[InternalRow] =
+    new TextSocketContinuousInputPartitionReader(driverEndpointName, 
partitionId, startOffset,
+      includeTimestamp)
+}
+
+/**
+ * Continuous text socket input partition reader.
+ *
+ * Polls the driver endpoint for new records.
+ */
+class TextSocketContinuousInputPartitionReader(
+    driverEndpointName: String,
+    partitionId: Int,
+    startOffset: Int,
+    includeTimestamp: Boolean)
+  extends ContinuousInputPartitionReader[InternalRow] {
+
+  private val endpoint = RpcUtils.makeDriverRef(
+    driverEndpointName,
+    SparkEnv.get.conf,
+    SparkEnv.get.rpcEnv)
+
+  private var currentOffset = startOffset
+  private var current: Option[InternalRow] = None
+
+  override def next(): Boolean = {
+    try {
+      current = getRecord
+      while (current.isEmpty) {
+        Thread.sleep(100)
+        current = getRecord
+      }
+      currentOffset += 1
+    } catch {
+      case _: InterruptedException =>
+        // Someone's trying to end the task; just let them.
+        return false
+    }
+    true
+  }
+
+  override def get(): InternalRow = {
+    current.get
+  }
+
+  override def close(): Unit = {}
+
+  override def getOffset: PartitionOffset =
+    ContinuousRecordPartitionOffset(partitionId, currentOffset)
+
+  private def getRecord: Option[InternalRow] =
+    endpoint.askSync[Option[InternalRow]](GetRecord(
+      ContinuousRecordPartitionOffset(partitionId, currentOffset))).map(rec =>
+      if (includeTimestamp) {
+        rec
+      } else {
+        InternalRow(rec.get(0, TextSocketReader.SCHEMA_TIMESTAMP)
+          .asInstanceOf[(String, Timestamp)]._1)
+      }
+    )
+}
+
+case class TextSocketOffset(offsets: List[Int]) extends Offset {
+  private implicit val formats = Serialization.formats(NoTypeHints)
+  override def json: String = Serialization.write(offsets)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9abe09bf/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index 711f094..4a32217 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -33,7 +33,6 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, 
RpcEnv, ThreadSafeR
 import org.apache.spark.sql.{Encoder, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.streaming._
-import 
org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord
 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, 
DataSourceOptions}
 import org.apache.spark.sql.sources.v2.reader.InputPartition
 import 
org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader,
 ContinuousReader, Offset, PartitionOffset}
@@ -63,7 +62,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, 
sqlContext: SQLContext, numPa
   @GuardedBy("this")
   private var startOffset: ContinuousMemoryStreamOffset = _
 
-  private val recordEndpoint = new RecordEndpoint()
+  private val recordEndpoint = new ContinuousRecordEndpoint(records, this)
   @volatile private var endpointRef: RpcEndpointRef = _
 
   def addData(data: TraversableOnce[A]): Offset = synchronized {
@@ -94,7 +93,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, 
sqlContext: SQLContext, numPa
   override def mergeOffsets(offsets: Array[PartitionOffset]): 
ContinuousMemoryStreamOffset = {
     ContinuousMemoryStreamOffset(
       offsets.map {
-        case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num)
+        case ContinuousRecordPartitionOffset(part, num) => (part, num)
       }.toMap
     )
   }
@@ -127,27 +126,9 @@ class ContinuousMemoryStream[A : Encoder](id: Int, 
sqlContext: SQLContext, numPa
       options: DataSourceOptions): ContinuousReader = {
     this
   }
-
-  /**
-   * Endpoint for executors to poll for records.
-   */
-  private class RecordEndpoint extends ThreadSafeRpcEndpoint {
-    override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
-
-    override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
-      case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) =>
-        ContinuousMemoryStream.this.synchronized {
-          val buf = records(part)
-          val record = if (buf.size <= index) None else Some(buf(index))
-
-          context.reply(record.map(r => encoder.toRow(r).copy()))
-        }
-    }
-  }
 }
 
 object ContinuousMemoryStream {
-  case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset)
   protected val memoryStreamId = new AtomicInteger(0)
 
   def apply[A : Encoder](implicit sqlContext: SQLContext): 
ContinuousMemoryStream[A] =
@@ -207,12 +188,12 @@ class ContinuousMemoryStreamInputPartitionReader(
 
   override def close(): Unit = {}
 
-  override def getOffset: ContinuousMemoryStreamPartitionOffset =
-    ContinuousMemoryStreamPartitionOffset(partition, currentOffset)
+  override def getOffset: ContinuousRecordPartitionOffset =
+    ContinuousRecordPartitionOffset(partition, currentOffset)
 
   private def getRecord: Option[InternalRow] =
     endpoint.askSync[Option[InternalRow]](
-      GetRecord(ContinuousMemoryStreamPartitionOffset(partition, 
currentOffset)))
+      GetRecord(ContinuousRecordPartitionOffset(partition, currentOffset)))
 }
 
 case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int])
@@ -220,6 +201,3 @@ case class ContinuousMemoryStreamOffset(partitionNums: 
Map[Int, Int])
   private implicit val formats = Serialization.formats(NoTypeHints)
   override def json(): String = Serialization.write(partitionNums)
 }
-
-case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: 
Int)
-  extends PartitionOffset

http://git-wip-us.apache.org/repos/asf/spark/blob/9abe09bf/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
index 9f53a18..874c479 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
@@ -33,14 +33,16 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.streaming.LongOffset
+import 
org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReader
 import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, 
MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, 
DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
 import org.apache.spark.sql.sources.v2.reader.{InputPartition, 
InputPartitionReader}
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, 
Offset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, 
MicroBatchReader, Offset}
 import org.apache.spark.sql.types.{StringType, StructField, StructType, 
TimestampType}
 import org.apache.spark.unsafe.types.UTF8String
 
-object TextSocketMicroBatchReader {
+// Shared object for micro-batch and continuous reader
+object TextSocketReader {
   val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
   val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
     StructField("timestamp", TimestampType) :: Nil)
@@ -137,9 +139,9 @@ class TextSocketMicroBatchReader(options: 
DataSourceOptions) extends MicroBatchR
 
   override def readSchema(): StructType = {
     if (options.getBoolean("includeTimestamp", false)) {
-      TextSocketMicroBatchReader.SCHEMA_TIMESTAMP
+      TextSocketReader.SCHEMA_TIMESTAMP
     } else {
-      TextSocketMicroBatchReader.SCHEMA_REGULAR
+      TextSocketReader.SCHEMA_REGULAR
     }
   }
 
@@ -226,7 +228,7 @@ class TextSocketMicroBatchReader(options: 
DataSourceOptions) extends MicroBatchR
 }
 
 class TextSocketSourceProvider extends DataSourceV2
-  with MicroBatchReadSupport with DataSourceRegister with Logging {
+  with MicroBatchReadSupport with ContinuousReadSupport with 
DataSourceRegister with Logging {
 
   private def checkParameters(params: DataSourceOptions): Unit = {
     logWarning("The socket source should not be used for production 
applications! " +
@@ -258,6 +260,17 @@ class TextSocketSourceProvider extends DataSourceV2
     new TextSocketMicroBatchReader(options)
   }
 
+  override def createContinuousReader(
+      schema: Optional[StructType],
+      checkpointLocation: String,
+      options: DataSourceOptions): ContinuousReader = {
+    checkParameters(options)
+    if (schema.isPresent) {
+      throw new AnalysisException("The socket source does not support a 
user-specified schema.")
+    }
+    new TextSocketContinuousReader(options)
+  }
+
   /** String that represents the format that this data source provider uses. */
   override def shortName(): String = "socket"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9abe09bf/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
index 52e8386..48e5cf7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
@@ -32,12 +32,13 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.v2.{DataSourceOptions, 
MicroBatchReadSupport}
 import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, 
Offset}
 import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{StringType, StructField, StructType, 
TimestampType}
+import org.apache.spark.sql.types._
 
 class TextSocketStreamSuite extends StreamTest with SharedSQLContext with 
BeforeAndAfterEach {
 
@@ -300,6 +301,101 @@ class TextSocketStreamSuite extends StreamTest with 
SharedSQLContext with Before
     }
   }
 
+  test("continuous data") {
+    serverThread = new ServerThread()
+    serverThread.start()
+
+    val reader = new TextSocketContinuousReader(
+      new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
+        "port" -> serverThread.port.toString).asJava))
+    reader.setStartOffset(Optional.empty())
+    val tasks = reader.planInputPartitions()
+    assert(tasks.size == 2)
+
+    val numRecords = 10
+    val data = scala.collection.mutable.ListBuffer[Int]()
+    val offsets = scala.collection.mutable.ListBuffer[Int]()
+    import org.scalatest.time.SpanSugar._
+    failAfter(5 seconds) {
+      // inject rows, read and check the data and offsets
+      for (i <- 0 until numRecords) {
+        serverThread.enqueue(i.toString)
+      }
+      tasks.asScala.foreach {
+        case t: TextSocketContinuousInputPartition =>
+          val r = 
t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader]
+          for (i <- 0 until numRecords / 2) {
+            r.next()
+            
offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset)
+            data.append(r.get().get(0, 
DataTypes.StringType).asInstanceOf[String].toInt)
+            // commit the offsets in the middle and validate if processing 
continues
+            if (i == 2) {
+              commitOffset(t.partitionId, i + 1)
+            }
+          }
+          assert(offsets.toSeq == Range.inclusive(1, 5))
+          assert(data.toSeq == Range(t.partitionId, 10, 2))
+          offsets.clear()
+          data.clear()
+        case _ => throw new IllegalStateException("Unexpected task type")
+      }
+      assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == 
List(3, 3))
+      reader.commit(TextSocketOffset(List(5, 5)))
+      assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == 
List(5, 5))
+    }
+
+    def commitOffset(partition: Int, offset: Int): Unit = {
+      val offsetsToCommit = 
reader.getStartOffset.asInstanceOf[TextSocketOffset]
+        .offsets.updated(partition, offset)
+      reader.commit(TextSocketOffset(offsetsToCommit))
+      assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == 
offsetsToCommit)
+    }
+  }
+
+  test("continuous data - invalid commit") {
+    serverThread = new ServerThread()
+    serverThread.start()
+
+    val reader = new TextSocketContinuousReader(
+      new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
+        "port" -> serverThread.port.toString).asJava))
+    reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5))))
+    // ok to commit same offset
+    reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5))))
+    assertThrows[IllegalStateException] {
+      reader.commit(TextSocketOffset(List(6, 6)))
+    }
+  }
+
+  test("continuous data with timestamp") {
+    serverThread = new ServerThread()
+    serverThread.start()
+
+    val reader = new TextSocketContinuousReader(
+      new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
+        "includeTimestamp" -> "true",
+        "port" -> serverThread.port.toString).asJava))
+    reader.setStartOffset(Optional.empty())
+    val tasks = reader.planInputPartitions()
+    assert(tasks.size == 2)
+
+    val numRecords = 4
+    // inject rows, read and check the data and offsets
+    for (i <- 0 until numRecords) {
+      serverThread.enqueue(i.toString)
+    }
+    tasks.asScala.foreach {
+      case t: TextSocketContinuousInputPartition =>
+        val r = 
t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader]
+        for (i <- 0 until numRecords / 2) {
+          r.next()
+          assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP)
+            .isInstanceOf[(String, Timestamp)])
+        }
+      case _ => throw new IllegalStateException("Unexpected task type")
+    }
+  }
+
   /**
    * This class tries to mimic the behavior of netcat, so that we can ensure
    * TextSocketStream supports netcat, which only accepts the first connection


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to