LantaoJin commented on a change in pull request #31189:
URL: https://github.com/apache/spark/pull/31189#discussion_r633392269
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
##########
@@ -620,6 +620,37 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan]
with PredicateHelper {
}
}
+/**
+ * Remove duplicated case when branches.
+ */
+object RemoveDuplicatedBranches extends Rule[LogicalPlan] with PredicateHelper
{
+
+ /**
+ * Wrapper around a branch that provides semantic equality.
+ */
+ case class EquivalentBranch(br: (Expression, Expression)) {
+ override def equals(o: Any): Boolean = o match {
+ case other: EquivalentBranch if br._1.deterministic &&
other.br._1.deterministic =>
+ br._1.semanticEquals(other.br._1)
+ case _ => false
+ }
+
+ override def hashCode: Int = br._1.semanticHash()
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case c @ CaseWhen(branches, _) if branches.length > 1 =>
+ val equivalentBranchSet = branches.map(EquivalentBranch).toSet
+ if (equivalentBranchSet.size < branches.length) {
+ val dedup = equivalentBranchSet.map(_.br).toSeq
Review comment:
Good catch. Should be
```scala
val distinctEquivalentBranches =
branches.map(EquivalentBranch).distinct
if (distinctEquivalentBranches.size < branches.length) {
val dedup = distinctEquivalentBranches.map(_.br)
c.copy(branches = dedup)
} else {
c
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]