Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/20688#discussion_r172729894 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateSourceProvider.scala --- @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset} +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.util.{ManualClock, SystemClock} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} + +class RateSourceProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateSourceProvider._ + + private def checkParameters(options: DataSourceOptions): Unit = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + } + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + checkParameters(options) + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) --- End diff -- split this into a separate file.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org