cloud-fan commented on code in PR #55828:
URL: https://github.com/apache/spark/pull/55828#discussion_r3240427223
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala:
##########
@@ -247,6 +247,34 @@ case class DecimalAddNoOverflowCheck(
copy(left = newLeft, right = newRight)
}
+/**
+ * A subtract expression for decimal values which is only used internally.
+ *
+ * Note that, this expression does not check overflow which is different from
`Subtract`.
+ */
Review Comment:
(Late catch.) Compared to the sibling `DecimalAddNoOverflowCheck` Scaladoc
just above, this one is missing the rationale for *why* skipping the overflow
check is safe. The Add version explains the UnsafeRowWriter safety net; the new
Subtract class has a different safety argument (callers pre-filter `left >=
right >= 0`), and right now that argument only lives at the call site in
`CounterDiff.scala`. A reader landing on this class in isolation can't tell why
it's sound. Worth expanding, e.g.:
```suggestion
/**
* A subtract expression for decimal values which is only used internally.
*
* Note that, this expression does not check overflow which is different
from `Subtract`.
* It is the caller's responsibility to ensure that the result fits in the
declared
* precision and scale. For example, `counter_diff` only invokes this on
operands that
* have already been validated to satisfy `left >= right >= 0`, so the
result is
* non-negative and bounded above by `left`.
*/
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CounterDiff.scala:
##########
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature,
InputParameter}
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.types._
+
+/**
+ * The counter_diff window function computes the differences between
consecutive cumulative counter
+ * values in a time series, thereby converting the counter from the cumulative
to the delta format.
+ *
+ * This class serves as the base class for the two versions of the
counter_diff function:
+ * - counter_diff(counter) -> CounterDiff(counter)
+ * - counter_diff(counter, start_time) -> CounterDiffWithStartTime(counter,
startTime)
+ */
+abstract class CounterDiffBase(val counter: Expression)
+ extends AggregateWindowFunction
+ with QueryErrorsBase {
+
+ override def prettyName: String = "counter_diff"
+
+ override def dataType: DataType = counter.dataType
+
+ /**
+ * Last non-NULL counter value from a previous row.
+ */
+ protected lazy val prevCounter: AttributeReference =
+ AttributeReference("prevCounter", counter.dataType, nullable = true)()
+
+ /**
+ * Counter value from the current row.
+ */
+ protected lazy val currCounter: AttributeReference =
+ AttributeReference("currCounter", counter.dataType, nullable = true)()
+
+ /**
+ * Null literal used as a counter_diff result, when appropriate.
+ */
+ protected lazy val nullResult: Expression = Literal.create(null,
counter.dataType)
+
+ /**
+ * Difference between the current and previous counter values.
+ */
+ protected lazy val diff: Expression = {
+ counter.dataType match {
+ // For DECIMAL, subtraction typically widens the result type to handle
possible overflow.
+ // For counter_diff, since counters cannot be negative, there is no risk
of overflow, and no
+ // need to widen the result type, so we subtract directly in the input
type.
+ case dt: DecimalType => DecimalSubtractNoOverflowCheck(currCounter,
prevCounter, dt)
+ case _ => currCounter - prevCounter
+ }
+ }
+
+ /**
+ * Returns the difference, unless the counter has decreased, which is
treated as a counter reset.
+ * In this case, NULL is returned.
+ */
+ protected lazy val diffWithCounterDecreaseCheck: Expression =
+ If(currCounter < prevCounter, nullResult, diff)
+
+ /**
+ * Error raised when the counter is negative.
+ */
+ protected lazy val negativeCounterError: Expression = RaiseError(
+ Literal("COUNTER_DIFF_NEGATIVE_COUNTER_VALUE"),
+ CreateMap(
+ Seq(
+ Literal("value"),
+ Cast(currCounter, StringType),
+ Literal("function"),
+ Literal(toSQLId("counter_diff"))
+ )
+ ),
+ counter.dataType
+ )
+
+ /**
+ * Wraps `inner` with the "skip row on NULL counter" and "raise error on
negative counter" checks.
+ */
+ protected def withCounterNullAndNegativeChecks(inner: Expression):
Expression = {
+ If(IsNull(currCounter),
+ nullResult,
+ If(currCounter < Literal.default(counter.dataType),
+ negativeCounterError,
+ inner
+ )
+ )
+ }
+}
+
+/**
+ * The single-parameter form of `counter_diff`: `counter_diff(value)`.
+ * Detects counter resets only when the counter value decreases.
+ */
+case class CounterDiff(override val counter: Expression)
+ extends CounterDiffBase(counter)
+ with ExpectsInputTypes {
+
+ override def children: Seq[Expression] = Seq(counter)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+ /**
+ * The aggregation state attributes for the counter_diff function.
+ * In the single-parameter form, there are two attributes:
+ * - prevCounter: The last non-NULL counter value from a previous row.
+ * - currCounter: The counter value from the current row.
+ */
+ override lazy val aggBufferAttributes: Seq[AttributeReference] =
+ Seq(prevCounter, currCounter)
+
+ /**
+ * The initial aggregation state for the counter_diff function. Initial
values are NULL.
+ */
+ override lazy val initialValues: Seq[Expression] = Seq(
+ Literal.create(null, counter.dataType),
+ Literal.create(null, counter.dataType)
+ )
+
+ /**
+ * The update expressions for the counter_diff function's aggregation state.
+ *
+ * Fundamentally, the current value becomes the previous value, and the new
value becomes the
+ * current value.
+ *
+ * Rows with NULL counter values should be skipped. As a result, the
previous counter value
+ * should not be updated in the aggregation state.
+ */
+ override lazy val updateExpressions: Seq[Expression] = Seq(
+ If(IsNotNull(currCounter), currCounter, prevCounter),
+ counter
+ )
+
+ /**
+ * The evaluation expression for the counter_diff function.
+ *
+ * Checks for edge cases first: NULL counter value, negative counter value
and counter reset.
+ * Otherwise, returns the difference between the current and previous
counter values.
+ */
+ override lazy val evaluateExpression: Expression =
+ withCounterNullAndNegativeChecks(diffWithCounterDecreaseCheck)
+
+ /**
+ * The SQL representation of the single-parameter form of the counter_diff
function.
+ */
+ override def sql: String = s"${prettyName}(${counter.sql})"
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): CounterDiff =
+ copy(counter = newChildren.head)
+}
+
+/**
+ * The two-parameter form of `counter_diff`: `counter_diff(value, start_time)`.
+ * Additionally checks for counter resets when `start_time` increases, which
signals a new start.
+ * Requires that the start time doesn't decrease, which would indicate moving
backwards in time.
+ */
+case class CounterDiffWithStartTime(
+ override val counter: Expression,
+ startTime: Expression,
+ timeZoneId: Option[String] = None)
+ extends CounterDiffBase(counter)
+ with ExpectsInputTypes
+ with TimeZoneAwareExpression {
+
+ override def withTimeZone(timeZoneId: String): CounterDiffWithStartTime =
+ copy(timeZoneId = Some(timeZoneId))
+
+ override def children: Seq[Expression] = Seq(counter, startTime)
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(NumericType, TypeCollection(TimestampType, TimestampNTZType))
+
+ /**
+ * The start time from a previous row.
+ */
+ protected lazy val prevStartTime: AttributeReference =
+ AttributeReference("prevStartTime", startTime.dataType, nullable = true)()
+
+ /**
+ * The start time from the current row.
+ */
+ protected lazy val currStartTime: AttributeReference =
+ AttributeReference("currStartTime", startTime.dataType, nullable = true)()
+
+ /**
+ * The aggregation state attributes for the counter_diff function.
+ * In the two-parameter form, there are four attributes:
+ * - prevCounter: The last non-NULL counter value from a previous row.
+ * - currCounter: The counter value from the current row.
+ * - prevStartTime: The start time from a previous row.
+ * - currStartTime: The start time from the current row.
+ */
+ override lazy val aggBufferAttributes: Seq[AttributeReference] =
+ Seq(prevCounter, currCounter, prevStartTime, currStartTime)
+
+ /**
+ * The initial aggregation state for the counter_diff function. Initial
values are NULL.
+ */
+ override lazy val initialValues: Seq[Expression] = Seq(
+ Literal.create(null, counter.dataType),
+ Literal.create(null, counter.dataType),
+ Literal.create(null, startTime.dataType),
+ Literal.create(null, startTime.dataType)
+ )
+
+ /**
+ * The update expressions for the counter_diff function's aggregation state.
+ *
+ * Fundamentally, the current value becomes the previous value, and the new
value becomes the
+ * current value. The same applies to the start time.
+ *
+ * Rows with NULL counter values should be skipped. As a result, the
previous values for both
+ * the counter and start time should not be updated in the aggregation state.
+ */
+ override lazy val updateExpressions: Seq[Expression] = Seq(
+ If(IsNotNull(currCounter), currCounter, prevCounter),
+ counter,
+ If(IsNotNull(currCounter), currStartTime, prevStartTime),
+ startTime
+ )
+
+ /**
+ * Error raised when the start time decreases.
+ */
+ protected lazy val decreasedStartTimeError: Expression = RaiseError(
+ Literal("COUNTER_DIFF_START_TIME_DECREASED"),
+ CreateMap(
+ Seq(
+ Literal("function"),
+ Literal(toSQLId("counter_diff")),
+ Literal("previousStartTime"),
+ Cast(prevStartTime, StringType, timeZoneId),
+ Literal("currentStartTime"),
+ Cast(currStartTime, StringType, timeZoneId)
+ )
+ ),
+ counter.dataType
+ )
+
+ /**
+ * The evaluation expression for the counter_diff function.
+ *
+ * Checks for edge cases first: NULL counter value, negative counter value,
start time decrease
+ * and counter resets.
+ *
+ * Otherwise, returns the difference between the current and previous
counter values.
+ */
+ override lazy val evaluateExpression: Expression =
withCounterNullAndNegativeChecks {
+ If(currStartTime < prevStartTime,
+ decreasedStartTimeError,
+ If(prevStartTime < currStartTime,
+ nullResult,
+ diffWithCounterDecreaseCheck
+ )
+ )
+ }
+
+ /**
+ * The SQL representation of the two-parameter form of the counter_diff
function.
+ */
+ override def sql: String =
+ s"${prettyName}(${counter.sql}, ${startTime.sql})"
Review Comment:
Same as the 1-arg form above — `children = Seq(counter, startTime)` so the
default `Expression.sql` already produces `counter_diff(<counter.sql>,
<startTime.sql>)`. `timeZoneId` is not a child, but it doesn't appear in this
SQL string anyway. The override is redundant.
```suggestion
```
##########
python/pyspark/sql/tests/test_functions.py:
##########
@@ -1954,6 +1954,35 @@ def test_window_functions(self):
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[: len(r)])
+ def test_counter_diff_window_function(self):
+ df = self.spark.createDataFrame(
+ [
+ (1, datetime.datetime(2024, 1, 1), 100),
+ (2, datetime.datetime(2024, 1, 1), 200),
+ (3, datetime.datetime(2024, 1, 1), 400),
+ (4, datetime.datetime(2024, 1, 2), 50),
+ (5, datetime.datetime(2024, 1, 2), 150),
+ ],
+ ["t", "st", "c"],
+ )
+ w = Window.orderBy("t")
+
+ rows = df.select("t",
F.counter_diff("c").over(w).alias("d")).orderBy("t").collect()
+ self.assertEqual(
+ [(r.t, r.d) for r in rows],
+ [(1, None), (2, 100), (3, 200), (4, None), (5, 100)],
+ )
+
+ rows = (
+ df.select("t", F.counter_diff("c",
startTime="st").over(w).alias("d"))
+ .orderBy("t")
+ .collect()
+ )
+ self.assertEqual(
+ [(r.t, r.d) for r in rows],
+ [(1, None), (2, 100), (3, 200), (4, None), (5, 100)],
+ )
Review Comment:
(Late catch.) This 2-arg branch doesn't actually exercise the 2-arg form. At
row 4 (`2024-01-02`, `c=50`) the start_time advances AND the counter decreases,
so both reset paths fire together — the 2-arg form returns NULL at row 4
because of either condition, and the expected output is byte-identical to the
1-arg form's expected output above. If you removed all start_time logic from
`CounterDiffWithStartTime`, this test would still pass.
Compare to `DataFrameWindowFunctionsSuite.scala:835-857`, where the two
forms produce different outputs (`Row(2, null)` only in the 2-arg form because
the start_time advances between rows 1 and 2 while the counter is still
increasing) — that test does discriminate.
Suggestion: shift the start_time advance to a row where the counter does not
decrease, so the 2-arg form returns NULL while the 1-arg form returns a
positive diff:
```suggestion
rows = (
df.select("t", F.counter_diff("c",
startTime="st").over(w).alias("d"))
.orderBy("t")
.collect()
)
self.assertEqual(
[(r.t, r.d) for r in rows],
[(1, None), (2, 100), (3, None), (4, None), (5, 100)],
)
```
...combined with shifting row 3's `st` to `datetime.datetime(2024, 1, 2)` in
the data fixture so the start_time advance happens between rows 2 and 3 (while
the counter is still increasing from 200 to 400).
`test_connect_function.py::test_window_functions` separately covers
Connect↔Spark equivalence on a fuller dataset, but the unit test for the
function itself should still be the place that distinguishes the two forms.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CounterDiff.scala:
##########
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature,
InputParameter}
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.types._
+
+/**
+ * The counter_diff window function computes the differences between
consecutive cumulative counter
+ * values in a time series, thereby converting the counter from the cumulative
to the delta format.
+ *
+ * This class serves as the base class for the two versions of the
counter_diff function:
+ * - counter_diff(counter) -> CounterDiff(counter)
+ * - counter_diff(counter, start_time) -> CounterDiffWithStartTime(counter,
startTime)
+ */
+abstract class CounterDiffBase(val counter: Expression)
+ extends AggregateWindowFunction
+ with QueryErrorsBase {
+
+ override def prettyName: String = "counter_diff"
+
+ override def dataType: DataType = counter.dataType
+
+ /**
+ * Last non-NULL counter value from a previous row.
+ */
+ protected lazy val prevCounter: AttributeReference =
+ AttributeReference("prevCounter", counter.dataType, nullable = true)()
+
+ /**
+ * Counter value from the current row.
+ */
+ protected lazy val currCounter: AttributeReference =
+ AttributeReference("currCounter", counter.dataType, nullable = true)()
+
+ /**
+ * Null literal used as a counter_diff result, when appropriate.
+ */
+ protected lazy val nullResult: Expression = Literal.create(null,
counter.dataType)
+
+ /**
+ * Difference between the current and previous counter values.
+ */
+ protected lazy val diff: Expression = {
+ counter.dataType match {
+ // For DECIMAL, subtraction typically widens the result type to handle
possible overflow.
+ // For counter_diff, since counters cannot be negative, there is no risk
of overflow, and no
+ // need to widen the result type, so we subtract directly in the input
type.
+ case dt: DecimalType => DecimalSubtractNoOverflowCheck(currCounter,
prevCounter, dt)
+ case _ => currCounter - prevCounter
+ }
+ }
+
+ /**
+ * Returns the difference, unless the counter has decreased, which is
treated as a counter reset.
+ * In this case, NULL is returned.
+ */
+ protected lazy val diffWithCounterDecreaseCheck: Expression =
+ If(currCounter < prevCounter, nullResult, diff)
+
+ /**
+ * Error raised when the counter is negative.
+ */
+ protected lazy val negativeCounterError: Expression = RaiseError(
+ Literal("COUNTER_DIFF_NEGATIVE_COUNTER_VALUE"),
+ CreateMap(
+ Seq(
+ Literal("value"),
+ Cast(currCounter, StringType),
+ Literal("function"),
+ Literal(toSQLId("counter_diff"))
+ )
+ ),
+ counter.dataType
+ )
+
+ /**
+ * Wraps `inner` with the "skip row on NULL counter" and "raise error on
negative counter" checks.
+ */
+ protected def withCounterNullAndNegativeChecks(inner: Expression):
Expression = {
+ If(IsNull(currCounter),
+ nullResult,
+ If(currCounter < Literal.default(counter.dataType),
+ negativeCounterError,
+ inner
+ )
+ )
+ }
+}
+
+/**
+ * The single-parameter form of `counter_diff`: `counter_diff(value)`.
+ * Detects counter resets only when the counter value decreases.
+ */
+case class CounterDiff(override val counter: Expression)
+ extends CounterDiffBase(counter)
+ with ExpectsInputTypes {
+
+ override def children: Seq[Expression] = Seq(counter)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+ /**
+ * The aggregation state attributes for the counter_diff function.
+ * In the single-parameter form, there are two attributes:
+ * - prevCounter: The last non-NULL counter value from a previous row.
+ * - currCounter: The counter value from the current row.
+ */
+ override lazy val aggBufferAttributes: Seq[AttributeReference] =
+ Seq(prevCounter, currCounter)
+
+ /**
+ * The initial aggregation state for the counter_diff function. Initial
values are NULL.
+ */
+ override lazy val initialValues: Seq[Expression] = Seq(
+ Literal.create(null, counter.dataType),
+ Literal.create(null, counter.dataType)
+ )
+
+ /**
+ * The update expressions for the counter_diff function's aggregation state.
+ *
+ * Fundamentally, the current value becomes the previous value, and the new
value becomes the
+ * current value.
+ *
+ * Rows with NULL counter values should be skipped. As a result, the
previous counter value
+ * should not be updated in the aggregation state.
+ */
+ override lazy val updateExpressions: Seq[Expression] = Seq(
+ If(IsNotNull(currCounter), currCounter, prevCounter),
+ counter
+ )
+
+ /**
+ * The evaluation expression for the counter_diff function.
+ *
+ * Checks for edge cases first: NULL counter value, negative counter value
and counter reset.
+ * Otherwise, returns the difference between the current and previous
counter values.
+ */
+ override lazy val evaluateExpression: Expression =
+ withCounterNullAndNegativeChecks(diffWithCounterDecreaseCheck)
+
+ /**
+ * The SQL representation of the single-parameter form of the counter_diff
function.
+ */
+ override def sql: String = s"${prettyName}(${counter.sql})"
Review Comment:
(Late catch — apologies for not flagging this last round.) This override
produces the same string as the default `Expression.sql` (which returns
`${prettyName}(${children.map(_.sql).mkString(", ")})`). Since `children =
Seq(counter)`, the default already yields `counter_diff(<counter.sql>)`. Other
window functions only override `sql` when they need to inject non-child state —
e.g. `EWM` exposes `$alpha`/`$ignoreNA`, `NthValue` appends `" ignore nulls"`.
counter_diff has no such extras; this line can be dropped.
```suggestion
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]