HyukjinKwon closed pull request #23433: [SPARK-25591][PySpark][SQL][BRANCH-2.2]
Avoid overwriting deserialized accumulator
URL: https://github.com/apache/spark/pull/23433
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index bc0be07bfb36e..5d46b92e27b5c 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -110,10 +110,14 @@
def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
- accum = Accumulator(aid, zero_value, accum_param)
- accum._deserialized = True
- _accumulatorRegistry[aid] = accum
- return accum
+ # If this certain accumulator was deserialized, don't overwrite it.
+ if aid in _accumulatorRegistry:
+ return _accumulatorRegistry[aid]
+ else:
+ accum = Accumulator(aid, zero_value, accum_param)
+ accum._deserialized = True
+ _accumulatorRegistry[aid] = accum
+ return accum
class Accumulator(object):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 0926112baecfe..083bb199b5979 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2220,6 +2220,31 @@ def
test_create_dataframe_from_pandas_with_timestamp(self):
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
+ # SPARK-25591
+ def test_same_accumulator_in_udfs(self):
+ from pyspark.sql.functions import udf
+
+ data_schema = StructType([StructField("a", IntegerType(), True),
+ StructField("b", IntegerType(), True)])
+ data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
+
+ test_accum = self.sc.accumulator(0)
+
+ def first_udf(x):
+ test_accum.add(1)
+ return x
+
+ def second_udf(x):
+ test_accum.add(100)
+ return x
+
+ func_udf = udf(first_udf, IntegerType())
+ func_udf2 = udf(second_udf, IntegerType())
+ data = data.withColumn("out1", func_udf(data["a"]))
+ data = data.withColumn("out2", func_udf2(data["b"]))
+ data.collect()
+ self.assertEqual(test_accum.value, 101)
+
class HiveSparkSubmitTests(SparkSubmitTests):
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]