jhtimmins commented on a change in pull request #9029:
URL: https://github.com/apache/airflow/pull/9029#discussion_r434267828



##########
File path: tests/providers/apache/hive/hooks/test_hive.py
##########
@@ -454,34 +608,91 @@ def test_get_conn_with_password(self, mock_connect):
                 database='default')
 
     def test_get_records(self):
-        hook = HiveServer2Hook()
+        hook = MockHiveServer2Hook()
         query = "SELECT * FROM {}".format(self.table)
-        results = hook.get_records(query, schema=self.database)
+
+        with mock.patch.dict('os.environ', {
+            'AIRFLOW_CTX_DAG_ID': 'test_dag_id',
+            'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835',
+            'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00',
+            'AIRFLOW_CTX_DAG_RUN_ID': '55',
+            'AIRFLOW_CTX_DAG_OWNER': 'airflow',
+            'AIRFLOW_CTX_DAG_EMAIL': '[email protected]',
+        }):
+            results = hook.get_records(query, schema=self.database)
+
         self.assertListEqual(results, [(1, 1), (2, 2)])
 
+        hook.get_conn.assert_called_with(self.database)
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.dag_id=test_dag_id')
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.task_id=HiveHook_3835')
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00')
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.dag_run_id=55')
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.dag_owner=airflow')
+        hook.mock_cursor.execute.assert_any_call(
+            'set [email protected]')
+
     def test_get_pandas_df(self):
-        hook = HiveServer2Hook()
+        hook = MockHiveServer2Hook()
         query = "SELECT * FROM {}".format(self.table)
-        df = hook.get_pandas_df(query, schema=self.database)
+
+        with mock.patch.dict('os.environ', {
+            'AIRFLOW_CTX_DAG_ID': 'test_dag_id',
+            'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835',
+            'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00',
+            'AIRFLOW_CTX_DAG_RUN_ID': '55',
+            'AIRFLOW_CTX_DAG_OWNER': 'airflow',
+            'AIRFLOW_CTX_DAG_EMAIL': '[email protected]',
+        }):
+            df = hook.get_pandas_df(query, schema=self.database)
+
         self.assertEqual(len(df), 2)
-        self.assertListEqual(df.columns.tolist(), self.columns)
-        self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
+        self.assertListEqual(df["hive_server_hook.a"].values.tolist(), [1, 2])
+
+        hook.get_conn.assert_called_with(self.database)
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.dag_id=test_dag_id')
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.task_id=HiveHook_3835')
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00')
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.dag_run_id=55')
+        hook.mock_cursor.execute.assert_any_call(
+            'set airflow.ctx.dag_owner=airflow')
+        hook.mock_cursor.execute.assert_any_call(
+            'set [email protected]')

Review comment:
       Generally speaking, I tried to mock assert all calls to external 
systems. These test that following line is run. 
https://github.com/apache/airflow/blob/master/airflow/providers/apache/hive/hooks/hive.py#L843




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to