hvanhovell commented on a change in pull request #29598:
URL: https://github.com/apache/spark/pull/29598#discussion_r480205472
##########
File path:
sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
##########
@@ -53,46 +57,102 @@ object ExpressionSet {
* This is consistent with how we define `semanticEquals` between two
expressions.
*/
class ExpressionSet protected(
- protected val baseSet: mutable.Set[Expression] = new mutable.HashSet,
- protected val originals: mutable.Buffer[Expression] = new ArrayBuffer)
- extends Set[Expression] {
+ private val baseSet: mutable.Set[Expression] = new mutable.HashSet,
+ private val originals: mutable.Buffer[Expression] = new ArrayBuffer)
+ extends Iterable[Expression] {
// Note: this class supports Scala 2.12. A parallel source tree has a 2.13
implementation.
protected def add(e: Expression): Unit = {
if (!e.deterministic) {
originals += e
- } else if (!baseSet.contains(e.canonicalized) ) {
+ } else if (!baseSet.contains(e.canonicalized)) {
baseSet.add(e.canonicalized)
originals += e
}
}
- override def contains(elem: Expression): Boolean =
baseSet.contains(elem.canonicalized)
+ protected def remove(e: Expression): Unit = {
+ if (e.deterministic) {
+ baseSet --= baseSet.filter(_ == e.canonicalized)
+ originals --= originals.filter(_.canonicalized == e.canonicalized)
+ }
+ }
+
+ def contains(elem: Expression): Boolean =
baseSet.contains(elem.canonicalized)
+
+ override def filter(p: Expression => Boolean): ExpressionSet = {
+ val newBaseSet = baseSet.filter(e => p(e.canonicalized))
+ val newOriginals = originals.filter(e => p(e.canonicalized))
+ new ExpressionSet(newBaseSet, newOriginals)
+ }
+
+ override def filterNot(p: Expression => Boolean): ExpressionSet = {
+ val newBaseSet = baseSet.filterNot(e => p(e.canonicalized))
+ val newOriginals = originals.filterNot(e => p(e.canonicalized))
+ new ExpressionSet(newBaseSet, newOriginals)
+ }
- override def +(elem: Expression): ExpressionSet = {
- val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
+ def +(elem: Expression): ExpressionSet = {
+ val newSet = clone()
newSet.add(elem)
newSet
}
- override def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
- val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
+ def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
+ val newSet = clone()
elems.foreach(newSet.add)
newSet
}
- override def -(elem: Expression): ExpressionSet = {
- if (elem.deterministic) {
- val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized)
- val newOriginals = originals.clone().filterNot(_.canonicalized ==
elem.canonicalized)
- new ExpressionSet(newBaseSet, newOriginals)
- } else {
- new ExpressionSet(baseSet.clone(), originals.clone())
- }
+ def -(elem: Expression): ExpressionSet = {
+ val newSet = clone()
+ newSet.remove(elem)
Review comment:
Isn't this more efficient?:
```scala
ExpressionSet(baseSet.filter(_ != e. canonicalized),
originals.filter(_.canonicalized != e.canonicalized))
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]