stefankandic commented on code in PR #48598:
URL: https://github.com/apache/spark/pull/48598#discussion_r1812258235
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -655,6 +657,14 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
}
@transient lazy val set: Set[Any] = child.dataType match {
+ case st: StringType =>
+ if (st.supportsBinaryEquality) {
Review Comment:
```suggestion
if (st.isUTF8BinaryCollation) {
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -655,6 +657,14 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
}
@transient lazy val set: Set[Any] = child.dataType match {
+ case st: StringType =>
+ if (st.supportsBinaryEquality) {
+ hset
+ } else if (st.supportsLowercaseEquality) {
Review Comment:
```suggestion
} else if (st.isUTF8LcaseCollation) {
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -767,6 +777,35 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
override protected def withNewChildInternal(newChild: Expression): InSet =
copy(child = newChild)
}
+object InSet {
+ class LCaseSet(inputSet: Set[Any]) extends immutable.Set[Any] with
Serializable {
+ private val strSet = inputSet.map { s =>
+ if (s == null) null
+ else
CollationAwareUTF8String.lowerCaseCodePoints(s.asInstanceOf[UTF8String])
+ }
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
+ override def excl(elem: Any): Set[Any] = inputSet.excl(elem)
+ override def iterator: Iterator[Any] = inputSet.iterator
+ override def contains(elem: Any): Boolean = {
+ assert(elem != null, "InSet guarantees non-null input")
+
strSet.contains(CollationAwareUTF8String.lowerCaseCodePoints(elem.asInstanceOf[UTF8String]))
+ }
+ }
+ class CollationSet(inputSet: Set[Any], collationId: Int)
+ extends immutable.Set[Any] with Serializable {
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
+ override def excl(elem: Any): Set[Any] = inputSet.excl(elem)
+ override def iterator: Iterator[Any] = inputSet.iterator
+ override def contains(elem: Any): Boolean = {
+ assert(elem != null, "InSet guarantees non-null input")
+ val collation = CollationFactory.fetchCollation(collationId)
+ inputSet.exists { p =>
Review Comment:
This seems pretty inefficient, can we also do the similar thing as we did
for the lcase one -> preprocess all elements by calculating the hashCode of the
string's collationKey and then create a Map[Int, Seq[String]].
After that, in contains we can query the map to see if the element is
present or not.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -767,6 +777,35 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
override protected def withNewChildInternal(newChild: Expression): InSet =
copy(child = newChild)
}
+object InSet {
+ class LCaseSet(inputSet: Set[Any]) extends immutable.Set[Any] with
Serializable {
+ private val strSet = inputSet.map { s =>
+ if (s == null) null
+ else
CollationAwareUTF8String.lowerCaseCodePoints(s.asInstanceOf[UTF8String])
+ }
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
+ override def excl(elem: Any): Set[Any] = inputSet.excl(elem)
+ override def iterator: Iterator[Any] = inputSet.iterator
+ override def contains(elem: Any): Boolean = {
+ assert(elem != null, "InSet guarantees non-null input")
+
strSet.contains(CollationAwareUTF8String.lowerCaseCodePoints(elem.asInstanceOf[UTF8String]))
+ }
+ }
+ class CollationSet(inputSet: Set[Any], collationId: Int)
+ extends immutable.Set[Any] with Serializable {
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
+ override def excl(elem: Any): Set[Any] = inputSet.excl(elem)
+ override def iterator: Iterator[Any] = inputSet.iterator
+ override def contains(elem: Any): Boolean = {
+ assert(elem != null, "InSet guarantees non-null input")
+ val collation = CollationFactory.fetchCollation(collationId)
Review Comment:
can we create this variable outside of the function so we only fetch the
collation once?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -767,6 +777,35 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
override protected def withNewChildInternal(newChild: Expression): InSet =
copy(child = newChild)
}
+object InSet {
+ class LCaseSet(inputSet: Set[Any]) extends immutable.Set[Any] with
Serializable {
+ private val strSet = inputSet.map { s =>
+ if (s == null) null
+ else
CollationAwareUTF8String.lowerCaseCodePoints(s.asInstanceOf[UTF8String])
+ }
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
+ override def excl(elem: Any): Set[Any] = inputSet.excl(elem)
+ override def iterator: Iterator[Any] = inputSet.iterator
+ override def contains(elem: Any): Boolean = {
+ assert(elem != null, "InSet guarantees non-null input")
+
strSet.contains(CollationAwareUTF8String.lowerCaseCodePoints(elem.asInstanceOf[UTF8String]))
+ }
+ }
+ class CollationSet(inputSet: Set[Any], collationId: Int)
+ extends immutable.Set[Any] with Serializable {
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
+ override def excl(elem: Any): Set[Any] = inputSet.excl(elem)
+ override def iterator: Iterator[Any] = inputSet.iterator
+ override def contains(elem: Any): Boolean = {
+ assert(elem != null, "InSet guarantees non-null input")
+ val collation = CollationFactory.fetchCollation(collationId)
+ inputSet.exists { p =>
Review Comment:
cc: @uros-db who did some work with collation keys to see if this makes sense
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -767,6 +777,35 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
override protected def withNewChildInternal(newChild: Expression): InSet =
copy(child = newChild)
}
+object InSet {
+ class LCaseSet(inputSet: Set[Any]) extends immutable.Set[Any] with
Serializable {
+ private val strSet = inputSet.map { s =>
+ if (s == null) null
+ else
CollationAwareUTF8String.lowerCaseCodePoints(s.asInstanceOf[UTF8String])
+ }
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
Review Comment:
shouldn't these methods also return a `LCaseSet` (same for `excl` and both
methods in the other class)?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -767,6 +777,35 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
override protected def withNewChildInternal(newChild: Expression): InSet =
copy(child = newChild)
}
+object InSet {
+ class LCaseSet(inputSet: Set[Any]) extends immutable.Set[Any] with
Serializable {
+ private val strSet = inputSet.map { s =>
+ if (s == null) null
+ else
CollationAwareUTF8String.lowerCaseCodePoints(s.asInstanceOf[UTF8String])
+ }
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
+ override def excl(elem: Any): Set[Any] = inputSet.excl(elem)
+ override def iterator: Iterator[Any] = inputSet.iterator
+ override def contains(elem: Any): Boolean = {
+ assert(elem != null, "InSet guarantees non-null input")
+
strSet.contains(CollationAwareUTF8String.lowerCaseCodePoints(elem.asInstanceOf[UTF8String]))
Review Comment:
this seems dangerous, can we implement Set[UTF8String] directly (and also
take Set[UTF8String] in constructor?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -767,6 +777,35 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
override protected def withNewChildInternal(newChild: Expression): InSet =
copy(child = newChild)
}
+object InSet {
Review Comment:
since these classes seem pretty general I'd maybe think about moving them
somewhere where other parts of the code can use them as well, something like
CollationUtils maybe?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -767,6 +777,35 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
override protected def withNewChildInternal(newChild: Expression): InSet =
copy(child = newChild)
}
+object InSet {
+ class LCaseSet(inputSet: Set[Any]) extends immutable.Set[Any] with
Serializable {
Review Comment:
```suggestion
class LowercaseSet(inputSet: Set[Any]) extends immutable.Set[Any] with
Serializable {
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:
##########
@@ -767,6 +777,35 @@ case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with
override protected def withNewChildInternal(newChild: Expression): InSet =
copy(child = newChild)
}
+object InSet {
+ class LCaseSet(inputSet: Set[Any]) extends immutable.Set[Any] with
Serializable {
+ private val strSet = inputSet.map { s =>
+ if (s == null) null
+ else
CollationAwareUTF8String.lowerCaseCodePoints(s.asInstanceOf[UTF8String])
+ }
+ override def incl(elem: Any): Set[Any] = inputSet.incl(elem)
+ override def excl(elem: Any): Set[Any] = inputSet.excl(elem)
+ override def iterator: Iterator[Any] = inputSet.iterator
+ override def contains(elem: Any): Boolean = {
+ assert(elem != null, "InSet guarantees non-null input")
+
strSet.contains(CollationAwareUTF8String.lowerCaseCodePoints(elem.asInstanceOf[UTF8String]))
+ }
+ }
+ class CollationSet(inputSet: Set[Any], collationId: Int)
Review Comment:
```suggestion
class CollationAwareSet(inputSet: Set[Any], collationId: Int)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]