Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20382#discussion_r171510614
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
 ---
    @@ -0,0 +1,300 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming.sources
    +
    +import java.io.IOException
    +import java.net.InetSocketAddress
    +import java.nio.ByteBuffer
    +import java.nio.channels.ServerSocketChannel
    +import java.sql.Timestamp
    +import java.util.Optional
    +import java.util.concurrent.LinkedBlockingQueue
    +
    +import scala.collection.JavaConverters._
    +
    +import org.scalatest.BeforeAndAfterEach
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.sql.AnalysisException
    +import org.apache.spark.sql.execution.datasources.DataSource
    +import org.apache.spark.sql.execution.streaming._
    +import org.apache.spark.sql.sources.v2.{DataSourceOptions, 
MicroBatchReadSupport}
    +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, 
Offset}
    +import org.apache.spark.sql.streaming.StreamTest
    +import org.apache.spark.sql.test.SharedSQLContext
    +import org.apache.spark.sql.types.{StringType, StructField, StructType, 
TimestampType}
    +
    +class TextSocketStreamSuite extends StreamTest with SharedSQLContext with 
BeforeAndAfterEach {
    +
    +  override def afterEach() {
    +    sqlContext.streams.active.foreach(_.stop())
    +    if (serverThread != null) {
    +      serverThread.interrupt()
    +      serverThread.join()
    +      serverThread = null
    +    }
    +    if (batchReader != null) {
    +      batchReader.stop()
    +      batchReader = null
    +    }
    +  }
    +
    +  private var serverThread: ServerThread = null
    +  private var batchReader: MicroBatchReader = null
    +
    +  case class AddSocketData(data: String*) extends AddData {
    +    override def addData(query: Option[StreamExecution]): 
(BaseStreamingSource, Offset) = {
    +      require(
    +        query.nonEmpty,
    +        "Cannot add data when there is no query for finding the active 
socket source")
    +
    +      val sources = query.get.logicalPlan.collect {
    +        case StreamingExecutionRelation(source: 
TextSocketMicroBatchReader, _) => source
    +      }
    +      if (sources.isEmpty) {
    +        throw new Exception(
    +          "Could not find socket source in the StreamExecution logical 
plan to add data to")
    +      } else if (sources.size > 1) {
    +        throw new Exception(
    +          "Could not select the socket source in the StreamExecution 
logical plan as there" +
    +            "are multiple socket sources:\n\t" + sources.mkString("\n\t"))
    +      }
    +      val socketSource = sources.head
    +
    +      assert(serverThread != null && serverThread.port != 0)
    +      val currOffset = socketSource.currentOffset
    +      data.foreach(serverThread.enqueue)
    +
    +      val newOffset = LongOffset(currOffset.offset + data.size)
    +      (socketSource, newOffset)
    +    }
    +
    +    override def toString: String = s"AddSocketData(data = $data)"
    +  }
    +
    +  test("backward compatibility with old path") {
    +    
DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider",
    +      spark.sqlContext.conf).newInstance() match {
    +      case ds: MicroBatchReadSupport =>
    +        assert(ds.isInstanceOf[TextSocketSourceProvider])
    +      case _ =>
    +        throw new IllegalStateException("Could not find socket source")
    +    }
    +  }
    +
    +  test("basic usage") {
    +    serverThread = new ServerThread()
    +    serverThread.start()
    +
    +    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> 
"false") {
    +      val ref = spark
    +      import ref.implicits._
    +
    +      val socket = spark
    +        .readStream
    +        .format("socket")
    +        .options(Map("host" -> "localhost", "port" -> 
serverThread.port.toString))
    +        .load()
    +        .as[String]
    +
    +      assert(socket.schema === StructType(StructField("value", StringType) 
:: Nil))
    +
    +      testStream(socket)(
    +        StartStream(),
    +        AddSocketData("hello"),
    +        CheckAnswer("hello"),
    +        AddSocketData("world"),
    +        CheckLastBatch("world"),
    +        CheckAnswer("hello", "world"),
    +        StopStream
    +      )
    +    }
    +  }
    +
    +  test("timestamped usage") {
    +    serverThread = new ServerThread()
    +    serverThread.start()
    +
    +    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> 
"false") {
    +      val socket = spark
    +        .readStream
    +        .format("socket")
    +        .options(Map(
    +          "host" -> "localhost",
    +          "port" -> serverThread.port.toString,
    +          "includeTimestamp" -> "true"))
    +        .load()
    +
    +      assert(socket.schema === StructType(StructField("value", StringType) 
::
    +        StructField("timestamp", TimestampType) :: Nil))
    +
    +      var batch1Stamp: Timestamp = null
    +      var batch2Stamp: Timestamp = null
    +
    +      testStream(socket)(
    +        StartStream(),
    +        AddSocketData("hello"),
    +        CheckAnswerRowsByFunc(
    +          rows => {
    +            assert(rows.size === 1)
    +            assert(rows.head.getAs[String](0) === "hello")
    +            batch1Stamp = rows.head.getAs[Timestamp](1)
    +          },
    +          true),
    +        AddSocketData("world"),
    +        CheckAnswerRowsByFunc(
    +          rows => {
    +            assert(rows.size === 1)
    +            assert(rows.head.getAs[String](0) === "world")
    +            batch2Stamp = rows.head.getAs[Timestamp](1)
    +          },
    +          true),
    +        StopStream
    +      )
    +
    +      assert(!batch2Stamp.before(batch1Stamp))
    --- End diff --
    
    val timestamp = System.currentTimeMillis
    testStream(...)(
       // get batch1stamp
    )
    // assert batch1stamp >= timestamp



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to