Github user xuanyuanking commented on a diff in the pull request:
https://github.com/apache/spark/pull/21194#discussion_r185252544
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
---
@@ -101,25 +101,10 @@ object RateStreamProvider {
/** Calculate the end value we will emit at the time `seconds`. */
def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds:
Long): Long = {
- // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
- // Then speedDeltaPerSecond = 2
- //
- // seconds = 0 1 2 3 4 5 6
- // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
- // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) *
(seconds + 1) / 2
- val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
- if (seconds <= rampUpTimeSeconds) {
- // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) /
2" in a special way to
- // avoid overflow
- if (seconds % 2 == 1) {
- (seconds + 1) / 2 * speedDeltaPerSecond * seconds
- } else {
- seconds / 2 * speedDeltaPerSecond * (seconds + 1)
- }
- } else {
- // rampUpPart is just a special case of the above formula:
rampUpTimeSeconds == seconds
- val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond,
rampUpTimeSeconds)
- rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
- }
+ val delta = rowsPerSecond.toDouble / rampUpTimeSeconds
+ val rampUpSeconds = if (seconds <= rampUpTimeSeconds) seconds else
rampUpTimeSeconds
+ val afterRampUpSeconds = if (seconds > rampUpTimeSeconds ) seconds -
rampUpTimeSeconds else 0
+ // Use classic distance formula based on accelaration: ut + ½at2
+ Math.floor(rampUpSeconds * rampUpSeconds * delta / 2).toLong +
afterRampUpSeconds * rowsPerSecond
--- End diff --
nit: >100 characters
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]