This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8c84d2c9349 [SPARK-44018][SQL] Improve the hashCode and toString for some DS V2 Expression 8c84d2c9349 is described below commit 8c84d2c9349d7b607db949c2e114df781f23e438 Author: Jiaan Geng <belie...@163.com> AuthorDate: Mon Jun 19 15:55:06 2023 +0800 [SPARK-44018][SQL] Improve the hashCode and toString for some DS V2 Expression ### What changes were proposed in this pull request? The `hashCode() `of `UserDefinedScalarFunc` and `GeneralScalarExpression` is not good enough. Take for example, `GeneralScalarExpression` uses `Objects.hash(name, children)`, it adopt the hash code of `name` and `children`'s reference and then combine them together as the `GeneralScalarExpression`'s hash code. In fact, we should adopt the hash code for each element in `children`. Because `UserDefinedAggregateFunc` and `GeneralAggregateFunc` missing `hashCode()`, this PR also want add them. This PR also improve the toString for `UserDefinedAggregateFunc` and `GeneralAggregateFunc` by using bool primitive comparison instead `Objects.equals`. Because the performance of bool primitive comparison better than `Objects.equals`. ### Why are the changes needed? Improve the hash code for some DS V2 Expression. ### Does this PR introduce _any_ user-facing change? 'Yes'. ### How was this patch tested? N/A Closes #41543 from beliefer/SPARK-44018. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/GeneralScalarExpression.java | 10 ++++++--- .../expressions/UserDefinedScalarFunc.java | 13 ++++++++---- .../aggregate/GeneralAggregateFunc.java | 22 ++++++++++++++++++++ .../aggregate/UserDefinedAggregateFunc.java | 24 ++++++++++++++++++++++ 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index cb9bf6d69e2..85966060021 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector.expressions; import java.util.Arrays; -import java.util.Objects; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.filter.Predicate; @@ -441,12 +440,17 @@ public class GeneralScalarExpression extends ExpressionWithToString { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + GeneralScalarExpression that = (GeneralScalarExpression) o; - return Objects.equals(name, that.name) && Arrays.equals(children, that.children); + + if (!name.equals(that.name)) return false; + return Arrays.equals(children, that.children); } @Override public int hashCode() { - return Objects.hash(name, children); + int result = name.hashCode(); + result = 31 * result + Arrays.hashCode(children); + return result; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java index b7f603cd431..cbf3941d77d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector.expressions; import java.util.Arrays; -import java.util.Objects; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.internal.connector.ExpressionWithToString; @@ -51,13 +50,19 @@ public class UserDefinedScalarFunc extends ExpressionWithToString { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; + UserDefinedScalarFunc that = (UserDefinedScalarFunc) o; - return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) && - Arrays.equals(children, that.children); + + if (!name.equals(that.name)) return false; + if (!canonicalName.equals(that.canonicalName)) return false; + return Arrays.equals(children, that.children); } @Override public int hashCode() { - return Objects.hash(name, canonicalName, children); + int result = name.hashCode(); + result = 31 * result + canonicalName.hashCode(); + result = 31 * result + Arrays.hashCode(children); + return result; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 1abf3865659..4ef5b7f97e9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions.aggregate; +import java.util.Arrays; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.internal.connector.ExpressionWithToString; @@ -60,4 +62,24 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement @Override public Expression[] children() { return children; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GeneralAggregateFunc that = (GeneralAggregateFunc) o; + + if (isDistinct != that.isDistinct) return false; + if (!name.equals(that.name)) return false; + return Arrays.equals(children, that.children); + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + (isDistinct ? 1 : 0); + result = 31 * result + Arrays.hashCode(children); + return result; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java index d166ba16ba5..10a62d0478b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions.aggregate; +import java.util.Arrays; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.internal.connector.ExpressionWithToString; @@ -50,4 +52,26 @@ public class UserDefinedAggregateFunc extends ExpressionWithToString implements @Override public Expression[] children() { return children; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + UserDefinedAggregateFunc that = (UserDefinedAggregateFunc) o; + + if (isDistinct != that.isDistinct) return false; + if (!name.equals(that.name)) return false; + if (!canonicalName.equals(that.canonicalName)) return false; + return Arrays.equals(children, that.children); + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + canonicalName.hashCode(); + result = 31 * result + (isDistinct ? 1 : 0); + result = 31 * result + Arrays.hashCode(children); + return result; + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org