[
https://issues.apache.org/jira/browse/SPARK-33383?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17228483#comment-17228483
]
Felix Wollschläger commented on SPARK-33383:
--------------------------------------------
I also found out that when using only 500 allowedValues elements (instead of
10000), the _runWithInCollection_ performs better than both broadcast-join and
broadcast-variable.
> Improve performance of Column.isin Expression
> ---------------------------------------------
>
> Key: SPARK-33383
> URL: https://issues.apache.org/jira/browse/SPARK-33383
> Project: Spark
> Issue Type: Improvement
> Components: SQL
> Affects Versions: 2.4.4, 3.0.1
> Environment: macOS
> Spark(-SQL) 2.4.4 and 3.0.1
> Scala 2.12.10
> Reporter: Felix Wollschläger
> Priority: Major
>
> When I asked [a question on
> Stackoverflow|https://stackoverflow.com/questions/64683189/usage-of-broadcast-variables-when-using-only-spark-sql-api]
> and running some local tests, I came across a performance bottleneck when
> using the _where_-Condition _Column.isin_.
> I have a set of allowed-values ("whitelist") with a size that's handleable
> in-memory really good (about 10k values). I thought simply using the
> _Column.isin_ Expression in the SQL API should be the way to go. I assumed it
> would be runtime equivalent to
> {code}
> df.filter(row => allowedValues.contains(row.getInt(0)))
> {code}
> however, when running a few tests locally, I realized that using
> _Column.isin_ is actually about 10 times slower than a _rdd.filter_ or a
> broadcast-inner-join.
> Shouldn't {code}df.where(col("colname").isin(allowedValues)){code} perform
> (SQL-API overhead aside) as good as {code}df.filter(row =>
> allowedValues.contains(row.getInt(0))){code} ?
> I used the following dummy code for my local tests:
> {code:scala}
> package example
> import org.apache.spark.sql.functions.{broadcast, col, count}
> import org.apache.spark.sql.{DataFrame, SparkSession}
> import scala.util.Random
> object Test {
> def main(args: Array[String]): Unit = {
> val spark = SparkSession.builder()
> .appName("Name")
> .master("local[*]")
> .config("spark.driver.host", "localhost")
> .config("spark.ui.enabled", "false")
> .getOrCreate()
> import spark.implicits._
> val _10Million = 10000000
> val random = new Random(1048394789305L)
> val values = Seq.fill(_10Million)(random.nextInt())
> val df = values.toDF("value")
> val allowedValues = getRandomElements(values, random, 10000)
> println("Starting ...")
> runWithInCollection(spark, df, allowedValues)
> println("---- In Collection")
> runWithBroadcastDF(spark, df, allowedValues)
> println("---- Broadcast DF")
> runWithBroadcastVariable(spark, df, allowedValues)
> println("---- Broadcast Variable")
> }
> def getRandomElements[A](seq: Seq[A], random: Random, size: Int): Set[A]
> = {
> val builder = Set.newBuilder[A]
> for (i <- 0 until size) {
> builder += getRandomElement(seq, random)
> }
> builder.result()
> }
> def getRandomElement[A](seq: Seq[A], random: Random): A = {
> seq(random.nextInt(seq.length))
> }
> // I expected this one to be almost equivalent to the one with a
> broadcast-variable, but it's actually about 10 times slower
> def runWithInCollection(spark: SparkSession, df: DataFrame,
> allowedValues: Set[Int]): Unit = {
> spark.time {
>
> df.where(col("value").isInCollection(allowedValues)).runTestAggregation()
> }
> }
> // A bit slower than the one with a broadcast variable
> def runWithBroadcastDF(spark: SparkSession, df: DataFrame, allowedValues:
> Set[Int]): Unit = {
> import spark.implicits._
> val allowedValuesDF = allowedValues.toSeq.toDF("allowedValue")
> spark.time {
> df.join(broadcast(allowedValuesDF), col("value") ===
> col("allowedValue")).runTestAggregation()
> }
> }
> // This is actually the fastest one
> def runWithBroadcastVariable(spark: SparkSession, df: DataFrame,
> allowedValues: Set[Int]): Unit = {
> val allowedValuesBroadcast =
> spark.sparkContext.broadcast(allowedValues)
> spark.time {
> df.filter(row =>
> allowedValuesBroadcast.value.contains(row.getInt(0))).runTestAggregation()
> }
> }
> implicit class TestRunner(val df: DataFrame) {
> def runTestAggregation(): Unit = {
> df.agg(count("value")).show()
> }
> }
> }
> {code}
--
This message was sent by Atlassian Jira
(v8.3.4#803005)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]