HyukjinKwon closed pull request #23432: [SPARK-25591][PySpark][SQL][BRANCH-2.3]
Avoid overwriting deserialized accumulator
URL: https://github.com/apache/spark/pull/23432
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index bc0be07bfb36e..5d46b92e27b5c 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -110,10 +110,14 @@
def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
- accum = Accumulator(aid, zero_value, accum_param)
- accum._deserialized = True
- _accumulatorRegistry[aid] = accum
- return accum
+ # If this certain accumulator was deserialized, don't overwrite it.
+ if aid in _accumulatorRegistry:
+ return _accumulatorRegistry[aid]
+ else:
+ accum = Accumulator(aid, zero_value, accum_param)
+ accum._deserialized = True
+ _accumulatorRegistry[aid] = accum
+ return accum
class Accumulator(object):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 3c5fc97c921bc..dc5ed198f4c50 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2932,6 +2932,31 @@ def test_create_dateframe_from_pandas_with_dst(self):
os.environ['TZ'] = orig_env_tz
time.tzset()
+ # SPARK-25591
+ def test_same_accumulator_in_udfs(self):
+ from pyspark.sql.functions import udf
+
+ data_schema = StructType([StructField("a", IntegerType(), True),
+ StructField("b", IntegerType(), True)])
+ data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
+
+ test_accum = self.sc.accumulator(0)
+
+ def first_udf(x):
+ test_accum.add(1)
+ return x
+
+ def second_udf(x):
+ test_accum.add(100)
+ return x
+
+ func_udf = udf(first_udf, IntegerType())
+ func_udf2 = udf(second_udf, IntegerType())
+ data = data.withColumn("out1", func_udf(data["a"]))
+ data = data.withColumn("out2", func_udf2(data["b"]))
+ data.collect()
+ self.assertEqual(test_accum.value, 101)
+
class HiveSparkSubmitTests(SparkSubmitTests):
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]