bolkedebruin closed pull request #3708: [AIRFLOW-2859] Implement own UtcDateTime
URL: https://github.com/apache/incubator-airflow/pull/3708
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index e2001789d9..45b7903d3e 100644
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -1007,7 +1007,6 @@ def initdb(args):  # noqa
     print("Done.")
 
 
-@cli_utils.action_logging
 def resetdb(args):
     print("DB: " + repr(settings.engine.url))
     if args.yes or input("This will drop existing tables "
diff --git a/airflow/jobs.py b/airflow/jobs.py
index cc26feee53..e8ba437e0b 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -40,7 +40,6 @@
     Column, Integer, String, func, Index, or_, and_, not_)
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import make_transient
-from sqlalchemy_utc import UtcDateTime
 from tabulate import tabulate
 from time import sleep
 
@@ -52,6 +51,7 @@
 from airflow.task.task_runner import get_task_runner
 from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
 from airflow.utils import asciiart, helpers, timezone
+from airflow.utils.configuration import tmp_configuration_copy
 from airflow.utils.dag_processing import (AbstractDagFileProcessor,
                                           DagFileProcessorManager,
                                           SimpleDag,
@@ -60,9 +60,9 @@
 from airflow.utils.db import create_session, provide_session
 from airflow.utils.email import send_email
 from airflow.utils.log.logging_mixin import LoggingMixin, set_context, 
StreamLogWriter
-from airflow.utils.state import State
-from airflow.utils.configuration import tmp_configuration_copy
 from airflow.utils.net import get_hostname
+from airflow.utils.state import State
+from airflow.utils.sqlalchemy import UtcDateTime
 
 Base = models.Base
 ID_LEN = models.ID_LEN
diff --git a/airflow/models.py b/airflow/models.py
index 288bd4c937..e7d38ebd65 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -60,7 +60,6 @@
 from sqlalchemy import func, or_, and_, true as sqltrue
 from sqlalchemy.ext.declarative import declarative_base, declared_attr
 from sqlalchemy.orm import reconstructor, relationship, synonym
-from sqlalchemy_utc import UtcDateTime
 
 from croniter import croniter
 import six
@@ -88,6 +87,7 @@
     as_tuple, is_container, validate_key, pprinttable)
 from airflow.utils.operator_resources import Resources
 from airflow.utils.state import State
+from airflow.utils.sqlalchemy import UtcDateTime
 from airflow.utils.timeout import timeout
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index baddd9dcf1..76c112785f 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -22,15 +22,19 @@
 from __future__ import print_function
 from __future__ import unicode_literals
 
+import datetime
 import os
+import pendulum
 import time
 import random
 
 from sqlalchemy import event, exc, select
+from sqlalchemy.types import DateTime, TypeDecorator
 
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 log = LoggingMixin().log
+utc = pendulum.timezone('UTC')
 
 
 def setup_event_handlers(
@@ -101,13 +105,21 @@ def ping_connection(connection, branch):
     def connect(dbapi_connection, connection_record):
         connection_record.info['pid'] = os.getpid()
 
-    @event.listens_for(engine, "connect")
-    def set_sqlite_pragma(dbapi_connection, connection_record):
-        if 'sqlite3.Connection' in str(type(dbapi_connection)):
+    if engine.dialect.name == "sqlite":
+        @event.listens_for(engine, "connect")
+        def set_sqlite_pragma(dbapi_connection, connection_record):
             cursor = dbapi_connection.cursor()
             cursor.execute("PRAGMA foreign_keys=ON")
             cursor.close()
 
+    # this ensures sanity in mysql when storing datetimes (not required for 
postgres)
+    if engine.dialect.name == "mysql":
+        @event.listens_for(engine, "connect")
+        def set_mysql_timezone(dbapi_connection, connection_record):
+            cursor = dbapi_connection.cursor()
+            cursor.execute("SET time_zone = '+00:00'")
+            cursor.close()
+
     @event.listens_for(engine, "checkout")
     def checkout(dbapi_connection, connection_record, connection_proxy):
         pid = os.getpid()
@@ -117,3 +129,46 @@ def checkout(dbapi_connection, connection_record, 
connection_proxy):
                 "Connection record belongs to pid {}, "
                 "attempting to check out in pid 
{}".format(connection_record.info['pid'], pid)
             )
+
+
+class UtcDateTime(TypeDecorator):
+    """
+    Almost equivalent to :class:`~sqlalchemy.types.DateTime` with
+    ``timezone=True`` option, but it differs from that by:
+    - Never silently take naive :class:`~datetime.datetime`, instead it
+      always raise :exc:`ValueError` unless time zone aware value.
+    - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
+      is always converted to UTC.
+    - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.DateTime`,
+      it never return naive :class:`~datetime.datetime`, but time zone
+      aware value, even with SQLite or MySQL.
+    - Always returns DateTime in UTC
+    """
+
+    impl = DateTime(timezone=True)
+
+    def process_bind_param(self, value, dialect):
+        if value is not None:
+            if not isinstance(value, datetime.datetime):
+                raise TypeError('expected datetime.datetime, not ' +
+                                repr(value))
+            elif value.tzinfo is None:
+                raise ValueError('naive datetime is disallowed')
+
+            return value.astimezone(utc)
+
+    def process_result_value(self, value, dialect):
+        """
+        Processes DateTimes from the DB making sure it is always
+        returning UTC. Not using timezone.convert_to_utc as that
+        converts to configured TIMEZONE while the DB might be
+        running with some other setting. We assume UTC datetimes
+        in the database.
+        """
+        if value is not None:
+            if value.tzinfo is None:
+                value = value.replace(tzinfo=utc)
+            else:
+                value = value.astimezone(utc)
+
+        return value
diff --git a/run_unit_tests.sh b/run_unit_tests.sh
index 42c78f3439..27e4d08af1 100755
--- a/run_unit_tests.sh
+++ b/run_unit_tests.sh
@@ -8,9 +8,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -91,3 +91,4 @@ nosetests $nose_args
 
 # To run individual tests:
 # nosetests tests.core:CoreTest.test_scheduler_job
