Alain Bryden created SPARK-36844:
------------------------------------

             Summary: "first" Window function is significantly slower than 
"last" in identical circumstances
                 Key: SPARK-36844
                 URL: https://issues.apache.org/jira/browse/SPARK-36844
             Project: Spark
          Issue Type: Bug
          Components: PySpark, Windows
    Affects Versions: 3.1.1
            Reporter: Alain Bryden


I originally posted a question on SO because I thought perhaps I was doing 
something wrong:

[https://stackoverflow.com/questions/69308560|https://stackoverflow.com/questions/69308560/spark-first-window-function-is-taking-much-longer-than-last?noredirect=1#comment122505685_69308560]

Perhaps I am, but I'm now fairly convinced that there's something wonky with 
the implementation of `first` that's causing it to unnecessarily have a much 
worse complexity than `last`.

 

More or less copy-pasted from the above post:

I was working on a pyspark routine to interpolate the missing values in a 
configuration table.

Imagine a table of configuration values that go from 0 to 50,000. The user 
specifies a few data points in between (say at 0, 50, 100, 500, 2000, 500000) 
and we interpolate the remainder. My solution mostly follows [this blog 
post|https://walkenho.github.io/interpolating-time-series-p2-spark/] quite 
closely, except I'm not using any UDFs.

In troubleshooting the performance of this (takes ~3 minutes) I found that one 
particular window function is taking all of the time, and everything else I'm 
doing takes mere seconds.

Here is the main area of interest - where I use window functions to fill in the 
previous and next user-supplied configuration values:
{code:python}
from pyspark.sql import Window, functions as F

# Create partition windows that are required to generate new rows from the ones 
provided
win_last = Window.partitionBy('PORT_TYPE', 
'loss_process').orderBy('rank').rowsBetween(Window.unboundedPreceding, 0)
win_next = Window.partitionBy('PORT_TYPE', 
'loss_process').orderBy('rank').rowsBetween(0, Window.unboundedFollowing)

# Join back in the provided config table to populate the "known" scale factors
df_part1 = (df_scale_factors_template
  .join(df_users_config, ['PORT_TYPE', 'loss_process', 'rank'], 'leftouter')
  # Add computed columns that can lookup the prior config and next config for 
each missing value
  .withColumn('last_rank', F.last( F.col('rank'),         
ignorenulls=True).over(win_last))
  .withColumn('last_sf',   F.last( F.col('scale_factor'), 
ignorenulls=True).over(win_last))
).cache()
debug_log_dataframe(df_part1 , 'df_part1') # Force a .count() and time Part1

df_part2 = (df_part1
  .withColumn('next_rank', F.first(F.col('rank'),         
ignorenulls=True).over(win_next))
  .withColumn('next_sf',   F.first(F.col('scale_factor'), 
ignorenulls=True).over(win_next))
).cache()
debug_log_dataframe(df_part2 , 'df_part2') # Force a .count() and time Part2

df_part3 = (df_part2
  # Implements standard linear interpolation: y = y1 + ((y2-y1)/(x2-x1)) * 
(x-x1)
  .withColumn('scale_factor', 
              F.when(F.col('last_rank')==F.col('next_rank'), F.col('last_sf')) 
# Handle div/0 case
              .otherwise(F.col('last_sf') + 
((F.col('next_sf')-F.col('last_sf'))/(F.col('next_rank')-F.col('last_rank'))) * 
(F.col('rank')-F.col('last_rank'))))
  .select('PORT_TYPE', 'loss_process', 'rank', 'scale_factor')
).cache()
debug_log_dataframe(df_part3, 'df_part3', explain: True)
{code}
 

The above used to be a single chained dataframe statement, but I've since split 
it into 3 parts so that I could isolate the part that's taking so long. The 
results are:
 * {{Part 1: Generated 8 columns and 300006 rows in 0.65 seconds}}
 * {{Part 2: Generated 10 columns and 300006 rows in 189.55 seconds}}
 * {{Part 3: Generated 4 columns and 300006 rows in 0.24 seconds}}

 

In trying various things to speed up my routine, it occurred to me to try 
re-rewriting my usages of {{first()}} to just be usages of {{last()}} with a 
reversed sort order.

So rewriting this:

{code:python}
win_next = (Window.partitionBy('PORT_TYPE', 'loss_process')
  .orderBy('rank').rowsBetween(0, Window.unboundedFollowing))

df_part2 = (df_part1
  .withColumn('next_rank', F.first(F.col('rank'),         
ignorenulls=True).over(win_next))
  .withColumn('next_sf',   F.first(F.col('scale_factor'), 
ignorenulls=True).over(win_next))
)
{code}
 
As this:

{code:python}
win_next = (Window.partitionBy('PORT_TYPE', 'loss_process')
  .orderBy(F.desc('rank')).rowsBetween(Window.unboundedPreceding, 0))

df_part2 = (df_part1
  .withColumn('next_rank', F.last(F.col('rank'),         
ignorenulls=True).over(win_next))
  .withColumn('next_sf',   F.last(F.col('scale_factor'), 
ignorenulls=True).over(win_next))
)
{code}
 
Much to my amazement, this actually solved the performance problem, and now the 
entire dataframe is generated in just 3 seconds.

I don't know anything about the internals, but conceptually I feel as though 
the initial solution should be faster, because all 4 columns should be able to 
take advantage of the same window and sort order by merely look forwards or 
backwards along the window - re-sorting like this shouldn't be necessary.



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to