xuanyuanking commented on a change in pull request #34333: URL: https://github.com/apache/spark/pull/34333#discussion_r736719970
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProvider.scala ########## @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} +import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} +import org.apache.spark.sql.internal.connector.SimpleTableProvider +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerMicroBatch` (e.g. 100): How many rows should be generated per micro-batch. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. + * - `startTimestamp` (e.g. 1000, default: 0): starting value of generated time + * - `advanceMillisPerMicroBatch` (e.g. 1000, default: 1000): the amount of time being advanced in + * generated time on each micro-batch. + * + * Unlike `rate` data source, this data source provides a consistent set of input rows per + * micro-batch regardless of query execution (configuration of trigger, query being lagging, etc.), + * say, batch 0 will produce 0~999 and batch 1 will produce 1000~1999, and so on. Same applies to + * the generated time. + * + * As the name represents, this data source only supports micro-batch read. + */ +class RatePerMicroBatchProvider extends SimpleTableProvider with DataSourceRegister { + import RatePerMicroBatchProvider._ + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val rowsPerBatch = options.getLong(ROWS_PER_BATCH, 0) + if (rowsPerBatch <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerBatch'. The option 'rowsPerBatch' must be positive") + } + + val numPartitions = options.getInt( + NUM_PARTITIONS, SparkSession.active.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + + val startTimestamp = options.getLong(START_TIMESTAMP, 0) + if (startTimestamp < 0) { + throw new IllegalArgumentException( + s"Invalid value '$startTimestamp'. The option 'startTimestamp' must be non-negative") + } + + val advanceMillisPerBatch = options.getInt(ADVANCE_MILLIS_PER_BATCH, 1000) + if (advanceMillisPerBatch < 0) { + throw new IllegalArgumentException( + s"Invalid value '$advanceMillisPerBatch'. The option 'advanceMillisPerBatch' " + + "must be non-negative") + } + + new RatePerMicroBatchTable(rowsPerBatch, numPartitions, startTimestamp, + advanceMillisPerBatch) + } + + override def shortName(): String = "rate-micro-batch" +} + +class RatePerMicroBatchTable( + rowsPerBatch: Long, + numPartitions: Int, + startTimestamp: Long, + advanceMillisPerBatch: Int) extends Table with SupportsRead { + override def name(): String = { + s"RatePerEpoch(rowsPerBatch=$rowsPerBatch, numPartitions=$numPartitions," + Review comment: nit: RatePerMicroBatch? ########## File path: sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala ########## @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.functions.spark_partition_id +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.streaming.util.StreamManualClock + +class RatePerMicroBatchProviderSuite extends StreamTest { + + import testImplicits._ + + test("RatePerMicroBatchProvider in registry") { + val ds = DataSource.lookupDataSource("rate-micro-batch", spark.sqlContext.conf).newInstance() + assert(ds.isInstanceOf[RatePerMicroBatchProvider], "Could not find rate-micro-batch source") + } + + test("basic") { + val input = spark.readStream + .format("rate-micro-batch") + .option("rowsPerBatch", "10") + .option("startTimestamp", "1000") + .option("advanceMillisPerBatch", "50") + .load() + val clock = new StreamManualClock + testStream(input)( + StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock), + AdvanceManualClock(10), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(1000L) -> v): _*), + AdvanceManualClock(10), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(1050L) -> v): _*), + AdvanceManualClock(10), + CheckLastBatch((20 until 30).map(v => new java.sql.Timestamp(1100L) -> v): _*) + ) + } + + test("restart") { + withTempDir { dir => + val input = spark.readStream + .format("rate-micro-batch") + .option("rowsPerBatch", "10") + .load() + .select('value) + + testStream(input)( + StartStream(checkpointLocation = dir.getAbsolutePath), + Execute(_.awaitOffset(0, RatePerMicroBatchStreamOffset(20, 2000), + streamingTimeout.toMillis)), + CheckAnswer(0 until 20: _*), + StopStream + ) + + testStream(input)( + StartStream(checkpointLocation = dir.getAbsolutePath), + Execute(_.awaitOffset(0, RatePerMicroBatchStreamOffset(40, 4000), + streamingTimeout.toMillis)), + CheckAnswer(20 until 40: _*) + ) + } + } + + test("numPartitions") { + val input = spark.readStream + .format("rate-micro-batch") + .option("rowsPerBatch", "10") + .option("numPartitions", "6") + .load() + .select(spark_partition_id()) + .distinct() + val clock = new StreamManualClock + testStream(input)( + StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock), + AdvanceManualClock(10), + CheckLastBatch(0 until 6: _*) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[IllegalArgumentException] { + var stream = spark.readStream + .format("rate-micro-batch") + .option(option, value) + + if (option != "rowsPerBatch") { + stream = stream.option("rowsPerBatch", "1") + } + + stream.load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + for (msg <- expectedMessages) { + assert(e.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerBatch", "-1", Seq("-1", "rowsPerBatch", "positive")) + testIllegalOptionValue("rowsPerBatch", "0", Seq("0", "rowsPerBatch", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + testIllegalOptionValue("numPartitions", "0", Seq("0", "numPartitions", "positive")) + + // RatePerEpochProvider allows setting below options to 0 Review comment: ditto -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
