This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 31e7c37 [SPARK-28185][PYTHON][SQL] Closes the generator when Python
UDFs stop early
31e7c37 is described below
commit 31e7c37354132545da59bff176af1613bd09447c
Author: WeichenXu <[email protected]>
AuthorDate: Fri Jun 28 17:10:25 2019 +0900
[SPARK-28185][PYTHON][SQL] Closes the generator when Python UDFs stop early
## What changes were proposed in this pull request?
Closes the generator when Python UDFs stop early.
### Manually verification on pandas iterator UDF and mapPartitions
```python
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.functions import col, udf
from pyspark.taskcontext import TaskContext
import time
import os
spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', '1')
spark.conf.set('spark.sql.pandas.udf.buffer.size', '4')
pandas_udf("int", PandasUDFType.SCALAR_ITER)
def fi1(it):
try:
for batch in it:
yield batch + 100
time.sleep(1.0)
except BaseException as be:
print("Debug: exception raised: " + str(type(be)))
raise be
finally:
open("/tmp/000001.tmp", "a").close()
df1 = spark.range(10).select(col('id').alias('a')).repartition(1)
# will see log Debug: exception raised: <class 'GeneratorExit'>
# and file "/tmp/000001.tmp" generated.
df1.select(col('a'), fi1('a')).limit(2).collect()
def mapper(it):
try:
for batch in it:
yield batch
except BaseException as be:
print("Debug: exception raised: " + str(type(be)))
raise be
finally:
open("/tmp/000002.tmp", "a").close()
df2 = spark.range(10000000).repartition(1)
# will see log Debug: exception raised: <class 'GeneratorExit'>
# and file "/tmp/000002.tmp" generated.
df2.rdd.mapPartitions(mapper).take(2)
```
## How was this patch tested?
Unit test added.
Please review https://spark.apache.org/contributing.html before opening a
pull request.
Closes #24986 from WeichenXu123/pandas_iter_udf_limit.
Authored-by: WeichenXu <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
---
python/pyspark/sql/tests/test_pandas_udf_scalar.py | 37 ++++++++++++++++++++++
python/pyspark/worker.py | 7 +++-
2 files changed, 43 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index c291d42..d254508 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -850,6 +850,43 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
with self.assertRaisesRegexp(Exception, "reached finally block"):
self.spark.range(1).select(test_close(col("id"))).collect()
+ def test_scalar_iter_udf_close_early(self):
+ tmp_dir = tempfile.mkdtemp()
+ try:
+ tmp_file = tmp_dir + '/reach_finally_block'
+
+ @pandas_udf('int', PandasUDFType.SCALAR_ITER)
+ def test_close(batch_iter):
+ generator_exit_caught = False
+ try:
+ for batch in batch_iter:
+ yield batch
+ time.sleep(1.0) # avoid the function finish too fast.
+ except GeneratorExit as ge:
+ generator_exit_caught = True
+ raise ge
+ finally:
+ assert generator_exit_caught, "Generator exit exception
was not caught."
+ open(tmp_file, 'a').close()
+
+ with QuietTest(self.sc):
+ with
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 1,
+ "spark.sql.pandas.udf.buffer.size": 4}):
+ self.spark.range(10).repartition(1) \
+ .select(test_close(col("id"))).limit(2).collect()
+ # wait here because python udf worker will take some time
to detect
+ # jvm side socket closed and then will trigger
`GenerateExit` raised.
+ # wait timeout is 10s.
+ for i in range(100):
+ time.sleep(0.1)
+ if os.path.exists(tmp_file):
+ break
+
+ assert os.path.exists(tmp_file), "finally block not
reached."
+
+ finally:
+ shutil.rmtree(tmp_dir)
+
# Regression test for SPARK-23314
def test_timestamp_dst(self):
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00
am
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index ee46bb6..04376c9 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -481,7 +481,12 @@ def main(infile, outfile):
def process():
iterator = deserializer.load_stream(infile)
- serializer.dump_stream(func(split_index, iterator), outfile)
+ out_iter = func(split_index, iterator)
+ try:
+ serializer.dump_stream(out_iter, outfile)
+ finally:
+ if hasattr(out_iter, 'close'):
+ out_iter.close()
if profiler:
profiler.profile(process)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]