mikhailnik-db commented on code in PR #54297:
URL: https://github.com/apache/spark/pull/54297#discussion_r2822984368
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala:
##########
@@ -564,6 +565,126 @@ case class ListAgg(
false
}
+ /**
+ * Validates that the ordering expression is compatible with DISTINCT
deduplication.
+ *
+ * When LISTAGG(DISTINCT col) WITHIN GROUP (ORDER BY col) is used on a
non-string column,
+ * the child is implicitly cast to string (with UTF8_BINARY collation). The
DISTINCT rewrite
+ * (see [[RewriteDistinctAggregates]]) uses GROUP BY on both the original
and cast columns,
+ * so the cast must preserve equality semantics: values that are GROUP
BY-equal must cast to
+ * equal strings, and vice versa. Types like Float/Double violate this
because IEEE 754
+ * negative zero (-0.0) and positive zero (0.0) are equal but produce
different strings.
+ *
+ * This method is a no-op when the order expression matches the child (i.e.,
+ * [[needSaveOrderValue]] is false). Otherwise, the behavior depends on the
+ * [[SQLConf.LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER]] config:
+ * - If enabled, delegates to [[orderMismatchCastSafety]] to determine
whether the
+ * mismatch is due to a safe cast, an unsafe cast, or not a cast at all.
+ * - If disabled, rejects any mismatch.
+ *
+ * @throws AnalysisException if the ordering is incompatible with DISTINCT
+ */
+ def validateDistinctOrderCompatibility(): Unit = {
+ if (needSaveOrderValue) {
+ if (SQLConf.get.listaggAllowDistinctCastWithOrder) {
+ orderMismatchCastSafety match {
+ case CastSafetyResult.SafeCast => // safe cast, allow
+ case CastSafetyResult.UnsafeCast(inputType, castType) =>
+ throwFunctionAndOrderExpressionUnsafeCastError(inputType, castType)
+ case CastSafetyResult.NotACast =>
+ throwFunctionAndOrderExpressionMismatchError()
+ }
+ } else {
+ throwFunctionAndOrderExpressionMismatchError()
+ }
+ }
+ }
+
+ private def throwFunctionAndOrderExpressionMismatchError() = {
+ throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
+ prettyName, child, orderExpressions)
+ }
+
+ private def throwFunctionAndOrderExpressionUnsafeCastError(
+ inputType: DataType, castType: DataType) = {
+ throw QueryCompilationErrors.functionAndOrderExpressionUnsafeCastError(
+ prettyName, inputType, castType)
+ }
+
+ /**
+ * Classifies the order-expression mismatch as a safe cast, unsafe cast, or
not a cast.
+ *
+ * @see [[validateDistinctOrderCompatibility]] for the full invariant this
enforces
+ */
+ private def orderMismatchCastSafety: CastSafetyResult = {
+ if (orderExpressions.size != 1) return CastSafetyResult.NotACast
+ child match {
+ case Cast(castChild, castType, _, _)
+ if orderExpressions.head.child.semanticEquals(castChild) =>
+ if (isCastSafeForDistinct(castChild.dataType) &&
+ isCastTargetSafeForDistinct(castType)) {
+ CastSafetyResult.SafeCast
+ } else {
+ CastSafetyResult.UnsafeCast(castChild.dataType, castType)
+ }
+ case _ => CastSafetyResult.NotACast
+ }
+ }
+
+ /**
+ * Returns true if casting `dt` to string is injective for DISTINCT
deduplication.
+ *
+ * @see [[validateDistinctOrderCompatibility]]
+ */
+ private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
+ case _: IntegerType | LongType | ShortType | ByteType => true
+ case _: DecimalType => true
+ case _: DateType | TimestampNTZType => true
+ case _: TimeType => true
+ case _: CalendarIntervalType => true
+ case _: YearMonthIntervalType => true
+ case _: DayTimeIntervalType => true
+ case BooleanType => true
+ case BinaryType => true
+ case st: StringType if st.isUTF8BinaryCollation => true
Review Comment:
nit:
```suggestion
case st: StringType => st.isUTF8BinaryCollation
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -6889,6 +6889,16 @@ object SQLConf {
.booleanConf
.createWithDefault(Utils.isTesting)
+ val LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER =
+ buildConf("spark.sql.listagg.allowDistinctCastWithOrder.enabled")
+ .internal()
+ .doc("When true, LISTAGG(DISTINCT expr) WITHIN GROUP (ORDER BY expr) is
allowed on " +
+ "non-string expr whose cast to string is injective. When false,
DISTINCT requires " +
+ "expr and ORDER BY to reference the same expression with no cast.")
Review Comment:
For internal conf, we can give more info so other developers have context on
what it actually does. I would rephrase it a bit: `When true, LISTAGG(DISTINCT
expr) WITHIN GROUP (ORDER BY expr) is allowed on non-string expr when the
implicit cast to string preserves equality (e.g., integer, decimal, date). When
false, the function argument and ORDER BY expression must have the exact same
type, which requires explicit casts`
##########
sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql:
##########
@@ -12,3 +12,4 @@ WITH t(c1) AS (SELECT listagg(col1) WITHIN GROUP (ORDER BY
col1 COLLATE unicode_
-- Error case with collations
SELECT listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1
COLLATE utf8_binary) FROM (VALUES ('a'), ('b'), ('A'), ('B')) AS t(c1);
+SELECT listagg(DISTINCT CAST(col AS STRING COLLATE UTF8_LCASE)) WITHIN GROUP
(ORDER BY col) FROM VALUES (X'414243'), (X'616263'), (X'414243') AS t(col)
Review Comment:
Let's actually cover all cases 6:
* string -> string cast
* binary -> string cast
* string -> binary cast
X
* safe
* unsafe
##########
common/utils/src/main/resources/error/error-conditions.json:
##########
@@ -4284,6 +4284,11 @@
"The function is invoked with DISTINCT and WITHIN GROUP but
expressions <funcArg> and <orderingExpr> do not match. The WITHIN GROUP
ordering expression must be picked from the function inputs."
]
},
+ "MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST" : {
+ "message" : [
+ "The function <funcName> with DISTINCT requires a cast from
<inputType> to <castType>, but this cast may not preserve equality semantics
for the input type (e.g., floating-point -0.0 and 0.0 are treated as equal
during GROUP BY but cast to different strings, leading to incorrect
deduplication)."
+ ]
+ },
Review Comment:
IMHO, the message has too much into the implementation details. I'd do smth
more simple and straightforward like:
```
<funcName> with DISTINCT and WITHIN GROUP (ORDER BY) is not supported for
<inputType> input. Cast the input to <castType> explicitly before passing it to
the function argument and
ORDER BY expression.
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala:
##########
@@ -116,8 +118,7 @@ class AggregateExpressionResolver(
private def validateResolvedAggregateExpression(aggregateExpression:
AggregateExpression): Unit =
aggregateExpression match {
case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
- if agg.isDistinct && listAgg.needSaveOrderValue =>
- throwFunctionAndOrderExpressionMismatchError(listAgg)
+ if agg.isDistinct => listAgg.validateDistinctOrderCompatibility()
Review Comment:
Sorry for ping-ponging, but I've just realized that this approach is not
correct. The logic here should follow the rule "if we go to a non-default
branch, that means we found an error and must throw". Currently, we can just
successfully execute`validateDistinctOrderCompatibility()` and skip the general
check from `case _ =>`. The same is applicable to `CheckAnalysis` as well.
Not sure how to better structure the code here. Probably it's okay to have a
method with very similar logic to `validateDistinctOrderCompatibility`, but
returning a `bool`, whether we should throw. But it's still code duplication...
Open to suggestions :)
##########
sql/core/src/test/resources/sql-tests/results/listagg.sql.out:
##########
@@ -189,6 +189,144 @@ struct<len(col1):int,regexp_count(col1,
1):int,regexp_count(col1, 2):int,regexp_
3 1 1 1 16 1 0
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES (1),
(2), (2), (3) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+1,2,3
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES
(cast(1 as bigint)), (cast(2 as bigint)), (cast(2 as bigint)), (cast(3 as
bigint)) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+1,2,3
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES
(cast(1 as smallint)), (cast(2 as smallint)), (cast(2 as smallint)), (cast(3 as
smallint)) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+1,2,3
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES
(cast(1 as tinyint)), (cast(2 as tinyint)), (cast(2 as tinyint)), (cast(3 as
tinyint)) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+1,2,3
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES
(cast(1.10 as decimal(10,2))), (cast(2.20 as decimal(10,2))), (cast(2.20 as
decimal(10,2))) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+1.10,2.20
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES
(DATE'2024-01-01'), (DATE'2024-01-02'), (DATE'2024-01-01') AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+2024-01-01,2024-01-02
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES
(TIMESTAMP_NTZ'2024-01-01 10:00:00'), (TIMESTAMP_NTZ'2024-01-02 12:00:00'),
(TIMESTAMP_NTZ'2024-01-01 10:00:00') AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+2024-01-01 10:00:00,2024-01-02 12:00:00
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES
(true), (false), (true) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+false,true
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES
(INTERVAL '1' MONTH), (INTERVAL '2' MONTH), (INTERVAL '1' MONTH) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+INTERVAL '1' MONTH,INTERVAL '2' MONTH
+
+
+-- !query
+SELECT listagg(DISTINCT cast(col as string), ',') WITHIN GROUP (ORDER BY col)
FROM VALUES (10), (1), (2), (20), (2) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT CAST(col AS STRING), ,) WITHIN GROUP (ORDER BY col ASC
NULLS FIRST):string>
+-- !query output
+1,2,10,20
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col DESC) FROM VALUES
(1), (10), (2), (20), (2) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col DESC NULLS
LAST):string>
+-- !query output
+20,10,2,1
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col) FROM VALUES (1),
(2), (null), (2), (3) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+1,2,3
+
+
+-- !query
+SELECT listagg(DISTINCT col, ',') WITHIN GROUP (ORDER BY col NULLS FIRST) FROM
VALUES (1), (null), (2), (null) AS t(col)
+-- !query schema
+struct<listagg(DISTINCT col, ,) WITHIN GROUP (ORDER BY col ASC NULLS
FIRST):string>
+-- !query output
+1,2
+
+
+-- !query
+SELECT grp, listagg(DISTINCT col) WITHIN GROUP (ORDER BY col) FROM VALUES (1,
'a'), (1, 'b'), (2, 'a'), (2, 'a'), (1, 'b') AS t(grp, col) GROUP BY grp
+-- !query schema
+struct<grp:int,listagg(DISTINCT col, NULL) WITHIN GROUP (ORDER BY col ASC
NULLS FIRST):string>
+-- !query output
+1 ab
+2 a
+
+
+-- !query
+WITH t(col) AS (SELECT listagg(DISTINCT col1, X'2C') WITHIN GROUP (ORDER BY
col1) FROM (VALUES (X'DEAD'), (X'BEEF'), (X'DEAD'), (X'CAFE'))) SELECT
len(col), regexp_count(col, X'DEAD'), regexp_count(col, X'BEEF'),
regexp_count(col, X'CAFE') FROM t
+-- !query schema
+struct<len(col):int,regexp_count(col, X'DEAD'):int,regexp_count(col,
X'BEEF'):int,regexp_count(col, X'CAFE'):int>
+-- !query output
+8 1 2 2
Review Comment:
Is it correct? Shouldn't it be `8 1 1 1`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]