This is an automated email from the ASF dual-hosted git repository.

turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 297e34a  Add log of affected sql rows in PostgresOperator (#9841)
297e34a is described below

commit 297e34afa0797dcc6e78d7d55c51c0f63818a166
Author: Johan Eklund <[email protected]>
AuthorDate: Fri Jul 17 08:06:07 2020 +0100

    Add log of affected sql rows in PostgresOperator (#9841)
    
    Co-authored-by: Johan Eklund <[email protected]>
    Co-authored-by: Tomek Urbaszek <[email protected]>
---
 airflow/hooks/dbapi_hook.py                     |  8 +++++---
 tests/hooks/test_dbapi_hook.py                  |  6 ++++++
 tests/providers/postgres/hooks/test_postgres.py | 13 +++++++++++++
 tests/providers/sqlite/hooks/test_sqlite.py     |  6 ++++++
 4 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index 5e4cbec..c5bcbd0 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -181,12 +181,14 @@ class DbApiHook(BaseHook):
 
             with closing(conn.cursor()) as cur:
                 for sql_statement in sql:
-                    if parameters is not None:
-                        self.log.info("%s with parameters %s", sql_statement, 
parameters)
+
+                    self.log.info("Running statement: %s, parameters: %s", 
sql_statement, parameters)
+                    if parameters:
                         cur.execute(sql_statement, parameters)
                     else:
-                        self.log.info(sql_statement)
                         cur.execute(sql_statement)
+                    if hasattr(cur, 'rowcount'):
+                        self.log.info("Rows affected: %s", cur.rowcount)
 
             # If autocommit was set to False for db that supports autocommit,
             # or if db does not supports autocommit, we do a manual commit.
diff --git a/tests/hooks/test_dbapi_hook.py b/tests/hooks/test_dbapi_hook.py
index 1365742..f8c63da 100644
--- a/tests/hooks/test_dbapi_hook.py
+++ b/tests/hooks/test_dbapi_hook.py
@@ -36,6 +36,7 @@ class TestDbApiHook(unittest.TestCase):
 
         class UnitTestDbApiHook(DbApiHook):
             conn_name_attr = 'test_conn_id'
+            log = mock.MagicMock()
 
             def get_conn(self):
                 return conn
@@ -171,3 +172,8 @@ class TestDbApiHook(unittest.TestCase):
             port=1
         ))
         self.assertEqual("conn_type://login:password@host:1/", 
self.db_hook.get_uri())
+
+    def test_run_log(self):
+        statement = 'SQL'
+        self.db_hook.run(statement)
+        assert self.db_hook.log.info.call_count == 2
diff --git a/tests/providers/postgres/hooks/test_postgres.py 
b/tests/providers/postgres/hooks/test_postgres.py
index 4121ed0..cac19b2 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -258,3 +258,16 @@ class TestPostgresHook(unittest.TestCase):
                 (2, "world",)]
         fields = ("id", "value")
         self.db_hook.insert_rows(table, rows, fields, replace=True)
+
+    @pytest.mark.backend("postgres")
+    def test_rowcount(self):
+        hook = PostgresHook()
+        input_data = ["foo", "bar", "baz"]
+
+        with hook.get_conn() as conn:
+            with conn.cursor() as cur:
+                cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table))
+                values = ",".join("('{}')".format(data) for data in input_data)
+                cur.execute("INSERT INTO {} VALUES {}".format(self.table, 
values))
+                conn.commit()
+                self.assertEqual(cur.rowcount, len(input_data))
diff --git a/tests/providers/sqlite/hooks/test_sqlite.py 
b/tests/providers/sqlite/hooks/test_sqlite.py
index 00508ee..8fdcaae 100644
--- a/tests/providers/sqlite/hooks/test_sqlite.py
+++ b/tests/providers/sqlite/hooks/test_sqlite.py
@@ -55,6 +55,7 @@ class TestSqliteHook(unittest.TestCase):
 
         class UnitTestSqliteHook(SqliteHook):
             conn_name_attr = 'test_conn_id'
+            log = mock.MagicMock()
 
             def get_conn(self):
                 return conn
@@ -95,3 +96,8 @@ class TestSqliteHook(unittest.TestCase):
         self.assertEqual(result_sets[1][0], df.values.tolist()[1][0])
 
         self.cur.execute.assert_called_once_with(statement)
+
+    def test_run_log(self):
+        statement = 'SQL'
+        self.db_hook.run(statement)
+        assert self.db_hook.log.info.call_count == 2

Reply via email to