+
diff --git a/setup.py b/setup.py
index f350ff5b6f..b92d267aaa 100644
--- a/setup.py
+++ b/setup.py
@@ -322,7 +322,6 @@ def do_setup():
             'requests>=2.5.1, <3',
             'setproctitle>=1.1.8, <2',
             'sqlalchemy>=1.1.15, <1.2.0',
-            'sqlalchemy-utc>=0.9.0',
             'tabulate>=0.7.5, <0.8.0',
             'tenacity==4.8.0',
             'thrift>=0.9.2',
diff --git a/tests/core.py b/tests/core.py
index d336b1bd1f..6efcbfd360 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -38,7 +38,6 @@
 from email.mime.application import MIMEApplication
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-from freezegun import freeze_time
 from numpy.testing import assert_array_almost_equal
 from six.moves.urllib.parse import urlencode
 from time import sleep
@@ -69,6 +68,7 @@
 from airflow.configuration import AirflowConfigException, run_command
 from jinja2.sandbox import SecurityError
 from jinja2 import UndefinedError
+from pendulum import utcnow
 
 import six
 
@@ -261,7 +261,6 @@ def test_schedule_dag_start_end_dates(self):
 
         self.assertIsNone(additional_dag_run)
 
-    @freeze_time('2016-01-01')
     def test_schedule_dag_no_end_date_up_to_today_only(self):
         """
         Tests that a Dag created without an end_date can only be scheduled up
@@ -273,8 +272,11 @@ def test_schedule_dag_no_end_date_up_to_today_only(self):
         """
         session = settings.Session()
         delta = timedelta(days=1)
-        start_date = DEFAULT_DATE
-        runs = 365
+        now = utcnow()
+        start_date = now.subtract(weeks=1)
+
+        runs = (now - start_date).days
+
         dag = DAG(TEST_DAG_ID + 
'test_schedule_dag_no_end_date_up_to_today_only',
                   start_date=start_date,
                   schedule_interval=delta)
diff --git a/tests/test_utils/fake_datetime.py 
b/tests/test_utils/fake_datetime.py
index 42bb01df01..8182d83e9a 100644
--- a/tests/test_utils/fake_datetime.py
+++ b/tests/test_utils/fake_datetime.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
new file mode 100644
index 0000000000..66f00f9427
--- /dev/null
+++ b/tests/utils/test_sqlalchemy.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import datetime
+import unittest
+
+from airflow import settings
+from airflow.models import DAG
+from airflow.settings import Session
+from airflow.utils.state import State
+from airflow.utils.timezone import utcnow
+
+from sqlalchemy.exc import StatementError
+
+
+class TestSqlAlchemyUtils(unittest.TestCase):
+    def setUp(self):
+        session = Session()
+
+        # make sure NOT to run in UTC. Only postgres supports storing
+        # timezone information in the datetime field
+        if session.bind.dialect.name == "postgresql":
+            session.execute("SET timezone='Europe/Amsterdam'")
+
+        self.session = session
+
+    def test_utc_transformations(self):
+        """
+        Test whether what we are storing is what we are retrieving
+        for datetimes
+        """
+        dag_id = 'test_utc_transformations'
+        start_date = utcnow()
+        iso_date = start_date.isoformat()
+        execution_date = start_date + datetime.timedelta(hours=1, days=1)
+
+        dag = DAG(
+            dag_id=dag_id,
+            start_date=start_date,
+        )
+        dag.clear()
+
+        run = dag.create_dagrun(
+            run_id=iso_date,
+            state=State.NONE,
+            execution_date=execution_date,
+            start_date=start_date,
+            session=self.session,
+        )
+
+        self.assertEqual(execution_date, run.execution_date)
+        self.assertEqual(start_date, run.start_date)
+
+        self.assertEqual(execution_date.utcoffset().total_seconds(), 0.0)
+        self.assertEqual(start_date.utcoffset().total_seconds(), 0.0)
+
+        self.assertEqual(iso_date, run.run_id)
+        self.assertEqual(run.start_date.isoformat(), run.run_id)
+
+        dag.clear()
+
+    def test_process_bind_param_naive(self):
+        """
+        Check if naive datetimes are prevented from saving to the db
+        """
+        dag_id = 'test_process_bind_param_naive'
+
+        # naive
+        start_date = datetime.datetime.now()
+        dag = DAG(dag_id=dag_id, start_date=start_date)
+        dag.clear()
+
+        with self.assertRaises((ValueError, StatementError)):
+            dag.create_dagrun(
+                run_id=start_date.isoformat,
+                state=State.NONE,
+                execution_date=start_date,
+                start_date=start_date,
+                session=self.session
+            )
+        dag.clear()
+
+    def tearDown(self):
+        self.session.close()
+        settings.engine.dispose()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to