Repository: spark Updated Branches: refs/heads/master 33d43bf1b -> 64817c423
[SPARK-22395][SQL][PYTHON] Fix the behavior of timestamp values for Pandas to respect session timezone ## What changes were proposed in this pull request? When converting Pandas DataFrame/Series from/to Spark DataFrame using `toPandas()` or pandas udfs, timestamp values behave to respect Python system timezone instead of session timezone. For example, let's say we use `"America/Los_Angeles"` as session timezone and have a timestamp value `"1970-01-01 00:00:01"` in the timezone. Btw, I'm in Japan so Python timezone would be `"Asia/Tokyo"`. The timestamp value from current `toPandas()` will be the following: ``` >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([28801], "long").selectExpr("timestamp(value) as >>> ts") >>> df.show() +-------------------+ | ts| +-------------------+ |1970-01-01 00:00:01| +-------------------+ >>> df.toPandas() ts 0 1970-01-01 17:00:01 ``` As you can see, the value becomes `"1970-01-01 17:00:01"` because it respects Python timezone. As we discussed in #18664, we consider this behavior is a bug and the value should be `"1970-01-01 00:00:01"`. ## How was this patch tested? Added tests and existing tests. Author: Takuya UESHIN <ues...@databricks.com> Closes #19607 from ueshin/issues/SPARK-22395. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/64817c42 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/64817c42 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/64817c42 Branch: refs/heads/master Commit: 64817c423c0d82a805abd69a3e166e5bfd79c739 Parents: 33d43bf Author: Takuya UESHIN <ues...@databricks.com> Authored: Tue Nov 28 16:45:22 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Nov 28 16:45:22 2017 +0800 ---------------------------------------------------------------------- docs/sql-programming-guide.md | 2 + python/pyspark/serializers.py | 13 +- python/pyspark/sql/dataframe.py | 24 ++- python/pyspark/sql/session.py | 52 ++++- python/pyspark/sql/tests.py | 214 +++++++++++++++++-- python/pyspark/sql/types.py | 87 +++++++- python/pyspark/worker.py | 3 +- python/setup.py | 2 +- .../org/apache/spark/sql/internal/SQLConf.scala | 11 + .../execution/python/ArrowEvalPythonExec.scala | 4 +- .../execution/python/ArrowPythonRunner.scala | 8 +- .../python/FlatMapGroupsInPandasExec.scala | 4 +- 12 files changed, 371 insertions(+), 53 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/docs/sql-programming-guide.md ---------------------------------------------------------------------- diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5f98213..983770d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1716,6 +1716,8 @@ options. </table> Note that, for <b>DecimalType(38,0)*</b>, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. + - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. + - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. ## Upgrading From Spark SQL 2.1 to 2.2 http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/python/pyspark/serializers.py ---------------------------------------------------------------------- diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b95de2c..37e7cf3 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -206,11 +206,12 @@ class ArrowSerializer(FramedSerializer): return "ArrowSerializer" -def _create_batch(series): +def _create_batch(series, timezone): """ Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :param timezone: A timezone to respect when handling timestamp values :return: Arrow RecordBatch """ @@ -227,7 +228,7 @@ def _create_batch(series): def cast_series(s, t): if type(t) == pa.TimestampType: # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 - return _check_series_convert_timestamps_internal(s.fillna(0))\ + return _check_series_convert_timestamps_internal(s.fillna(0), timezone)\ .values.astype('datetime64[us]', copy=False) # NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1 elif t is not None and t == pa.date32(): @@ -253,6 +254,10 @@ class ArrowStreamPandasSerializer(Serializer): Serializes Pandas.Series as Arrow data with Arrow streaming format. """ + def __init__(self, timezone): + super(ArrowStreamPandasSerializer, self).__init__() + self._timezone = timezone + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or @@ -262,7 +267,7 @@ class ArrowStreamPandasSerializer(Serializer): writer = None try: for series in iterator: - batch = _create_batch(series) + batch = _create_batch(series, self._timezone) if writer is None: write_int(SpecialLengths.START_ARROW_STREAM, stream) writer = pa.RecordBatchStreamWriter(stream, batch.schema) @@ -280,7 +285,7 @@ class ArrowStreamPandasSerializer(Serializer): reader = pa.open_stream(stream) for batch in reader: # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 - pdf = _check_dataframe_localize_timestamps(batch.to_pandas()) + pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone) yield [c for _, c in pdf.iteritems()] def __repr__(self): http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 406686e..9864dc9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -39,6 +39,7 @@ from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.streaming import DataStreamWriter from pyspark.sql.types import IntegralType from pyspark.sql.types import * +from pyspark.util import _exception_message __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -1881,6 +1882,13 @@ class DataFrame(object): 1 5 Bob """ import pandas as pd + + if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ + == "true": + timezone = self.sql_ctx.getConf("spark.sql.session.timeZone") + else: + timezone = None + if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_localize_timestamps @@ -1889,13 +1897,13 @@ class DataFrame(object): if tables: table = pyarrow.concat_tables(tables) pdf = table.to_pandas() - return _check_dataframe_localize_timestamps(pdf) + return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: msg = "note: pyarrow must be installed and available on calling Python process " \ "if using spark.sql.execution.arrow.enabled=true" - raise ImportError("%s\n%s" % (e.message, msg)) + raise ImportError("%s\n%s" % (_exception_message(e), msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) @@ -1913,7 +1921,17 @@ class DataFrame(object): for f, t in dtype.items(): pdf[f] = pdf[f].astype(t, copy=False) - return pdf + + if timezone is None: + return pdf + else: + from pyspark.sql.types import _check_series_convert_timestamps_local_tz + for field in self.schema: + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/python/pyspark/sql/session.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 47c58bb..e2435e0 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -34,8 +34,9 @@ from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string +from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \ + _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ + _parse_datatype_string from pyspark.sql.utils import install_exception_handler __all__ = ["SparkSession"] @@ -444,11 +445,34 @@ class SparkSession(object): record_type_list.append((str(col_names[i]), curr_type)) return np.dtype(record_type_list) if has_rec_fix else None - def _convert_from_pandas(self, pdf): + def _convert_from_pandas(self, pdf, schema, timezone): """ Convert a pandas.DataFrame to list of records that can be used to make a DataFrame :return list of records """ + if timezone is not None: + from pyspark.sql.types import _check_series_convert_timestamps_tz_local + copied = False + if isinstance(schema, StructType): + for field in schema: + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) + if not copied and s is not pdf[field.name]: + # Copy once if the series is modified to prevent the original Pandas + # DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s + else: + for column, series in pdf.iteritems(): + s = _check_series_convert_timestamps_tz_local(pdf[column], timezone) + if not copied and s is not pdf[column]: + # Copy once if the series is modified to prevent the original Pandas + # DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) @@ -462,15 +486,19 @@ class SparkSession(object): # Convert list of numpy records to python lists return [r.tolist() for r in np_records] - def _create_from_pandas_with_arrow(self, pdf, schema): + def _create_from_pandas_with_arrow(self, pdf, schema, timezone): """ Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ from pyspark.serializers import ArrowSerializer, _create_batch - from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + from pyspark.sql.types import from_arrow_schema, to_arrow_type, \ + _old_pandas_exception_message, TimestampType + try: + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + except ImportError as e: + raise ImportError(_old_pandas_exception_message(e)) # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): @@ -488,7 +516,8 @@ class SparkSession(object): pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) # Create Arrow record batches - batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]) + batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], + timezone) for pdf_slice in pdf_slices] # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) @@ -606,6 +635,11 @@ class SparkSession(object): except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): + if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ + == "true": + timezone = self.conf.get("spark.sql.session.timeZone") + else: + timezone = None # If no schema supplied by user then get the names of columns only if schema is None: @@ -614,11 +648,11 @@ class SparkSession(object): if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: try: - return self._create_from_pandas_with_arrow(data, schema) + return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) # Fallback to create DataFrame without arrow if raise some exception - data = self._convert_from_pandas(data) + data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 762afe0..b4d32d8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -49,9 +49,14 @@ else: import unittest _have_pandas = False +_have_old_pandas = False try: import pandas - _have_pandas = True + try: + import pandas.api + _have_pandas = True + except: + _have_old_pandas = True except: # No Pandas, but that's okay, we'll skip those tests pass @@ -2565,21 +2570,38 @@ class SQLTests(ReusedSQLTestCase): .mode("overwrite").saveAsTable("pyspark_bucket")) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - @unittest.skipIf(not _have_pandas, "Pandas not installed") - def test_to_pandas(self): + def _to_pandas(self): + from datetime import datetime, date import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ - .add("c", BooleanType()).add("d", FloatType()) + .add("c", BooleanType()).add("d", FloatType())\ + .add("dt", DateType()).add("ts", TimestampType()) data = [ - (1, "foo", True, 3.0), (2, "foo", True, 5.0), - (3, "bar", False, -1.0), (4, "bar", False, 6.0), + (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (2, "foo", True, 5.0, None, None), + (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)), + (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)), ] df = self.spark.createDataFrame(data, schema) - types = df.toPandas().dtypes + return df.toPandas() + + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_to_pandas(self): + import numpy as np + pdf = self._to_pandas() + types = pdf.dtypes self.assertEquals(types[0], np.int32) self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) + self.assertEquals(types[4], 'datetime64[ns]') + self.assertEquals(types[5], 'datetime64[ns]') + + @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") + def test_to_pandas_old(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + self._to_pandas() @unittest.skipIf(not _have_pandas, "Pandas not installed") def test_to_pandas_avoid_astype(self): @@ -2614,6 +2636,16 @@ class SQLTests(ReusedSQLTestCase): self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") + def test_create_dataframe_from_old_pandas(self): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) + with QuietTest(self.sc): + with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + self.spark.createDataFrame(pdf) + class HiveSparkSubmitTests(SparkSubmitTests): @@ -3103,7 +3135,7 @@ class DataTypeVerificationTests(unittest.TestCase): _make_type_verifier(data_type, nullable=False)(obj) -@unittest.skipIf(not _have_arrow, "Arrow not installed") +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class ArrowTests(ReusedSQLTestCase): @classmethod @@ -3169,16 +3201,47 @@ class ArrowTests(ReusedSQLTestCase): null_counts = pdf.isnull().sum().tolist() self.assertTrue(all([c == 1 for c in null_counts])) - def test_toPandas_arrow_toggle(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) + def _toPandas_arrow_toggle(self, df): self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") try: pdf = df.toPandas() finally: self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() + return pdf, pdf_arrow + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) self.assertFramesEqual(pdf_arrow, pdf) + def test_toPandas_respect_session_timezone(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + orig_tz = self.spark.conf.get("spark.sql.session.timeZone") + try: + timezone = "America/New_York" + self.spark.conf.set("spark.sql.session.timeZone", timezone) + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") + try: + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertFramesEqual(pdf_arrow_la, pdf_la) + finally: + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) + self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + + self.assertFalse(pdf_ny.equals(pdf_la)) + + from pyspark.sql.types import _check_series_convert_timestamps_local_tz + pdf_la_corrected = pdf_la.copy() + for field in self.schema: + if isinstance(field.dataType, TimestampType): + pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( + pdf_la_corrected[field.name], timezone) + self.assertFramesEqual(pdf_ny, pdf_la_corrected) + finally: + self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -3192,16 +3255,50 @@ class ArrowTests(ReusedSQLTestCase): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) - def test_createDataFrame_toggle(self): - pdf = self.create_pandas_data_frame() + def _createDataFrame_toggle(self, pdf, schema=None): self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") try: - df_no_arrow = self.spark.createDataFrame(pdf) + df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) finally: self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") - df_arrow = self.spark.createDataFrame(pdf) + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + return df_no_arrow, df_arrow + + def test_createDataFrame_toggle(self): + pdf = self.create_pandas_data_frame() + df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema) self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) + def test_createDataFrame_respect_session_timezone(self): + from datetime import timedelta + pdf = self.create_pandas_data_frame() + orig_tz = self.spark.conf.get("spark.sql.session.timeZone") + try: + timezone = "America/New_York" + self.spark.conf.set("spark.sql.session.timeZone", timezone) + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") + try: + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + finally: + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) + result_ny = df_no_arrow_ny.collect() + result_arrow_ny = df_arrow_ny.collect() + self.assertEqual(result_ny, result_arrow_ny) + + self.assertNotEqual(result_ny, result_la) + + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v + for k, v in row.asDict().items()}) + for row in result_la] + self.assertEqual(result_ny, result_la_corrected) + finally: + self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_createDataFrame_with_schema(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(pdf, schema=self.schema) @@ -3385,6 +3482,27 @@ class PandasUDFTests(ReusedSQLTestCase): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class VectorizedUDFTests(ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3621,29 +3739,37 @@ class VectorizedUDFTests(ReusedSQLTestCase): data = [(0, datetime(1969, 1, 1, 1, 1, 1)), (1, datetime(2012, 2, 2, 2, 2, 2)), (2, None), - (3, datetime(2100, 4, 4, 4, 4, 4))] + (3, datetime(2100, 3, 3, 3, 3, 3))] + df = self.spark.createDataFrame(data, schema=schema) # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) - @pandas_udf(returnType=BooleanType()) + @pandas_udf(returnType=StringType()) def check_data(idx, timestamp, timestamp_copy): + import pandas as pd + msgs = [] is_equal = timestamp.isnull() # use this array to check values are equal for i in range(len(idx)): # Check that timestamps are as expected in the UDF - is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \ - timestamp[i].to_pydatetime() == data[idx[i]][1] - return is_equal - - result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"), - col("timestamp_copy"))).collect() + if (is_equal[i] and data[idx[i]][1] is None) or \ + timestamp[i].to_pydatetime() == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')" + % (timestamp[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"), + col("timestamp_copy"))).collect() # Check that collection values are correct self.assertEquals(len(data), len(result)) for i in range(len(result)): self.assertEquals(data[i][1], result[i][1]) # "timestamp" col - self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected + self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): from pyspark.sql.functions import pandas_udf, col @@ -3683,6 +3809,48 @@ class VectorizedUDFTests(ReusedSQLTestCase): else: self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) + def test_vectorized_udf_timestamps_respect_session_timezone(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import datetime + import pandas as pd + schema = StructType([ + StructField("idx", LongType(), True), + StructField("timestamp", TimestampType(), True)]) + data = [(1, datetime(1969, 1, 1, 1, 1, 1)), + (2, datetime(2012, 2, 2, 2, 2, 2)), + (3, None), + (4, datetime(2100, 3, 3, 3, 3, 3))] + df = self.spark.createDataFrame(data, schema=schema) + + f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType()) + internal_value = pandas_udf( + lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) + + orig_tz = self.spark.conf.get("spark.sql.session.timeZone") + try: + timezone = "America/New_York" + self.spark.conf.set("spark.sql.session.timeZone", timezone) + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") + try: + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + finally: + self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + + df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() + + self.assertNotEqual(result_ny, result_la) + self.assertEqual(result_ny, result_la_corrected) + finally: + self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/python/pyspark/sql/types.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fe62f60..78abc32 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -35,6 +35,7 @@ from py4j.java_gateway import JavaClass from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer +from pyspark.util import _exception_message __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", @@ -1678,37 +1679,105 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _check_dataframe_localize_timestamps(pdf): +def _old_pandas_exception_message(e): + """ Create an error message for importing old Pandas. """ - Convert timezone aware timestamps to timezone-naive in local time + msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process" + return "%s\n%s" % (_exception_message(e), msg) + + +def _check_dataframe_localize_timestamps(pdf, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone :param pdf: pandas.DataFrame - :return pandas.DataFrame where any timezone aware columns have be converted to tz-naive + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.DataFrame where any timezone aware columns have been converted to tz-naive """ - from pandas.api.types import is_datetime64tz_dtype + try: + from pandas.api.types import is_datetime64tz_dtype + except ImportError as e: + raise ImportError(_old_pandas_exception_message(e)) + tz = timezone or 'tzlocal()' for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(series.dtype): - pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) + pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) return pdf -def _check_series_convert_timestamps_internal(s): +def _check_series_convert_timestamps_internal(s, timezone): """ - Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage + Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for + Spark internal storage + :param s: a pandas.Series + :param timezone: the timezone to convert. if None then use local timezone :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone """ - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + try: + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + except ImportError as e: + raise ImportError(_old_pandas_exception_message(e)) # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): - return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') + tz = timezone or 'tzlocal()' + return s.dt.tz_localize(tz).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') else: return s +def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param from_timezone: the timezone to convert from. if None then use local timezone + :param to_timezone: the timezone to convert to. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + try: + import pandas as pd + from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype + except ImportError as e: + raise ImportError(_old_pandas_exception_message(e)) + from_tz = from_timezone or 'tzlocal()' + to_tz = to_timezone or 'tzlocal()' + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(to_tz).dt.tz_localize(None) + elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: + # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. + return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None) + if ts is not pd.NaT else pd.NaT) + else: + return s + + +def _check_series_convert_timestamps_local_tz(s, timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param timezone: the timezone to convert to. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + return _check_series_convert_timestamps_localize(s, None, timezone) + + +def _check_series_convert_timestamps_tz_local(s, timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param timezone: the timezone to convert from. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + return _check_series_convert_timestamps_localize(s, timezone, None) + + def _test(): import doctest from pyspark.context import SparkContext http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/python/pyspark/worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9396430..e6737ae 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -150,7 +150,8 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - ser = ArrowStreamPandasSerializer() + timezone = utf8_deserializer.loads(infile) + ser = ArrowStreamPandasSerializer(timezone) else: ser = BatchedSerializer(PickleSerializer(), 100) http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/python/setup.py ---------------------------------------------------------------------- diff --git a/python/setup.py b/python/setup.py index 02612ff..310670e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ try: extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.13.0'] + 'sql': ['pandas>=0.19.2'] }, classifiers=[ 'Development Status :: 5 - Production/Stable', http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ce68dbb..8abb426 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -998,6 +998,15 @@ object SQLConf { .intConf .createWithDefault(10000) + val PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE = + buildConf("spark.sql.execution.pandas.respectSessionTimeZone") + .internal() + .doc("When true, make Pandas DataFrame with timestamp type respecting session local " + + "timezone when converting to/from Pandas DataFrame. This configuration will be " + + "deprecated in the future releases.") + .booleanConf + .createWithDefault(true) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -1316,6 +1325,8 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) /** ********************** SQLConf functionality methods ************ */ http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index e272101..c06bc7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -63,6 +63,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -81,7 +82,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 94c05b9..9a94d77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -44,7 +44,8 @@ class ArrowPythonRunner( evalType: Int, argOffsets: Array[Array[Int]], schema: StructType, - timeZoneId: String) + timeZoneId: String, + respectTimeZone: Boolean) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -58,6 +59,11 @@ class ArrowPythonRunner( protected override def writeCommand(dataOut: DataOutputStream): Unit = { PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + if (respectTimeZone) { + PythonRDD.writeUTF(timeZoneId, dataOut) + } else { + dataOut.writeInt(SpecialLengths.NULL) + } } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/64817c42/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index ee49581..59db66b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -78,6 +78,7 @@ case class FlatMapGroupsInPandasExec( val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { @@ -95,7 +96,8 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org