Hello! I have a data set that I'm trying to process in PySpark. The data (on disk as Parquet) contains user IDs, session IDs, and metadata related to each session. I'm adding a number of columns to my dataframe that are the result of aggregating over a window. The issue I'm running into is that all but 4-6 executors will complete quickly and the rest run forever without completing. My code sample is below this message.
In my logs, I see this over and over: INFO ExternalAppendOnlyUnsafeRowArray: Reached spill threshold of 4096 rows, switching to org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter INFO UnsafeExternalSorter: Thread 92 spilling sort data of 9.2 GB to disk (2 times so far) INFO UnsafeExternalSorter: Thread 91 spilling sort data of 19.3 GB to disk ( 0 time so far) Which suggests that Spark can't hold all the windowed data in memory. I tried increasing the internal settings spark.sql.windowExec.buffer.in.memory.threshold and spark.sql.windowExec.buffer.spill.threshold, which helped a little but there are still executors not completing. I believe this is all caused by some skew in the data. Grouping by both user_id and session_id, there are 5 entries with a count >= 10,000, 100 records with a count between 1,000 and 10,000, and 150,000 entries with a count less than 1,000 (usually count = 1). Thanks in advance! Michael Code: ``` import pyspark.sql.functions as ffrom pyspark.sql.window import Window empty_col_a_cond = ((f.col("col_A").isNull()) | (f.col("col_A") == "")) session_window = Window.partitionBy("user_id", "session_id") \ .orderBy(f.col("step_id").asc()) output_df = ( input_df .withColumn("col_A_val", f .when(empty_col_a_cond, f.lit("NA")) .otherwise(f.col("col_A"))) # ... 10 more added columns replacing nulls/empty strings .repartition("user_id", "session_id") .withColumn("s_user_id", f.first("user_id", True).over(session_window)) .withColumn("s_col_B", f.collect_list("col_B").over(session_window)) .withColumn("s_col_C", f.min("col_C").over(session_window)) .withColumn("s_col_D", f.max("col_D").over(session_window)) # ... 16 more added columns aggregating over session_window .where(f.col("session_flag") == 1) .where(f.array_contains(f.col("s_col_B"), "some_val")) ) ```