Github user HeartSaVioR commented on a diff in the pull request: https://github.com/apache/spark/pull/21199#discussion_r206386593 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala --- @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.net.Socket +import java.sql.Timestamp +import java.text.SimpleDateFormat +import java.util.{Calendar, List => JList, Locale} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +import org.json4s.{DefaultFormats, NoTypeHints} +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, ContinuousRecordPartitionOffset, GetRecord} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.util.RpcUtils + + +object TextSocketContinuousReader { + val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) + val SCHEMA_TIMESTAMP = StructType( + StructField("value", StringType) + :: StructField("timestamp", TimestampType) :: Nil) + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) +} + +/** + * A ContinuousReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This ContinuousReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. + * + * The driver maintains a socket connection to the host-port, keeps the received messages in + * buckets and serves the messages to the executors via a RPC endpoint. + */ +class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousReader + with SupportsDeprecatedScanRow with Logging { + implicit val defaultFormats: DefaultFormats = DefaultFormats + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + private val numPartitions = spark.sparkContext.defaultParallelism + + @GuardedBy("this") + private var socket: Socket = _ + + @GuardedBy("this") + private var readThread: Thread = _ + + @GuardedBy("this") + private val buckets = Seq.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + + @GuardedBy("this") + private var currentOffset: Int = -1 + + private var startOffset: TextSocketOffset = _ + + private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) + @volatile private var endpointRef: RpcEndpointRef = _ + + initialize() + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + assert(offsets.length == numPartitions) + val offs = offsets + .map(_.asInstanceOf[ContinuousRecordPartitionOffset]) + .sortBy(_.partitionId) + .map(_.offset) + .toList + TextSocketOffset(offs) + } + + override def deserializeOffset(json: String): Offset = { + TextSocketOffset(Serialization.read[List[Int]](json)) + } + + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { + this.startOffset = offset + .orElse(TextSocketOffset(List.fill(numPartitions)(0))) + .asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) + } + + override def getStartOffset: Offset = startOffset + + override def readSchema(): StructType = { + if (includeTimestamp) { + TextSocketContinuousReader.SCHEMA_TIMESTAMP + } else { + TextSocketContinuousReader.SCHEMA_REGULAR + } + } + + override def planRowInputPartitions(): JList[InputPartition[Row]] = { + + val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" + endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + val offsets = startOffset match { + case off: TextSocketOffset => off.offsets + case off => + throw new IllegalArgumentException( + s"invalid offset type ${off.getClass} for TextSocketContinuousReader") + } + + if (offsets.size != numPartitions) { + throw new IllegalArgumentException( + s"The previous run contained ${offsets.size} partitions, but" + + s" $numPartitions partitions are currently configured. The numPartitions option" + + " cannot be changed.") + } + + startOffset.offsets.zipWithIndex.map { + case (offset, i) => + TextSocketContinuousInputPartition( + endpointName, i, offset, includeTimestamp): InputPartition[Row] + }.asJava + + } + + override def commit(end: Offset): Unit = synchronized { + val endOffset = end.asInstanceOf[TextSocketOffset] + endOffset.offsets.zipWithIndex.foreach { + case (offset, partition) => + val max = startOffset.offsets(partition) + buckets(partition).size + if (offset > max) { + throw new IllegalStateException("Invalid offset " + offset + " to commit" + + " for partition " + partition + ". Max valid offset: " + max) + } + val n = offset - startOffset.offsets(partition) + buckets(partition).trimStart(n) + } + startOffset = endOffset + recordEndpoint.setStartOffsets(startOffset.offsets) + } + + /** Stop this source. */ + override def stop(): Unit = synchronized { + if (socket != null) { + try { + // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to + // stop the readThread is to close the socket. + socket.close() + } catch { + case e: IOException => + } + socket = null + } + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + private def initialize(): Unit = synchronized { + socket = new Socket(host, port) + val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) + // Thread continuously reads from a socket and inserts data into buckets + readThread = new Thread(s"TextSocketContinuousReader($host, $port)") { + setDaemon(true) + + override def run(): Unit = { + try { + while (true) { + val line = reader.readLine() + if (line == null) { + // End of file reached + logWarning(s"Stream closed by $host:$port") + return + } + TextSocketContinuousReader.this.synchronized { + currentOffset += 1 + val newData = (line, + Timestamp.valueOf( + TextSocketContinuousReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + buckets(currentOffset % numPartitions) += newData + } + } + } catch { + case e: IOException => + } + } + } + + readThread.start() + } + + override def toString: String = s"TextSocketContinuousReader[host: $host, port: $port]" + + private def includeTimestamp: Boolean = options.getBoolean("includeTimestamp", false) + +} + +/** + * Continuous text socket input partition. + */ +case class TextSocketContinuousInputPartition( + driverEndpointName: String, + partitionId: Int, + startOffset: Int, + includeTimestamp: Boolean) +extends InputPartition[Row] { + + override def createPartitionReader(): InputPartitionReader[Row] = + new TextSocketContinuousInputPartitionReader(driverEndpointName, partitionId, startOffset, + includeTimestamp) +} + +/** + * Continuous text socket input partition reader. + * + * Polls the driver endpoint for new records. + */ +class TextSocketContinuousInputPartitionReader( + driverEndpointName: String, + partitionId: Int, + startOffset: Int, + includeTimestamp: Boolean) + extends ContinuousInputPartitionReader[Row] { + + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[Row] = None + + override def next(): Boolean = { + try { + current = getRecord + while (current.isEmpty) { + Thread.sleep(100) + current = getRecord + } + currentOffset += 1 + } catch { + case _: InterruptedException => + // Someone's trying to end the task; just let them. + return false + } + true + } + + override def get(): Row = { + current.get + } + + override def close(): Unit = {} + + override def getOffset: PartitionOffset = + ContinuousRecordPartitionOffset(partitionId, currentOffset) + + private def getRecord: Option[Row] = + endpoint.askSync[Option[Row]](GetRecord( + ContinuousRecordPartitionOffset(partitionId, currentOffset))) + .map(rec => { --- End diff -- nit: according to style guide, this may need to be written as follow ``` .map { rec => if (includeTimestamp) { ... ``` or even ``` ContinuousRecordPartitionOffset(partitionId, currentOffset))).map { rec => if (includeTimestamp) { ... ``` https://github.com/databricks/scala-style-guide#anonymous-methods
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org