blcksrx commented on issue #7422: [AIRFLOW-6809] Test for presto operators
URL: https://github.com/apache/airflow/pull/7422#issuecomment-591839111
 
 
   > You are still missing 1. and 2.
   
   No it's True!
   
   Here is 1:
   ```
   #
   # Licensed to the Apache Software Foundation (ASF) under one
   # or more contributor license agreements.  See the NOTICE file
   # distributed with this work for additional information
   # regarding copyright ownership.  The ASF licenses this file
   # to you under the Apache License, Version 2.0 (the
   # "License"); you may not use this file except in compliance
   # with the License.  You may obtain a copy of the License at
   #
   #   http://www.apache.org/licenses/LICENSE-2.0
   #
   # Unless required by applicable law or agreed to in writing,
   # software distributed under the License is distributed on an
   # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
   # KIND, either express or implied.  See the License for the
   # specific language governing permissions and limitations
   # under the License.
   #
   
   import unittest
   from unittest import mock
   from unittest.mock import patch
   
   from prestodb.transaction import IsolationLevel
   
   from airflow.models import Connection
   from airflow.providers.presto.hooks.presto import PrestoHook
   
   
   class TestPrestoHookConn(unittest.TestCase):
   
       def setUp(self):
           super().setUp()
   
           self.connection = Connection(
               login='login',
               password='password',
               host='host',
               schema='hive',
           )
   
           class UnitTestPrestoHook(PrestoHook):
               conn_name_attr = 'presto_conn_id'
   
           self.db_hook = UnitTestPrestoHook()
           self.db_hook.get_connection = mock.Mock()
           self.db_hook.get_connection.return_value = self.connection
   
       
@patch('airflow.providers.presto.hooks.presto.prestodb.auth.BasicAuthentication')
       @patch('airflow.providers.presto.hooks.presto.prestodb.dbapi.connect')
       def test_get_conn(self, mock_connect, mock_basic_auth):
           self.db_hook.get_conn()
           mock_connect.assert_called_once_with(catalog='hive', host='host', 
port=None, http_scheme='http',
                                                schema='hive', 
source='airflow', user='login', isolation_level=0,
                                                auth=mock_basic_auth('login', 
'password'))
   
   
   class TestPrestoHook(unittest.TestCase):
   
       def setUp(self):
           super().setUp()
   
           self.cur = mock.MagicMock()
           self.conn = mock.MagicMock()
           self.conn.cursor.return_value = self.cur
           conn = self.conn
   
           class UnitTestPrestoHook(PrestoHook):
               conn_name_attr = 'test_conn_id'
   
               def get_conn(self):
                   return conn
   
               def get_isolation_level(self):
                   return IsolationLevel.READ_COMMITTED
   
           self.db_hook = UnitTestPrestoHook()
   
       @patch('airflow.hooks.dbapi_hook.DbApiHook.insert_rows')
       def test_insert_rows(self, mock_insert_rows):
           table = "table"
           rows = [("hello",),
                   ("world",)]
           target_fields = None
           commit_every = 10
           self.db_hook.insert_rows(table, rows, target_fields, commit_every)
           mock_insert_rows.assert_called_once_with(table, rows, None, 10)
   
       def test_get_first_record(self):
           statement = 'SQL'
           result_sets = [('row1',), ('row2',)]
           self.cur.fetchone.return_value = result_sets[0]
   
           self.assertEqual(result_sets[0], self.db_hook.get_first(statement))
           self.conn.close.assert_called_once_with()
           self.cur.close.assert_called_once_with()
           self.cur.execute.assert_called_once_with(statement)
   
       def test_get_records(self):
           statement = 'SQL'
           result_sets = [('row1',), ('row2',)]
           self.cur.fetchall.return_value = result_sets
   
           self.assertEqual(result_sets, self.db_hook.get_records(statement))
           self.conn.close.assert_called_once_with()
           self.cur.close.assert_called_once_with()
           self.cur.execute.assert_called_once_with(statement)
   
       def test_get_pandas_df(self):
           statement = 'SQL'
           column = 'col'
           result_sets = [('row1',), ('row2',)]
           self.cur.description = [(column,)]
           self.cur.fetchall.return_value = result_sets
           df = self.db_hook.get_pandas_df(statement)
   
           self.assertEqual(column, df.columns[0])
   
           self.assertEqual(result_sets[0][0], df.values.tolist()[0][0])
           self.assertEqual(result_sets[1][0], df.values.tolist()[1][0])
   
           self.cur.execute.assert_called_once_with(statement, None)
   ```
   And Here is 2:
   ```
   #
   # Licensed to the Apache Software Foundation (ASF) under one
   # or more contributor license agreements.  See the NOTICE file
   # distributed with this work for additional information
   # regarding copyright ownership.  The ASF licenses this file
   # to you under the Apache License, Version 2.0 (the
   # "License"); you may not use this file except in compliance
   # with the License.  You may obtain a copy of the License at
   #
   #   http://www.apache.org/licenses/LICENSE-2.0
   #
   # Unless required by applicable law or agreed to in writing,
   # software distributed under the License is distributed on an
   # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
   # KIND, either express or implied.  See the License for the
   # specific language governing permissions and limitations
   # under the License.
   
   import os
   import unittest
   from unittest import mock
   
   from airflow.configuration import conf
   from airflow.models import TaskInstance
   from airflow.providers.apache.hive.operators.hive import HiveOperator
   from airflow.utils import timezone
   from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment
   
   
   class HiveOperatorConfigTest(TestHiveEnvironment):
   
       def test_hive_airflow_default_config_queue(self):
           op = HiveOperator(
               task_id='test_default_config_queue',
               hql=self.hql,
               mapred_queue_priority='HIGH',
               mapred_job_name='airflow.test_default_config_queue',
               dag=self.dag)
   
           # just check that the correct default value in test_default.cfg is 
used
           test_config_hive_mapred_queue = conf.get(
               'hive',
               'default_hive_mapred_queue'
           )
           self.assertEqual(op.get_hook().mapred_queue, 
test_config_hive_mapred_queue)
   
       def test_hive_airflow_default_config_queue_override(self):
           specific_mapred_queue = 'default'
           op = HiveOperator(
               task_id='test_default_config_queue',
               hql=self.hql,
               mapred_queue=specific_mapred_queue,
               mapred_queue_priority='HIGH',
               mapred_job_name='airflow.test_default_config_queue',
               dag=self.dag)
   
           self.assertEqual(op.get_hook().mapred_queue, specific_mapred_queue)
   
   
   class HiveOperatorTest(TestHiveEnvironment):
   
       def test_hiveconf_jinja_translate(self):
           hql = "SELECT ${num_col} FROM ${hiveconf:table};"
           op = HiveOperator(
               hiveconf_jinja_translate=True,
               task_id='dry_run_basic_hql', hql=hql, dag=self.dag)
           op.prepare_template()
           self.assertEqual(op.hql, "SELECT {{ num_col }} FROM {{ table }};")
   
       def test_hiveconf(self):
           hql = "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});"
           op = HiveOperator(
               hiveconfs={'table': 'static_babynames', 'day': '{{ ds }}'},
               task_id='dry_run_basic_hql', hql=hql, dag=self.dag)
           op.prepare_template()
           self.assertEqual(
               op.hql,
               "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});")
   
       
@mock.patch('airflow.providers.apache.hive.operators.hive.HiveOperator.get_hook')
       def test_mapred_job_name(self, mock_get_hook):
           mock_hook = mock.MagicMock()
           mock_get_hook.return_value = mock_hook
           op = HiveOperator(
               task_id='test_mapred_job_name',
               hql=self.hql,
               dag=self.dag)
   
           fake_execution_date = timezone.datetime(2018, 6, 19)
           fake_ti = TaskInstance(task=op, execution_date=fake_execution_date)
           fake_ti.hostname = 'fake_hostname'
           fake_context = {'ti': fake_ti}
   
           op.execute(fake_context)
           self.assertEqual(
               "Airflow HiveOperator task for {}.{}.{}.{}"
               .format(fake_ti.hostname,
                       self.dag.dag_id, op.task_id,
                       fake_execution_date.isoformat()), 
mock_hook.mapred_job_name)
   
   
   @unittest.skipIf(
       'AIRFLOW_RUNALL_TESTS' not in os.environ,
       "Skipped because AIRFLOW_RUNALL_TESTS is not set")
   class TestHivePresto(TestHiveEnvironment):
       def test_hive(self):
           op = HiveOperator(
               task_id='basic_hql', hql=self.hql, dag=self.dag)
           op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
                  ignore_ti_state=True)
   
       def test_hive_queues(self):
           op = HiveOperator(
               task_id='test_hive_queues', hql=self.hql,
               mapred_queue='default', mapred_queue_priority='HIGH',
               mapred_job_name='airflow.test_hive_queues',
               dag=self.dag)
           op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
                  ignore_ti_state=True)
   
       def test_hive_dryrun(self):
           op = HiveOperator(
               task_id='dry_run_basic_hql', hql=self.hql, dag=self.dag)
           op.dry_run()
   
       def test_beeline(self):
           op = HiveOperator(
               task_id='beeline_hql', hive_cli_conn_id='hive_cli_default',
               hql=self.hql, dag=self.dag)
           op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
                  ignore_ti_state=True)
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to