feluelle commented on a change in pull request #7422: [AIRFLOW-6809] Test for 
presto operators
URL: https://github.com/apache/airflow/pull/7422#discussion_r385009987
 
 

 ##########
 File path: tests/providers/presto/operators/test_presto_check.py
 ##########
 @@ -16,23 +16,191 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-import os
+
+
 import unittest
+from datetime import datetime
+
+import mock
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG
+from airflow.providers.presto.operators.presto_check import (
+    PrestoCheckOperator, PrestoIntervalCheckOperator, PrestoValueCheckOperator,
+)
+
+
+class TestPrestoCheckOperator(unittest.TestCase):
+    @mock.patch.object(PrestoCheckOperator, "get_db_hook")
+    def test_execute_no_records(self, mock_get_db_hook):
+        mock_get_db_hook.return_value.get_first.return_value = []
+
+        with self.assertRaises(AirflowException):
+            PrestoCheckOperator(sql="sql").execute()
+
+    @mock.patch.object(PrestoCheckOperator, "get_db_hook")
+    def test_execute_not_all_records_are_true(self, mock_get_db_hook):
+        mock_get_db_hook.return_value.get_first.return_value = ["data", ""]
+
+        with self.assertRaises(AirflowException):
+            PrestoCheckOperator(sql="sql").execute()
+
+
+class TestValuePrestoCheckOperator(unittest.TestCase):
+    def setUp(self):
+        self.task_id = "test_task"
+        self.conn_id = "default_conn"
+
+    def _construct_operator(self, sql, pass_value, tolerance=None):
+        dag = DAG("test_dag", start_date=datetime(2017, 1, 1))
+
+        return PrestoValueCheckOperator(
+            dag=dag,
+            task_id=self.task_id,
+            conn_id=self.conn_id,
+            sql=sql,
+            pass_value=pass_value,
+            tolerance=tolerance,
+        )
+
+    def test_pass_value_template_string(self):
+        pass_value_str = "2018-03-22"
+        operator = self._construct_operator("select date from tab1;", "{{ ds 
}}")
+
+        operator.render_template_fields({"ds": pass_value_str})
+
+        self.assertEqual(operator.task_id, self.task_id)
+        self.assertEqual(operator.pass_value, pass_value_str)
+
+    def test_pass_value_template_string_float(self):
+        pass_value_float = 4.0
+        operator = self._construct_operator("select date from tab1;", 
pass_value_float)
+
+        operator.render_template_fields({})
+
+        self.assertEqual(operator.task_id, self.task_id)
+        self.assertEqual(operator.pass_value, str(pass_value_float))
+
+    @mock.patch.object(PrestoValueCheckOperator, "get_db_hook")
+    def test_execute_pass(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [10]
+        mock_get_db_hook.return_value = mock_hook
+        sql = "select value from tab1 limit 1;"
+        operator = self._construct_operator(sql, 5, 1)
+
+        operator.execute(None)
+
+        mock_hook.get_first.assert_called_once_with(sql)
+
+    @mock.patch.object(PrestoValueCheckOperator, "get_db_hook")
+    def test_execute_fail(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [11]
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("select value from tab1 limit 1;", 
5, 1)
+
+        with self.assertRaisesRegex(AirflowException, "Tolerance:100.0%"):
+            operator.execute()
+
+
+class TestPrestoIntervalCheckOperator(unittest.TestCase):
+    def _construct_operator(self, table, metric_thresholds, ratio_formula, 
ignore_zero):
+        return PrestoIntervalCheckOperator(
+            task_id="test_task",
+            table=table,
+            metrics_thresholds=metric_thresholds,
+            ratio_formula=ratio_formula,
+            ignore_zero=ignore_zero,
+        )
+
+    def test_invalid_ratio_formula(self):
+        with self.assertRaisesRegex(AirflowException, "Invalid diff_method"):
+            self._construct_operator(
+                table="test_table",
+                metric_thresholds={"f1": 1},
+                ratio_formula="abs",
+                ignore_zero=False,
+            )
+
+    @mock.patch.object(PrestoIntervalCheckOperator, "get_db_hook")
+    def test_execute_not_ignore_zero(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [0]
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator(
+            table="test_table",
+            metric_thresholds={"f1": 1},
+            ratio_formula="max_over_min",
+            ignore_zero=False,
+        )
+
+        with self.assertRaises(AirflowException):
+            operator.execute()
+
+    @mock.patch.object(PrestoIntervalCheckOperator, "get_db_hook")
+    def test_execute_ignore_zero(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [0]
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator(
+            table="test_table",
+            metric_thresholds={"f1": 1},
+            ratio_formula="max_over_min",
+            ignore_zero=True,
+        )
+
+        operator.execute()
+
+    @mock.patch.object(PrestoIntervalCheckOperator, "get_db_hook")
+    def test_execute_min_max(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+
+        def returned_row():
+            rows = [
+                [2, 2, 2, 2],  # reference
+                [1, 1, 1, 1],  # current
+            ]
+
+            yield from rows
+
+        mock_hook.get_first.side_effect = returned_row()
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator(
+            table="test_table",
+            metric_thresholds={"f0": 1.0, "f1": 1.5, "f2": 2.0, "f3": 2.5},
+            ratio_formula="max_over_min",
+            ignore_zero=True,
+        )
+
+        with self.assertRaisesRegex(AirflowException, "f0, f1, f2"):
+            operator.execute()
+
+    @mock.patch.object(PrestoIntervalCheckOperator, "get_db_hook")
+    def test_execute_diff(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+
+        def returned_row():
+            rows = [
+                [3, 3, 3, 3],  # reference
+                [1, 1, 1, 1],  # current
+            ]
 
-from airflow.providers.presto.operators.presto_check import PrestoCheckOperator
-from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment
+            yield from rows
 
+        mock_hook.get_first.side_effect = returned_row()
+        mock_get_db_hook.return_value = mock_hook
 
[email protected](
-    'AIRFLOW_RUNALL_TESTS' not in os.environ,
-    "Skipped because AIRFLOW_RUNALL_TESTS is not set")
-class TestPrestoCheckOperator(TestHiveEnvironment):
+        operator = self._construct_operator(
+            table="test_table",
+            metric_thresholds={"f0": 0.5, "f1": 0.6, "f2": 0.7, "f3": 0.8},
+            ratio_formula="relative_diff",
+            ignore_zero=True,
+        )
 
-    def test_presto(self):
-        sql = """
-            SELECT count(1) FROM airflow.static_babynames_partitioned;
-            """
-        op = PrestoCheckOperator(
-            task_id='presto_check', sql=sql, dag=self.dag)
-        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
-               ignore_ti_state=True)
 
 Review comment:
   This test is missing. -> 1.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to