Repository: incubator-airflow
Updated Branches:
  refs/heads/master d62a03767 -> ef8a6ca4e


[AIRFLOW-2605] Fix autocommit for MySqlHook

Closes #3493 from
yrqls21/kevin_yang_fix_mysql_autocommit


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/ef8a6ca4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/ef8a6ca4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/ef8a6ca4

Branch: refs/heads/master
Commit: ef8a6ca4e43ba41664443bf83ec3982e8effcdb7
Parents: d62a037
Author: Kevin Yang <[email protected]>
Authored: Fri Jun 15 13:27:22 2018 +0200
Committer: Fokko Driesprong <[email protected]>
Committed: Fri Jun 15 13:27:22 2018 +0200

----------------------------------------------------------------------
 airflow/hooks/dbapi_hook.py    | 16 ++++++-
 airflow/hooks/mysql_hook.py    | 16 +++++++
 tests/hooks/test_mysql_hook.py | 87 +++++++++++++++++++++++++++++++++++++
 3 files changed, 118 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/ef8a6ca4/airflow/hooks/dbapi_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index 05d2084..358360d 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -171,7 +171,7 @@ class DbApiHook(BaseHook):
 
             # If autocommit was set to False for db that supports autocommit,
             # or if db does not supports autocommit, we do a manual commit.
-            if not getattr(conn, 'autocommit', False):
+            if not self.get_autocommit(conn):
                 conn.commit()
 
     def set_autocommit(self, conn, autocommit):
@@ -185,6 +185,20 @@ class DbApiHook(BaseHook):
                 getattr(self, self.conn_name_attr))
         conn.autocommit = autocommit
 
+    def get_autocommit(self, conn):
+        """
+        Get autocommit setting for the provided connection.
+        Return True if conn.autocommit is set to True.
+        Return False if conn.autocommit is not set or set to False or conn
+        does not support autocommit.
+        :param conn: Connection to get autocommit setting from.
+        :type conn: connection object.
+        :return: connection autocommit setting.
+        :rtype bool.
+        """
+
+        return getattr(conn, 'autocommit', False) and self.supports_autocommit
+
     def get_cursor(self):
         """
         Returns a cursor

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/ef8a6ca4/airflow/hooks/mysql_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py
index bd41733..c02c0f4 100644
--- a/airflow/hooks/mysql_hook.py
+++ b/airflow/hooks/mysql_hook.py
@@ -40,6 +40,22 @@ class MySqlHook(DbApiHook):
         super(MySqlHook, self).__init__(*args, **kwargs)
         self.schema = kwargs.pop("schema", None)
 
+    def set_autocommit(self, conn, autocommit):
+        """
+        MySql connection sets autocommit in a different way.
+        """
+        conn.autocommit(autocommit)
+
+    def get_autocommit(self, conn):
+        """
+        MySql connection gets autocommit in a different way.
+        :param conn: connection to get autocommit setting from.
+        :type conn: connection object.
+        :return: connection autocommit setting
+        :rtype bool
+        """
+        return conn.get_autocommit()
+
     def get_conn(self):
         """
         Returns a mysql connection object

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/ef8a6ca4/tests/hooks/test_mysql_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_mysql_hook.py b/tests/hooks/test_mysql_hook.py
new file mode 100644
index 0000000..d112f88
--- /dev/null
+++ b/tests/hooks/test_mysql_hook.py
@@ -0,0 +1,87 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import mock
+import unittest
+
+from airflow.hooks.mysql_hook import MySqlHook
+
+
+class TestMySqlHook(unittest.TestCase):
+
+    def setUp(self):
+        super(TestMySqlHook, self).setUp()
+
+        self.cur = mock.MagicMock()
+        self.conn = mock.MagicMock()
+        self.conn.cursor.return_value = self.cur
+        conn = self.conn
+
+        class SubMySqlHook(MySqlHook):
+            conn_name_attr = 'test_conn_id'
+
+            def get_conn(self):
+                return conn
+
+        self.db_hook = SubMySqlHook()
+
+    def test_set_autocommit(self):
+        autocommit = True
+        self.db_hook.set_autocommit(self.conn, autocommit)
+
+        self.conn.autocommit.assert_called_once_with(autocommit)
+
+    def test_run_without_autocommit(self):
+        sql = 'SQL'
+        self.conn.get_autocommit.return_value = False
+
+        # Default autocommit setting should be False.
+        # Testing default autocommit value as well as run() behavior.
+        self.db_hook.run(sql, autocommit=False)
+        self.conn.autocommit.assert_called_once_with(False)
+        self.cur.execute.assert_called_once_with(sql)
+        self.conn.commit.assert_called_once()
+
+    def test_run_with_autocommit(self):
+        sql = 'SQL'
+        self.db_hook.run(sql, autocommit=True)
+        self.conn.autocommit.assert_called_once_with(True)
+        self.cur.execute.assert_called_once_with(sql)
+        self.conn.commit.assert_not_called()
+
+    def test_run_with_parameters(self):
+        sql = 'SQL'
+        parameters = ('param1', 'param2')
+        self.db_hook.run(sql, autocommit=True, parameters=parameters)
+        self.conn.autocommit.assert_called_once_with(True)
+        self.cur.execute.assert_called_once_with(sql, parameters)
+        self.conn.commit.assert_not_called()
+
+    def test_run_multi_queries(self):
+        sql = ['SQL1', 'SQL2']
+        self.db_hook.run(sql, autocommit=True)
+        self.conn.autocommit.assert_called_once_with(True)
+        for i in range(len(self.cur.execute.call_args_list)):
+            args, kwargs = self.cur.execute.call_args_list[i]
+            self.assertEqual(len(args), 1)
+            self.assertEqual(args[0], sql[i])
+            self.assertEqual(kwargs, {})
+        self.cur.execute.assert_called_with(sql[1])
+        self.conn.commit.assert_not_called()

Reply via email to