zhengruifeng commented on code in PR #36599:
URL: https://github.com/apache/spark/pull/36599#discussion_r880050644
##########
python/pyspark/pandas/series.py:
##########
@@ -6255,36 +6261,47 @@ def argmax(self) -> int:
--------
Consider dataset containing cereal calories
- >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0,
+ >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0,
'Unknown': np.nan,
... 'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff':
110.0})
- >>> s # doctest: +SKIP
+ >>> s
Corn Flakes 100.0
Almond Delight 110.0
+ Unknown NaN
Cinnamon Toast Crunch 120.0
Cocoa Puff 110.0
dtype: float64
- >>> s.argmax() # doctest: +SKIP
- 2
+ >>> s.argmax()
+ 3
+
+ >>> s.argmax(skipna=False)
+ -1
"""
sdf = self._internal.spark_frame.select(self.spark.column,
NATURAL_ORDER_COLUMN_NAME)
+ seq_col_name = verify_temp_column_name(sdf,
"__distributed_sequence_column__")
+ sdf = InternalFrame.attach_distributed_sequence_column(
+ sdf,
+ seq_col_name,
+ )
+ scol = scol_for(sdf, self._internal.data_spark_column_names[0])
+
+ if skipna:
+ sdf = sdf.orderBy(scol.desc_nulls_last(),
NATURAL_ORDER_COLUMN_NAME)
+ else:
+ sdf = sdf.orderBy(scol.desc_nulls_first(),
NATURAL_ORDER_COLUMN_NAME)
+
max_value = sdf.select(
- F.max(scol_for(sdf, self._internal.data_spark_column_names[0])),
+ F.first(scol),
F.first(NATURAL_ORDER_COLUMN_NAME),
).head()
+
if max_value[1] is None:
raise ValueError("attempt to get argmax of an empty sequence")
elif max_value[0] is None:
return -1
- # We should remember the natural sequence started from 0
- seq_col_name = verify_temp_column_name(sdf,
"__distributed_sequence_column__")
- sdf = InternalFrame.attach_distributed_sequence_column(
- sdf.drop(NATURAL_ORDER_COLUMN_NAME), seq_col_name
- )
+
# If the maximum is achieved in multiple locations, the first row
position is returned.
- return sdf.filter(
- scol_for(sdf, self._internal.data_spark_column_names[0]) ==
max_value[0]
- ).head()[0]
+ return sdf.filter(scol == max_value[0]).head()[0]
Review Comment:
I had a try to apply `max_by` here but found it can not guarantee the `If
the maximum is achieved in multiple locations, the first row position is
returned.`
let's keep current code. I'll take another look at `max_by`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]