xuanyuanking commented on a change in pull request #34333:
URL: https://github.com/apache/spark/pull/34333#discussion_r736719970



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProvider.scala
##########
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.util
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, 
TableCapability}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
+import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, 
MicroBatchStream}
+import org.apache.spark.sql.internal.connector.SimpleTableProvider
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.types.{LongType, StructField, StructType, 
TimestampType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ *  A source that generates increment long values with timestamps. Each 
generated row has two
+ *  columns: a timestamp column for the generated time and an auto increment 
long column starting
+ *  with 0L.
+ *
+ *  This source supports the following options:
+ *  - `rowsPerMicroBatch` (e.g. 100): How many rows should be generated per 
micro-batch.
+ *  - `numPartitions` (e.g. 10, default: Spark's default parallelism): The 
partition number for the
+ *    generated rows.
+ *  - `startTimestamp` (e.g. 1000, default: 0): starting value of generated 
time
+ *  - `advanceMillisPerMicroBatch` (e.g. 1000, default: 1000): the amount of 
time being advanced in
+ *    generated time on each micro-batch.
+ *
+ *  Unlike `rate` data source, this data source provides a consistent set of 
input rows per
+ *  micro-batch regardless of query execution (configuration of trigger, query 
being lagging, etc.),
+ *  say, batch 0 will produce 0~999 and batch 1 will produce 1000~1999, and so 
on. Same applies to
+ *  the generated time.
+ *
+ *  As the name represents, this data source only supports micro-batch read.
+ */
+class RatePerMicroBatchProvider extends SimpleTableProvider with 
DataSourceRegister {
+  import RatePerMicroBatchProvider._
+
+  override def getTable(options: CaseInsensitiveStringMap): Table = {
+    val rowsPerBatch = options.getLong(ROWS_PER_BATCH, 0)
+    if (rowsPerBatch <= 0) {
+      throw new IllegalArgumentException(
+        s"Invalid value '$rowsPerBatch'. The option 'rowsPerBatch' must be 
positive")
+    }
+
+    val numPartitions = options.getInt(
+      NUM_PARTITIONS, SparkSession.active.sparkContext.defaultParallelism)
+    if (numPartitions <= 0) {
+      throw new IllegalArgumentException(
+        s"Invalid value '$numPartitions'. The option 'numPartitions' must be 
positive")
+    }
+
+    val startTimestamp = options.getLong(START_TIMESTAMP, 0)
+    if (startTimestamp < 0) {
+      throw new IllegalArgumentException(
+        s"Invalid value '$startTimestamp'. The option 'startTimestamp' must be 
non-negative")
+    }
+
+    val advanceMillisPerBatch = options.getInt(ADVANCE_MILLIS_PER_BATCH, 1000)
+    if (advanceMillisPerBatch < 0) {
+      throw new IllegalArgumentException(
+        s"Invalid value '$advanceMillisPerBatch'. The option 
'advanceMillisPerBatch' " +
+          "must be non-negative")
+    }
+
+    new RatePerMicroBatchTable(rowsPerBatch, numPartitions, startTimestamp,
+      advanceMillisPerBatch)
+  }
+
+  override def shortName(): String = "rate-micro-batch"
+}
+
+class RatePerMicroBatchTable(
+    rowsPerBatch: Long,
+    numPartitions: Int,
+    startTimestamp: Long,
+    advanceMillisPerBatch: Int) extends Table with SupportsRead {
+  override def name(): String = {
+    s"RatePerEpoch(rowsPerBatch=$rowsPerBatch, numPartitions=$numPartitions," +

Review comment:
       nit: RatePerMicroBatch?

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RatePerMicroBatchProviderSuite.scala
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.functions.spark_partition_id
+import org.apache.spark.sql.streaming.{StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class RatePerMicroBatchProviderSuite extends StreamTest {
+
+  import testImplicits._
+
+  test("RatePerMicroBatchProvider in registry") {
+    val ds = DataSource.lookupDataSource("rate-micro-batch", 
spark.sqlContext.conf).newInstance()
+    assert(ds.isInstanceOf[RatePerMicroBatchProvider], "Could not find 
rate-micro-batch source")
+  }
+
+  test("basic") {
+    val input = spark.readStream
+      .format("rate-micro-batch")
+      .option("rowsPerBatch", "10")
+      .option("startTimestamp", "1000")
+      .option("advanceMillisPerBatch", "50")
+      .load()
+    val clock = new StreamManualClock
+    testStream(input)(
+      StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock),
+      AdvanceManualClock(10),
+      CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(1000L) -> 
v): _*),
+      AdvanceManualClock(10),
+      CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(1050L) -> 
v): _*),
+      AdvanceManualClock(10),
+      CheckLastBatch((20 until 30).map(v => new java.sql.Timestamp(1100L) -> 
v): _*)
+    )
+  }
+
+  test("restart") {
+    withTempDir { dir =>
+      val input = spark.readStream
+        .format("rate-micro-batch")
+        .option("rowsPerBatch", "10")
+        .load()
+        .select('value)
+
+      testStream(input)(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        Execute(_.awaitOffset(0, RatePerMicroBatchStreamOffset(20, 2000),
+          streamingTimeout.toMillis)),
+        CheckAnswer(0 until 20: _*),
+        StopStream
+      )
+
+      testStream(input)(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        Execute(_.awaitOffset(0, RatePerMicroBatchStreamOffset(40, 4000),
+          streamingTimeout.toMillis)),
+        CheckAnswer(20 until 40: _*)
+      )
+    }
+  }
+
+  test("numPartitions") {
+    val input = spark.readStream
+      .format("rate-micro-batch")
+      .option("rowsPerBatch", "10")
+      .option("numPartitions", "6")
+      .load()
+      .select(spark_partition_id())
+      .distinct()
+    val clock = new StreamManualClock
+    testStream(input)(
+      StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock),
+      AdvanceManualClock(10),
+      CheckLastBatch(0 until 6: _*)
+    )
+  }
+
+  testQuietly("illegal option values") {
+    def testIllegalOptionValue(
+        option: String,
+        value: String,
+        expectedMessages: Seq[String]): Unit = {
+      val e = intercept[IllegalArgumentException] {
+        var stream = spark.readStream
+          .format("rate-micro-batch")
+          .option(option, value)
+
+        if (option != "rowsPerBatch") {
+          stream = stream.option("rowsPerBatch", "1")
+        }
+
+        stream.load()
+          .writeStream
+          .format("console")
+          .start()
+          .awaitTermination()
+      }
+      for (msg <- expectedMessages) {
+        assert(e.getMessage.contains(msg))
+      }
+    }
+
+    testIllegalOptionValue("rowsPerBatch", "-1", Seq("-1", "rowsPerBatch", 
"positive"))
+    testIllegalOptionValue("rowsPerBatch", "0", Seq("0", "rowsPerBatch", 
"positive"))
+    testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", 
"positive"))
+    testIllegalOptionValue("numPartitions", "0", Seq("0", "numPartitions", 
"positive"))
+
+    // RatePerEpochProvider allows setting below options to 0

Review comment:
       ditto




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to