Repository: spark Updated Branches: refs/heads/master 5f4deff19 -> 566321852
[SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible ## What changes were proposed in this pull request? https://github.com/apache/spark/commit/d6632d185e147fcbe6724545488ad80dce20277e added an useful util ```python contextmanager def sql_conf(self, pairs): ... ``` to allow configuration set/unset within a block: ```python with self.sql_conf({"spark.blah.blah.blah", "blah"}) # test codes ``` This PR proposes to use this util where possible in PySpark tests. Note that there look already few places affecting tests without restoring the original value back in unittest classes. ## How was this patch tested? Manually tested via: ``` ./run-tests --modules=pyspark-sql --python-executables=python2 ./run-tests --modules=pyspark-sql --python-executables=python3 ``` Author: hyukjinkwon <[email protected]> Closes #20830 from HyukjinKwon/cleanup-sql-conf. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56632185 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56632185 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56632185 Branch: refs/heads/master Commit: 566321852b2d60641fe86acbc8914b4a7063b58e Parents: 5f4deff Author: hyukjinkwon <[email protected]> Authored: Mon Mar 19 21:25:37 2018 -0700 Committer: Bryan Cutler <[email protected]> Committed: Mon Mar 19 21:25:37 2018 -0700 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 130 +++++++++++++++------------------------ 1 file changed, 50 insertions(+), 80 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/56632185/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a0d547a..39d6c52 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2461,17 +2461,13 @@ class SQLTests(ReusedSQLTestCase): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") - try: - self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + with self.sql_conf({"spark.sql.crossJoin.enabled": False}): self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) - self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): actual = df1.join(df2, how="inner").collect() expected = [Row(a=0, b=0)] self.assertEqual(actual, expected) - finally: - # We should unset this. Otherwise, other tests are affected. - self.spark.conf.unset("spark.sql.crossJoin.enabled") # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): @@ -2943,21 +2939,18 @@ class SQLTests(ReusedSQLTestCase): self.assertPandasEqual(pdf, df.toPandas()) orig_env_tz = os.environ.get('TZ', None) - orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') try: tz = 'America/Los_Angeles' os.environ['TZ'] = tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', tz) - - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) + with self.sql_conf({'spark.sql.session.timeZone': tz}): + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) finally: del os.environ['TZ'] if orig_env_tz is not None: os.environ['TZ'] = orig_env_tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) class HiveSparkSubmitTests(SparkSubmitTests): @@ -3562,12 +3555,11 @@ class ArrowTests(ReusedSQLTestCase): self.assertTrue(all([c == 1 for c in null_counts])) def _toPandas_arrow_toggle(self, df): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): pdf = df.toPandas() - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + pdf_arrow = df.toPandas() + return pdf, pdf_arrow def test_toPandas_arrow_toggle(self): @@ -3579,16 +3571,17 @@ class ArrowTests(ReusedSQLTestCase): def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_la, pdf_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_la, pdf_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) self.assertPandasEqual(pdf_arrow_ny, pdf_ny) @@ -3601,8 +3594,6 @@ class ArrowTests(ReusedSQLTestCase): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) self.assertPandasEqual(pdf_ny, pdf_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() @@ -3618,12 +3609,11 @@ class ArrowTests(ReusedSQLTestCase): self.assertTrue(pdf.empty) def _createDataFrame_toggle(self, pdf, schema=None): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + return df_no_arrow, df_arrow def test_createDataFrame_toggle(self): @@ -3634,18 +3624,18 @@ class ArrowTests(ReusedSQLTestCase): def test_createDataFrame_respect_session_timezone(self): from datetime import timedelta pdf = self.create_pandas_data_frame() - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) - result_la = df_no_arrow_la.collect() - result_arrow_la = df_arrow_la.collect() - self.assertEqual(result_la, result_arrow_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) result_ny = df_no_arrow_ny.collect() result_arrow_ny = df_arrow_ny.collect() @@ -3658,8 +3648,6 @@ class ArrowTests(ReusedSQLTestCase): for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_createDataFrame_with_schema(self): pdf = self.create_pandas_data_frame() @@ -4336,9 +4324,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): def test_vectorized_udf_check_config(self): from pyspark.sql.functions import pandas_udf, col import pandas as pd - orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) - try: + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df = self.spark.range(10, numPartitions=1) @pandas_udf(returnType=LongType()) @@ -4348,11 +4334,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): result = df.select(check_records_per_batch(col("id"))).collect() for (r,) in result: self.assertTrue(r <= 3) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - else: - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) def test_vectorized_udf_timestamps_respect_session_timezone(self): from pyspark.sql.functions import pandas_udf, col @@ -4371,30 +4352,27 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): internal_value = pandas_udf( lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_la = df_la.select(col("idx"), col("internal_value")).collect() - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 - result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ .withColumn("internal_value", internal_value(col("timestamp"))) result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() self.assertNotEqual(result_ny, result_la) self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations @@ -5170,9 +5148,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): def test_retain_group_columns(self): from pyspark.sql.functions import sum, lit, col - orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) - self.spark.conf.set("spark.sql.retainGroupColumns", False) - try: + with self.sql_conf({"spark.sql.retainGroupColumns": False}): df = self.data sum_udf = self.pandas_agg_sum_udf @@ -5180,12 +5156,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.retainGroupColumns") - else: - self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) - def test_invalid_args(self): from pyspark.sql.functions import mean --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
