pnikic-db commented on code in PR #55828: URL: https://github.com/apache/spark/pull/55828#discussion_r3241431525
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CounterDiff.scala: ########## @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} +import org.apache.spark.sql.errors.QueryErrorsBase +import org.apache.spark.sql.types._ + +/** + * The counter_diff window function computes the differences between consecutive cumulative counter + * values in a time series, thereby converting the counter from the cumulative to the delta format. + * + * This class serves as the base class for the two versions of the counter_diff function: + * - counter_diff(counter) -> CounterDiff(counter) + * - counter_diff(counter, start_time) -> CounterDiffWithStartTime(counter, startTime) + */ +abstract class CounterDiffBase(val counter: Expression) + extends AggregateWindowFunction + with QueryErrorsBase { + + override def prettyName: String = "counter_diff" + + override def dataType: DataType = counter.dataType + + /** + * Last non-NULL counter value from a previous row. + */ + protected lazy val prevCounter: AttributeReference = + AttributeReference("prevCounter", counter.dataType, nullable = true)() + + /** + * Counter value from the current row. + */ + protected lazy val currCounter: AttributeReference = + AttributeReference("currCounter", counter.dataType, nullable = true)() + + /** + * Null literal used as a counter_diff result, when appropriate. + */ + protected lazy val nullResult: Expression = Literal.create(null, counter.dataType) + + /** + * Difference between the current and previous counter values. + */ + protected lazy val diff: Expression = { + counter.dataType match { + // For DECIMAL, subtraction typically widens the result type to handle possible overflow. + // For counter_diff, since counters cannot be negative, there is no risk of overflow, and no + // need to widen the result type, so we subtract directly in the input type. + case dt: DecimalType => DecimalSubtractNoOverflowCheck(currCounter, prevCounter, dt) + case _ => currCounter - prevCounter + } + } + + /** + * Returns the difference, unless the counter has decreased, which is treated as a counter reset. + * In this case, NULL is returned. + */ + protected lazy val diffWithCounterDecreaseCheck: Expression = + If(currCounter < prevCounter, nullResult, diff) + + /** + * Error raised when the counter is negative. + */ + protected lazy val negativeCounterError: Expression = RaiseError( + Literal("COUNTER_DIFF_NEGATIVE_COUNTER_VALUE"), + CreateMap( + Seq( + Literal("value"), + Cast(currCounter, StringType), + Literal("function"), + Literal(toSQLId("counter_diff")) + ) + ), + counter.dataType + ) + + /** + * Wraps `inner` with the "skip row on NULL counter" and "raise error on negative counter" checks. + */ + protected def withCounterNullAndNegativeChecks(inner: Expression): Expression = { + If(IsNull(currCounter), + nullResult, + If(currCounter < Literal.default(counter.dataType), + negativeCounterError, + inner + ) + ) + } +} + +/** + * The single-parameter form of `counter_diff`: `counter_diff(value)`. + * Detects counter resets only when the counter value decreases. + */ +case class CounterDiff(override val counter: Expression) + extends CounterDiffBase(counter) + with ExpectsInputTypes { + + override def children: Seq[Expression] = Seq(counter) + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + /** + * The aggregation state attributes for the counter_diff function. + * In the single-parameter form, there are two attributes: + * - prevCounter: The last non-NULL counter value from a previous row. + * - currCounter: The counter value from the current row. + */ + override lazy val aggBufferAttributes: Seq[AttributeReference] = + Seq(prevCounter, currCounter) + + /** + * The initial aggregation state for the counter_diff function. Initial values are NULL. + */ + override lazy val initialValues: Seq[Expression] = Seq( + Literal.create(null, counter.dataType), + Literal.create(null, counter.dataType) + ) + + /** + * The update expressions for the counter_diff function's aggregation state. + * + * Fundamentally, the current value becomes the previous value, and the new value becomes the + * current value. + * + * Rows with NULL counter values should be skipped. As a result, the previous counter value + * should not be updated in the aggregation state. + */ + override lazy val updateExpressions: Seq[Expression] = Seq( + If(IsNotNull(currCounter), currCounter, prevCounter), + counter + ) + + /** + * The evaluation expression for the counter_diff function. + * + * Checks for edge cases first: NULL counter value, negative counter value and counter reset. + * Otherwise, returns the difference between the current and previous counter values. + */ + override lazy val evaluateExpression: Expression = + withCounterNullAndNegativeChecks(diffWithCounterDecreaseCheck) + + /** + * The SQL representation of the single-parameter form of the counter_diff function. + */ + override def sql: String = s"${prettyName}(${counter.sql})" Review Comment: Removed the `sql` override from both `counter_diff` expressions. Thanks! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
