Repository: incubator-airflow Updated Branches: refs/heads/master d62a03767 -> ef8a6ca4e
[AIRFLOW-2605] Fix autocommit for MySqlHook Closes #3493 from yrqls21/kevin_yang_fix_mysql_autocommit Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/ef8a6ca4 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/ef8a6ca4 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/ef8a6ca4 Branch: refs/heads/master Commit: ef8a6ca4e43ba41664443bf83ec3982e8effcdb7 Parents: d62a037 Author: Kevin Yang <[email protected]> Authored: Fri Jun 15 13:27:22 2018 +0200 Committer: Fokko Driesprong <[email protected]> Committed: Fri Jun 15 13:27:22 2018 +0200 ---------------------------------------------------------------------- airflow/hooks/dbapi_hook.py | 16 ++++++- airflow/hooks/mysql_hook.py | 16 +++++++ tests/hooks/test_mysql_hook.py | 87 +++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/ef8a6ca4/airflow/hooks/dbapi_hook.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index 05d2084..358360d 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -171,7 +171,7 @@ class DbApiHook(BaseHook): # If autocommit was set to False for db that supports autocommit, # or if db does not supports autocommit, we do a manual commit. - if not getattr(conn, 'autocommit', False): + if not self.get_autocommit(conn): conn.commit() def set_autocommit(self, conn, autocommit): @@ -185,6 +185,20 @@ class DbApiHook(BaseHook): getattr(self, self.conn_name_attr)) conn.autocommit = autocommit + def get_autocommit(self, conn): + """ + Get autocommit setting for the provided connection. + Return True if conn.autocommit is set to True. + Return False if conn.autocommit is not set or set to False or conn + does not support autocommit. + :param conn: Connection to get autocommit setting from. + :type conn: connection object. + :return: connection autocommit setting. + :rtype bool. + """ + + return getattr(conn, 'autocommit', False) and self.supports_autocommit + def get_cursor(self): """ Returns a cursor http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/ef8a6ca4/airflow/hooks/mysql_hook.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py index bd41733..c02c0f4 100644 --- a/airflow/hooks/mysql_hook.py +++ b/airflow/hooks/mysql_hook.py @@ -40,6 +40,22 @@ class MySqlHook(DbApiHook): super(MySqlHook, self).__init__(*args, **kwargs) self.schema = kwargs.pop("schema", None) + def set_autocommit(self, conn, autocommit): + """ + MySql connection sets autocommit in a different way. + """ + conn.autocommit(autocommit) + + def get_autocommit(self, conn): + """ + MySql connection gets autocommit in a different way. + :param conn: connection to get autocommit setting from. + :type conn: connection object. + :return: connection autocommit setting + :rtype bool + """ + return conn.get_autocommit() + def get_conn(self): """ Returns a mysql connection object http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/ef8a6ca4/tests/hooks/test_mysql_hook.py ---------------------------------------------------------------------- diff --git a/tests/hooks/test_mysql_hook.py b/tests/hooks/test_mysql_hook.py new file mode 100644 index 0000000..d112f88 --- /dev/null +++ b/tests/hooks/test_mysql_hook.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import mock +import unittest + +from airflow.hooks.mysql_hook import MySqlHook + + +class TestMySqlHook(unittest.TestCase): + + def setUp(self): + super(TestMySqlHook, self).setUp() + + self.cur = mock.MagicMock() + self.conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + conn = self.conn + + class SubMySqlHook(MySqlHook): + conn_name_attr = 'test_conn_id' + + def get_conn(self): + return conn + + self.db_hook = SubMySqlHook() + + def test_set_autocommit(self): + autocommit = True + self.db_hook.set_autocommit(self.conn, autocommit) + + self.conn.autocommit.assert_called_once_with(autocommit) + + def test_run_without_autocommit(self): + sql = 'SQL' + self.conn.get_autocommit.return_value = False + + # Default autocommit setting should be False. + # Testing default autocommit value as well as run() behavior. + self.db_hook.run(sql, autocommit=False) + self.conn.autocommit.assert_called_once_with(False) + self.cur.execute.assert_called_once_with(sql) + self.conn.commit.assert_called_once() + + def test_run_with_autocommit(self): + sql = 'SQL' + self.db_hook.run(sql, autocommit=True) + self.conn.autocommit.assert_called_once_with(True) + self.cur.execute.assert_called_once_with(sql) + self.conn.commit.assert_not_called() + + def test_run_with_parameters(self): + sql = 'SQL' + parameters = ('param1', 'param2') + self.db_hook.run(sql, autocommit=True, parameters=parameters) + self.conn.autocommit.assert_called_once_with(True) + self.cur.execute.assert_called_once_with(sql, parameters) + self.conn.commit.assert_not_called() + + def test_run_multi_queries(self): + sql = ['SQL1', 'SQL2'] + self.db_hook.run(sql, autocommit=True) + self.conn.autocommit.assert_called_once_with(True) + for i in range(len(self.cur.execute.call_args_list)): + args, kwargs = self.cur.execute.call_args_list[i] + self.assertEqual(len(args), 1) + self.assertEqual(args[0], sql[i]) + self.assertEqual(kwargs, {}) + self.cur.execute.assert_called_with(sql[1]) + self.conn.commit.assert_not_called()
