http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala deleted file mode 100644 index 5884380..0000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.io.ByteArrayOutputStream - -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming.{StreamTest, Trigger} - -class ConsoleWriteSupportSuite extends StreamTest { - import testImplicits._ - - test("microbatch - default") { - val input = MemoryStream[Int] - - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - val query = input.toDF().writeStream.format("console").start() - try { - input.addData(1, 2, 3) - query.processAllAvailable() - input.addData(4, 5, 6) - query.processAllAvailable() - input.addData() - query.processAllAvailable() - } finally { - query.stop() - } - } - - assert(captured.toString() == - """------------------------------------------- - |Batch: 0 - |------------------------------------------- - |+-----+ - ||value| - |+-----+ - || 1| - || 2| - || 3| - |+-----+ - | - |------------------------------------------- - |Batch: 1 - |------------------------------------------- - |+-----+ - ||value| - |+-----+ - || 4| - || 5| - || 6| - |+-----+ - | - |------------------------------------------- - |Batch: 2 - |------------------------------------------- - |+-----+ - ||value| - |+-----+ - |+-----+ - | - |""".stripMargin) - } - - test("microbatch - with numRows") { - val input = MemoryStream[Int] - - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() - try { - input.addData(1, 2, 3) - query.processAllAvailable() - } finally { - query.stop() - } - } - - assert(captured.toString() == - """------------------------------------------- - |Batch: 0 - |------------------------------------------- - |+-----+ - ||value| - |+-----+ - || 1| - || 2| - |+-----+ - |only showing top 2 rows - | - |""".stripMargin) - } - - test("microbatch - truncation") { - val input = MemoryStream[String] - - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() - try { - input.addData("123456789012345678901234567890") - query.processAllAvailable() - } finally { - query.stop() - } - } - - assert(captured.toString() == - """------------------------------------------- - |Batch: 0 - |------------------------------------------- - |+--------------------+ - || value| - |+--------------------+ - ||12345678901234567...| - |+--------------------+ - | - |""".stripMargin) - } - - test("continuous - default") { - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - val input = spark.readStream - .format("rate") - .option("numPartitions", "1") - .option("rowsPerSecond", "5") - .load() - .select('value) - - val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() - assert(query.isActive) - query.stop() - } - } -}
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala new file mode 100644 index 0000000..55acf2b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.ByteArrayOutputStream + +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.{StreamTest, Trigger} + +class ConsoleWriterSuite extends StreamTest { + import testImplicits._ + + test("microbatch - default") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + input.addData(4, 5, 6) + query.processAllAvailable() + input.addData() + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + || 3| + |+-----+ + | + |------------------------------------------- + |Batch: 1 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 4| + || 5| + || 6| + |+-----+ + | + |------------------------------------------- + |Batch: 2 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + |+-----+ + | + |""".stripMargin) + } + + test("microbatch - with numRows") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + |+-----+ + |only showing top 2 rows + | + |""".stripMargin) + } + + test("microbatch - truncation") { + val input = MemoryStream[String] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() + try { + input.addData("123456789012345678901234567890") + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+--------------------+ + || value| + |+--------------------+ + ||12345678901234567...| + |+--------------------+ + | + |""".stripMargin) + } + + test("continuous - default") { + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val input = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value) + + val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() + assert(query.isActive) + query.stop() + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index dd74af8..5ca13b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util.Optional import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -41,7 +43,7 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -54,10 +56,10 @@ class RateSourceSuite extends StreamTest { test("microbatch in registry") { withTempDir { temp => DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => - val readSupport = ds.createMicroBatchReadSupport( - temp.getCanonicalPath, DataSourceOptions.empty()) - assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader( + Optional.empty(), temp.getCanonicalPath, DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) case _ => throw new IllegalStateException("Could not find read support for rate") } @@ -67,7 +69,7 @@ class RateSourceSuite extends StreamTest { test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => + case ds: MicroBatchReadSupport => assert(ds.isInstanceOf[RateStreamProvider]) case _ => throw new IllegalStateException("Could not find read support for rate") @@ -139,19 +141,30 @@ class RateSourceSuite extends StreamTest { ) } + test("microbatch - set offset") { + withTempDir { temp => + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + } + test("microbatch - infer offsets") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val reader = new RateStreamMicroBatchReader( new DataSourceOptions( Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), temp.getCanonicalPath) - readSupport.clock.asInstanceOf[ManualClock].advance(100000) - val startOffset = readSupport.initialOffset() - startOffset match { + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { case r: LongOffset => assert(r.offset === 0L) case _ => throw new IllegalStateException("unexpected offset type") } - readSupport.latestOffset() match { + reader.getEndOffset() match { case r: LongOffset => assert(r.offset >= 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -160,16 +173,15 @@ class RateSourceSuite extends StreamTest { test("microbatch - predetermined batch size") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val reader = new RateStreamMicroBatchReader( new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createReaderFactory(config) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.planInputPartitions() assert(tasks.size == 1) - val dataReader = readerFactory.createReader(tasks(0)) + val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[InternalRow]() while (dataReader.next()) { data.append(dataReader.get()) @@ -180,25 +192,24 @@ class RateSourceSuite extends StreamTest { test("microbatch - data read") { withTempDir { temp => - val readSupport = new RateStreamMicroBatchReadSupport( + val reader = new RateStreamMicroBatchReader( new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createReaderFactory(config) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.planInputPartitions() assert(tasks.size == 11) - val readData = tasks - .map(readerFactory.createReader) + val readData = tasks.asScala + .map(_.createPartitionReader()) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[InternalRow]() while (reader.next()) buf.append(reader.get()) buf } - assert(readData.map(_.getLong(1)).sorted === 0.until(33).toArray) + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) } } @@ -309,44 +320,41 @@ class RateSourceSuite extends StreamTest { } test("user-specified schema given") { - val exception = intercept[UnsupportedOperationException] { + val exception = intercept[AnalysisException] { spark.readStream .format("rate") .schema(spark.range(1).schema) .load() } assert(exception.getMessage.contains( - "rate source does not support user-specified schema")) + "rate source does not support a user-specified schema")) } test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupportProvider => - val readSupport = ds.createContinuousReadSupport( - "", DataSourceOptions.empty()) - assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport]) + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find read support for continuous rate") } } test("continuous data") { - val readSupport = new RateStreamContinuousReadSupport( + val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build() - val tasks = readSupport.planInputPartitions(config) - val readerFactory = readSupport.createContinuousReaderFactory(config) + reader.setStartOffset(Optional.empty()) + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[InternalRow]() - tasks.foreach { + tasks.asScala.foreach { case t: RateStreamContinuousInputPartition => - val startTimeMs = readSupport.initialOffset() + val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = readerFactory.createReader(t) - .asInstanceOf[RateStreamContinuousPartitionReader] + val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 409156e..48e5cf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -21,6 +21,7 @@ import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp +import java.util.Optional import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -33,8 +34,8 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -48,9 +49,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread.join() serverThread = null } + if (batchReader != null) { + batchReader.stop() + batchReader = null + } } private var serverThread: ServerThread = null + private var batchReader: MicroBatchReader = null case class AddSocketData(data: String*) extends AddData { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { @@ -59,7 +65,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source + case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source } if (sources.isEmpty) { throw new Exception( @@ -85,7 +91,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupportProvider => + case ds: MicroBatchReadSupport => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => throw new IllegalStateException("Could not find socket source") @@ -175,16 +181,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map.empty[String, String].asJava)) + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReadSupport( - "", new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("port" -> "1234").asJava)) } } @@ -193,7 +199,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReadSupport("", a) + provider.createMicroBatchReader(Optional.empty(), "", a) } } @@ -203,12 +209,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("name", StringType) :: StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") - val exception = intercept[UnsupportedOperationException] { - provider.createMicroBatchReadSupport( - userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) + val exception = intercept[AnalysisException] { + provider.createMicroBatchReader( + Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) } assert(exception.getMessage.contains( - "socket source does not support user-specified schema")) + "socket source does not support a user-specified schema")) } test("input row metrics") { @@ -299,27 +305,25 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val reader = new TextSocketContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - - val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() - val tasks = readSupport.planInputPartitions(scanConfig) + reader.setStartOffset(Optional.empty()) + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val numRecords = 10 val data = scala.collection.mutable.ListBuffer[Int]() val offsets = scala.collection.mutable.ListBuffer[Int]() - val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) import org.scalatest.time.SpanSugar._ failAfter(5 seconds) { // inject rows, read and check the data and offsets for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.foreach { + tasks.asScala.foreach { case t: TextSocketContinuousInputPartition => - val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] + val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] for (i <- 0 until numRecords / 2) { r.next() offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset) @@ -335,15 +339,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before data.clear() case _ => throw new IllegalStateException("Unexpected task type") } - assert(readSupport.startOffset.offsets == List(3, 3)) - readSupport.commit(TextSocketOffset(List(5, 5))) - assert(readSupport.startOffset.offsets == List(5, 5)) + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(3, 3)) + reader.commit(TextSocketOffset(List(5, 5))) + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(5, 5)) } def commitOffset(partition: Int, offset: Int): Unit = { - val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset) - readSupport.commit(TextSocketOffset(offsetsToCommit)) - assert(readSupport.startOffset.offsets == offsetsToCommit) + val offsetsToCommit = reader.getStartOffset.asInstanceOf[TextSocketOffset] + .offsets.updated(partition, offset) + reader.commit(TextSocketOffset(offsetsToCommit)) + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == offsetsToCommit) } } @@ -351,13 +356,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val reader = new TextSocketContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - - readSupport.startOffset = TextSocketOffset(List(5, 5)) + reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) + // ok to commit same offset + reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) assertThrows[IllegalStateException] { - readSupport.commit(TextSocketOffset(List(6, 6))) + reader.commit(TextSocketOffset(List(6, 6))) } } @@ -365,12 +371,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val readSupport = new TextSocketContinuousReadSupport( + val reader = new TextSocketContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "includeTimestamp" -> "true", "port" -> serverThread.port.toString).asJava)) - val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() - val tasks = readSupport.planInputPartitions(scanConfig) + reader.setStartOffset(Optional.empty()) + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val numRecords = 4 @@ -378,10 +384,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) - tasks.foreach { + tasks.asScala.foreach { case t: TextSocketContinuousInputPartition => - val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] + val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] for (i <- 0 until numRecords / 2) { r.next() assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index f6c3e0c..12beca2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources.v2 +import java.util.{ArrayList, List => JList} + import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException @@ -36,21 +38,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ - private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] - }.head - } - - private def getJavaScanConfig( - query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] - }.head - } - test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -63,6 +50,18 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + } + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -71,58 +70,58 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q1) - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + val reader = getReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) } else { - val config = getJavaScanConfig(q1) - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + val reader = getJavaReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q2) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i", "j")) + val reader = getReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) } else { - val config = getJavaScanConfig(q2) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i", "j")) + val reader = getJavaReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q3) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i")) + val reader = getReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) } else { - val config = getJavaScanConfig(q3) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i")) + val reader = getJavaReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q4) + val reader = getReader(q4) // 'j < 10 is not supported by the testing data source. - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) } else { - val config = getJavaScanConfig(q4) + val reader = getJavaReader(q4) // 'j < 10 is not supported by the testing data source. - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) } } } } test("columnar batch scan implementation") { - Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => + Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) @@ -154,25 +153,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('i).agg(sum('j)) + val groupByColA = df.groupBy('a).agg(sum('b)) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(groupByColA.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('i, 'j).agg(count("*")) + val groupByColAB = df.groupBy('a, 'b).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(groupByColAB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('j).agg(sum('i)) + val groupByColB = df.groupBy('b).agg(sum('a)) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(groupByColB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) + val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e @@ -273,30 +272,36 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23301: column pruning with arbitrary expressions") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val config1 = getScanConfig(q1) - assert(config1.requiredSchema.fieldNames === Seq("i")) + val reader1 = getReader(q1) + assert(reader1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val config2 = getScanConfig(q2) - assert(config2.requiredSchema.isEmpty) + val reader2 = getReader(q2) + assert(reader2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val config3 = getScanConfig(q3) - assert(config3.filters.isEmpty) - assert(config3.requiredSchema.fieldNames === Seq("j")) + val reader3 = getReader(q3) + assert(reader3.filters.isEmpty) + assert(reader3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val config4 = getScanConfig(q4) - assert(config4.requiredSchema.fieldNames === Seq("i")) + val reader4 = getReader(q4) + assert(reader4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { @@ -319,291 +324,240 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } +class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { -case class RangeInputPartition(start: Int, end: Int) extends InputPartition - -case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { - override def build(): ScanConfig = this -} - -object SimpleReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val RangeInputPartition(start, end) = partition - new PartitionReader[InternalRow] { - private var current = start - 1 - - override def next(): Boolean = { - current += 1 - current < end - } - - override def get(): InternalRow = InternalRow(current, -current) + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def close(): Unit = {} + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -abstract class SimpleReadSupport extends BatchReadSupport { - override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. +class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - override def newScanConfigBuilder(): ScanConfigBuilder = { - NoopScanConfigBuilder(fullSchema()) - } + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - SimpleReaderFactory + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) + } } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } +class SimpleInputPartition(start: Int, end: Int) + extends InputPartition[InternalRow] + with InputPartitionReader[InternalRow] { + private var current = start - 1 -class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new SimpleInputPartition(start, end) - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - Array(RangeInputPartition(0, 5)) - } + override def next(): Boolean = { + current += 1 + current < end } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def get(): InternalRow = InternalRow(current, -current) + + override def close(): Unit = {} } -// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark -// tests still pass. -class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) - } - } + class Reader extends DataSourceReader + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } -} + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } -class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported + } - class ReadSupport extends SimpleReadSupport { - override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() + override def pushedFilters(): Array[Filter] = filters - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters + override def readSchema(): StructType = { + requiredSchema + } + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { val lowerBound = filters.collectFirst { case GreaterThan("i", v: Int) => v } - val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + val res = new ArrayList[InputPartition[InternalRow]] if (lowerBound.isEmpty) { - res.append(RangeInputPartition(0, 5)) - res.append(RangeInputPartition(5, 10)) + res.add(new AdvancedInputPartition(0, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.append(RangeInputPartition(lowerBound.get + 1, 5)) - res.append(RangeInputPartition(5, 10)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.append(RangeInputPartition(lowerBound.get + 1, 10)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) } - res.toArray - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema - new AdvancedReaderFactory(requiredSchema) + res } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { +class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { - var requiredSchema = new StructType().add("i", "int").add("j", "int") - var filters = Array.empty[Filter] + private var current = start - 1 - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + override def createPartitionReader(): InputPartitionReader[InternalRow] = { + new AdvancedInputPartition(start, end, requiredSchema) } - override def readSchema(): StructType = requiredSchema + override def close(): Unit = {} - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true - case _ => false - } - this.filters = supported - unsupported + override def next(): Boolean = { + current += 1 + current < end } - override def pushedFilters(): Array[Filter] = filters - - override def build(): ScanConfig = this -} - -class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val RangeInputPartition(start, end) = partition - new PartitionReader[InternalRow] { - private var current = start - 1 - - override def next(): Boolean = { - current += 1 - current < end - } - - override def get(): InternalRow = { - val values = requiredSchema.map(_.name).map { - case "i" => current - case "j" => -current - } - InternalRow.fromSeq(values) - } - - override def close(): Unit = {} + override def get(): InternalRow = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current } + InternalRow.fromSeq(values) } } -class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { +class SchemaRequiredDataSource extends DataSourceV2 with ReadSupport { - class ReadSupport(val schema: StructType) extends SimpleReadSupport { - override def fullSchema(): StructType = schema - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = - Array.empty + class Reader(val readSchema: StructType) extends DataSourceReader { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = + java.util.Collections.emptyList() } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + override def createReader(options: DataSourceOptions): DataSourceReader = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def createBatchReadSupport( - schema: StructType, options: DataSourceOptions): BatchReadSupport = { - new ReadSupport(schema) + override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = { + new Reader(schema) } } -class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) - } + class Reader extends DataSourceReader with SupportsScanColumnarBatch { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - ColumnarReaderFactory + override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { + java.util.Arrays.asList( + new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -object ColumnarReaderFactory extends PartitionReaderFactory { - private final val BATCH_SIZE = 20 +class BatchInputPartitionReader(start: Int, end: Int) + extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { - override def supportColumnarReads(partition: InputPartition): Boolean = true + private final val BATCH_SIZE = 20 + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch(Array(i, j)) - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - throw new UnsupportedOperationException - } + private var current = start - override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { - val RangeInputPartition(start, end) = partition - new PartitionReader[ColumnarBatch] { - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch(Array(i, j)) - - private var current = start - - override def next(): Boolean = { - i.reset() - j.reset() - - var count = 0 - while (current < end && count < BATCH_SIZE) { - i.putInt(count, current) - j.putInt(count, -current) - current += 1 - count += 1 - } + override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this - if (count == 0) { - false - } else { - batch.setNumRows(count) - true - } - } + override def next(): Boolean = { + i.reset() + j.reset() - override def get(): ColumnarBatch = batch + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } - override def close(): Unit = batch.close() + if (count == 0) { + false + } else { + batch.setNumRows(count) + true } } + + override def get(): ColumnarBatch = { + batch + } + + override def close(): Unit = batch.close() } +class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { -class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { + class Reader extends DataSourceReader with SupportsReportPartitioning { + override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { // Note that we don't have same value of column `a` across partitions. - Array( - SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), - SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - SpecificReaderFactory + java.util.Arrays.asList( + new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) } - override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning + override def outputPartitioning(): Partitioning = new MyPartitioning } class MyPartitioning extends Partitioning { override def numPartitions(): Int = 2 override def satisfy(distribution: Distribution): Boolean = distribution match { - case c: ClusteredDistribution => c.clusteredColumns.contains("i") + case c: ClusteredDistribution => c.clusteredColumns.contains("a") case _ => false } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition +class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) + extends InputPartition[InternalRow] + with InputPartitionReader[InternalRow] { + assert(i.length == j.length) -object SpecificReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val p = partition.asInstanceOf[SpecificInputPartition] - new PartitionReader[InternalRow] { - private var current = -1 + private var current = -1 - override def next(): Boolean = { - current += 1 - current < p.i.length - } + override def createPartitionReader(): InputPartitionReader[InternalRow] = this - override def get(): InternalRow = InternalRow(p.i(current), p.j(current)) - - override def close(): Unit = {} - } + override def next(): Boolean = { + current += 1 + current < i.length } + + override def get(): InternalRow = InternalRow(i(current), j(current)) + + override def close(): Unit = {} } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 952241b..e1b8e9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,36 +18,34 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util.Optional +import java.util.{Collections, List => JList, Optional} import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration /** * A HDFS based transactional writable data source. - * Each task writes data to `target/_temporary/queryId/$jobId-$partitionId-$attemptNumber`. - * Each job moves files from `target/_temporary/queryId/` to `target`. + * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/jobId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 - with BatchReadSupportProvider with BatchWriteSupportProvider { +class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { private val schema = new StructType().add("i", "long").add("j", "long") - class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { + class Reader(path: String, conf: Configuration) extends DataSourceReader { + override def readSchema(): StructType = schema - override def fullSchema(): StructType = schema - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -55,23 +53,21 @@ class SimpleWritableDataSource extends DataSourceV2 val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.map { f => - CSVInputPartitionReader(f.getPath.toUri.toString) - }.toArray + val serializableConf = new SerializableConfiguration(conf) + new SimpleCSVInputPartitionReader( + f.getPath.toUri.toString, + serializableConf): InputPartition[InternalRow] + }.toList.asJava } else { - Array.empty + Collections.emptyList() } } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - val serializableConf = new SerializableConfiguration(conf) - new CSVReaderFactory(serializableConf) - } } - class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { - override def createBatchWriterFactory(): DataWriterFactory = { + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { SimpleCounter.resetCounter - new CSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -80,7 +76,7 @@ class SimpleWritableDataSource extends DataSourceV2 override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) - val jobPath = new Path(new Path(finalPath, "_temporary"), queryId) + val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) val fs = jobPath.getFileSystem(conf) try { for (file <- fs.listStatus(jobPath).map(_.getPath)) { @@ -95,23 +91,23 @@ class SimpleWritableDataSource extends DataSourceV2 } override def abort(messages: Array[WriterCommitMessage]): Unit = { - val jobPath = new Path(new Path(path, "_temporary"), queryId) + val jobPath = new Path(new Path(path, "_temporary"), jobId) val fs = jobPath.getFileSystem(conf) fs.delete(jobPath, true) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration - new ReadSupport(path.toUri.toString, conf) + new Reader(path.toUri.toString, conf) } - override def createBatchWriteSupport( - queryId: String, + override def createWriter( + jobId: String, schema: StructType, mode: SaveMode, - options: DataSourceOptions): Optional[BatchWriteSupport] = { + options: DataSourceOptions): Optional[DataSourceWriter] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -134,42 +130,39 @@ class SimpleWritableDataSource extends DataSourceV2 } val pathStr = path.toUri.toString - Optional.of(new WritSupport(queryId, pathStr, conf)) + Optional.of(new Writer(jobId, pathStr, conf)) } } -case class CSVInputPartitionReader(path: String) extends InputPartition +class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { -class CSVReaderFactory(conf: SerializableConfiguration) - extends PartitionReaderFactory { + @transient private var lines: Iterator[String] = _ + @transient private var currentLine: String = _ + @transient private var inputStream: FSDataInputStream = _ - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val path = partition.asInstanceOf[CSVInputPartitionReader].path + override def createPartitionReader(): InputPartitionReader[InternalRow] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) + inputStream = fs.open(filePath) + lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + this + } - new PartitionReader[InternalRow] { - private val inputStream = fs.open(filePath) - private val lines = new BufferedReader(new InputStreamReader(inputStream)) - .lines().iterator().asScala - - private var currentLine: String = _ - - override def next(): Boolean = { - if (lines.hasNext) { - currentLine = lines.next() - true - } else { - false - } - } + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } - override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) + override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) - override def close(): Unit = { - inputStream.close() - } - } + override def close(): Unit = { + inputStream.close() } } @@ -190,11 +183,12 @@ private[v2] object SimpleCounter { } class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory { + extends DataWriterFactory[InternalRow] { - override def createWriter( + override def createDataWriter( partitionId: Int, - taskId: Long): DataWriter[InternalRow] = { + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 491dc34..35644c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -686,7 +686,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.readSupport + case r: StreamingDataSourceV2Relation => r.reader } .zipWithIndex .find(_._1 == source) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index fe77a1b..0f15cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -299,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def latestOffset(): OffsetV2 = { + override def getEndOffset: OffsetV2 = { numTriggers += 1 - super.latestOffset() + super.getEndOffset } } val clock = new StreamManualClock() http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1dd8175..0278e2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.CountDownLatch import scala.collection.mutable @@ -30,12 +32,13 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -212,17 +215,25 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi private def dataAdded: Boolean = currentOffset.offset != -1 - // latestOffset should take 50 ms the first time it is called after data is added - override def latestOffset(): OffsetV2 = synchronized { - if (dataAdded) clock.waitTillTime(1050) - super.latestOffset() + // setOffsetRange should take 50 ms the first time it is called after data is added + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.setOffsetRange(start, end) + } + } + + // getEndOffset should take 100 ms the first time it is called after data is added + override def getEndOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1150) + super.getEndOffset() } // getBatch should take 100 ms the first time it is called - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { - clock.waitTillTime(1150) - super.planInputPartitions(config) + clock.waitTillTime(1350) + super.planInputPartitions() } } } @@ -263,26 +274,34 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress when `latestOffset` is being called + // Test status and progress when setOffsetRange is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on `latestOffset` + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(50), // time = 1050 to unblock `latestOffset` + AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange AssertClockTime(1050), - // will block on `planInputPartitions` that needs 1350 - AssertStreamExecThreadIsWaitingForTime(1150), + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 + AssertOnQuery(_.status.isDataAvailable === false), + AssertOnQuery(_.status.isTriggerActive === true), + AssertOnQuery(_.status.message.startsWith("Getting offsets from")), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), + + AdvanceManualClock(100), // time = 1150 to unblock getEndOffset + AssertClockTime(1150), + // will block on planInputPartitions that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1350), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(100), // time = 1150 to unblock `planInputPartitions` - AssertClockTime(1150), + AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions + AssertClockTime(1350), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), @@ -290,7 +309,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(350), // time = 1500 to unblock map task + AdvanceManualClock(150), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger @@ -310,10 +329,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("latestOffset") === 50) - assert(progress.durationMs.get("queryPlanning") === 100) + assert(progress.durationMs.get("setOffsetRange") === 50) + assert(progress.durationMs.get("getEndOffset") === 100) + assert(progress.durationMs.get("queryPlanning") === 200) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 350) + assert(progress.durationMs.get("addBatch") === 150) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index d6819ea..4f19881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -22,15 +22,16 @@ import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType} class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { case class LongPartitionOffset(offset: Long) extends PartitionOffset @@ -43,8 +44,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamingWriteSupport], - mock[ContinuousReadSupport], + mock[StreamWriter], + mock[ContinuousReader], mock[ContinuousExecution], coordinatorId, startEpoch, @@ -72,26 +73,26 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val partitionReader = new ContinuousPartitionReader[InternalRow] { - var index = -1 - var curr: UnsafeRow = _ - - override def next() = { - curr = queue.take() - index += 1 - true - } + val factory = new InputPartition[InternalRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } - override def get = curr + override def get = curr - override def getOffset = LongPartitionOffset(index) + override def getOffset = LongPartitionOffset(index) - override def close() = {} + override def close() = {} + } } val reader = new ContinuousQueuedDataReader( - 0, - partitionReader, - new StructType().add("i", "int"), + new ContinuousDataSourceRDDPartition(0, factory), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 3d21bc6..4980b0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 3c973d8..82836dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,20 +40,20 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writeSupport: StreamingWriteSupport = _ + private var writer: StreamWriter = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { - val reader = mock[ContinuousReadSupport] - writeSupport = mock[StreamingWriteSupport] + val reader = mock[ContinuousReader] + writer = mock[StreamWriter] query = mock[ContinuousExecution] - orderVerifier = inOrder(writeSupport, query) + orderVerifier = inOrder(writer, query) spark = new TestSparkSession() epochCoordinator - = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get) + = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) } test("single epoch") { @@ -209,12 +209,12 @@ class EpochCoordinatorSuite } private def verifyCommit(epoch: Long): Unit = { - orderVerifier.verify(writeSupport).commit(eqTo(epoch), any()) + orderVerifier.verify(writer).commit(eqTo(epoch), any()) orderVerifier.verify(query).commit(epoch) } private def verifyNoCommitFor(epoch: Long): Unit = { - verify(writeSupport, never()).commit(eqTo(epoch), any()) + verify(writer, never()).commit(eqTo(epoch), any()) verify(query, never()).commit(epoch) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org