Repository: incubator-airflow
Updated Branches:
  refs/heads/master b76d560ce -> c1d583f91


[AIRFLOW-2213] Add Quoble check operator

Closes #3300 from sakshi2894/AIRFLOW-2213


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/c1d583f9
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/c1d583f9
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/c1d583f9

Branch: refs/heads/master
Commit: c1d583f91a0b4185f760a64acbeae86739479cdb
Parents: b76d560
Author: Sakshi Bansal <saks...@qubole.com>
Authored: Tue May 15 14:51:35 2018 +0530
Committer: Sumit Maheshwari <sum...@qubole.com>
Committed: Tue May 15 14:51:35 2018 +0530

----------------------------------------------------------------------
 airflow/contrib/hooks/qubole_check_hook.py      | 117 ++++++++++
 .../contrib/operators/qubole_check_operator.py  | 225 +++++++++++++++++++
 docs/code.rst                                   |   2 +
 tests/contrib/hooks/test_qubole_check_hook.py   |  43 ++++
 .../operators/test_qubole_check_operator.py     | 115 ++++++++++
 5 files changed, 502 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c1d583f9/airflow/contrib/hooks/qubole_check_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/qubole_check_hook.py 
b/airflow/contrib/hooks/qubole_check_hook.py
new file mode 100644
index 0000000..303c19b
--- /dev/null
+++ b/airflow/contrib/hooks/qubole_check_hook.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.contrib.hooks.qubole_hook import QuboleHook
+from airflow.exceptions import AirflowException
+from qds_sdk.commands import Command
+
+try:
+    from cStringIO import StringIO
+except ImportError:
+    from io import StringIO
+
+
+COL_DELIM = '\t'
+ROW_DELIM = '\r\n'
+
+
+def isint(value):
+    try:
+        int(value)
+        return True
+    except ValueError:
+        return False
+
+
+def isfloat(value):
+    try:
+        float(value)
+        return True
+    except ValueError:
+        return False
+
+
+def isbool(value):
+    try:
+        if value.lower() in ["true", "false"]:
+            return True
+    except ValueError:
+        return False
+
+
+def parse_first_row(row_list):
+    record_list = []
+    first_row = row_list[0] if row_list else ""
+
+    for col_value in first_row.split(COL_DELIM):
+        if isint(col_value):
+            col_value = int(col_value)
+        elif isfloat(col_value):
+            col_value = float(col_value)
+        elif isbool(col_value):
+            col_value = (col_value.lower() == "true")
+        record_list.append(col_value)
+
+    return record_list
+
+
+class QuboleCheckHook(QuboleHook):
+    def __init__(self, context, *args, **kwargs):
+        super(QuboleCheckHook, self).__init__(*args, **kwargs)
+        self.results_parser_callable = parse_first_row
+        if 'results_parser_callable' in kwargs and \
+                kwargs['results_parser_callable'] is not None:
+            if not callable(kwargs['results_parser_callable']):
+                raise AirflowException('`results_parser_callable` param must 
be callable')
+            self.results_parser_callable = kwargs['results_parser_callable']
+        self.context = context
+
+    @staticmethod
+    def handle_failure_retry(context):
+        ti = context['ti']
+        cmd_id = ti.xcom_pull(key='qbol_cmd_id', task_ids=ti.task_id)
+
+        if cmd_id is not None:
+            cmd = Command.find(cmd_id)
+            if cmd is not None:
+                if cmd.status == 'running':
+                    log = LoggingMixin().log
+                    log.info('Cancelling the Qubole Command Id: %s', cmd_id)
+                    cmd.cancel()
+
+    def get_first(self, sql):
+        self.execute(context=self.context)
+        query_result = self.get_query_results()
+        row_list = list(filter(None, query_result.split(ROW_DELIM)))
+        record_list = self.results_parser_callable(row_list)
+        return record_list
+
+    def get_query_results(self):
+        log = LoggingMixin().log
+        if self.cmd is not None:
+            cmd_id = self.cmd.id
+            log.info("command id: " + str(cmd_id))
+            query_result_buffer = StringIO()
+            self.cmd.get_results(fp=query_result_buffer, inline=True, 
delim=COL_DELIM)
+            query_result = query_result_buffer.getvalue()
+            query_result_buffer.close()
+            return query_result
+        else:
+            log.info("Qubole command not found")

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c1d583f9/airflow/contrib/operators/qubole_check_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/qubole_check_operator.py 
b/airflow/contrib/operators/qubole_check_operator.py
new file mode 100644
index 0000000..0e8d75e
--- /dev/null
+++ b/airflow/contrib/operators/qubole_check_operator.py
@@ -0,0 +1,225 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from airflow.contrib.operators.qubole_operator import QuboleOperator
+from airflow.utils.decorators import apply_defaults
+from airflow.contrib.hooks.qubole_check_hook import QuboleCheckHook
+from airflow.operators.check_operator import CheckOperator, ValueCheckOperator
+from airflow.exceptions import AirflowException
+
+
+class QuboleCheckOperator(CheckOperator, QuboleOperator):
+    """
+    Performs checks against Qubole Commands. ``QuboleCheckOperator`` expects
+    a command that will be executed on QDS.
+    By default, each value on first row of the result of this Qubole Commmand
+    is evaluated using python ``bool`` casting. If any of the
+    values return ``False``, the check is failed and errors out.
+
+    Note that Python bool casting evals the following as ``False``:
+
+    * ``False``
+    * ``0``
+    * Empty string (``""``)
+    * Empty list (``[]``)
+    * Empty dictionary or set (``{}``)
+
+    Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
+    the count ``== 0``. You can craft much more complex query that could,
+    for instance, check that the table has the same number of rows as
+    the source table upstream, or that the count of today's partition is
+    greater than yesterday's partition, or that a set of metrics are less
+    than 3 standard deviation for the 7 day average.
+
+    This operator can be used as a data quality check in your pipeline, and
+    depending on where you put it in your DAG, you have the choice to
+    stop the critical path, preventing from
+    publishing dubious data, or on the side and receive email alerts
+    without stopping the progress of the DAG.
+
+    :param qubole_conn_id: Connection id which consists of qds auth_token
+    :type qubole_conn_id: str
+
+    kwargs:
+
+        Arguments specific to Qubole command can be referred from 
QuboleOperator docs.
+
+        :results_parser_callable: This is an optional parameter to
+            extend the flexibility of parsing the results of Qubole
+            command to the users. This is a python callable which
+            can hold the logic to parse list of rows returned by Qubole 
command.
+            By default, only the values on first row are used for performing 
checks.
+            This callable should return a list of records on
+            which the checks have to be performed.
+
+    .. note:: All fields in common with template fields of
+            QuboleOperator and CheckOperator are template-supported.
+    """
+
+    template_fields = QuboleOperator.template_fields + 
CheckOperator.template_fields
+    template_ext = QuboleOperator.template_ext
+    ui_fgcolor = '#000'
+
+    @apply_defaults
+    def __init__(self, qubole_conn_id="qubole_default", *args, **kwargs):
+        sql = get_sql_from_qbol_cmd(kwargs)
+        super(QuboleCheckOperator, self)\
+            .__init__(qubole_conn_id=qubole_conn_id, sql=sql, *args, **kwargs)
+        self.on_failure_callback = QuboleCheckHook.handle_failure_retry
+        self.on_retry_callback = QuboleCheckHook.handle_failure_retry
+
+    def execute(self, context=None):
+        try:
+            self.hook = self.get_hook(context=context)
+            super(QuboleCheckOperator, self).execute(context=context)
+        except AirflowException as e:
+            handle_airflow_exception(e, self.get_hook())
+
+    def get_db_hook(self):
+        return self.get_hook()
+
+    def get_hook(self, context=None):
+        if hasattr(self, 'hook') and (self.hook is not None):
+            return self.hook
+        else:
+            return QuboleCheckHook(context=context, *self.args, **self.kwargs)
+
+    def __getattribute__(self, name):
+        if name in QuboleCheckOperator.template_fields:
+            if name in self.kwargs:
+                return self.kwargs[name]
+            else:
+                return ''
+        else:
+            return object.__getattribute__(self, name)
+
+    def __setattr__(self, name, value):
+        if name in QuboleCheckOperator.template_fields:
+            self.kwargs[name] = value
+        else:
+            object.__setattr__(self, name, value)
+
+
+class QuboleValueCheckOperator(ValueCheckOperator, QuboleOperator):
+    """
+    Performs a simple value check using Qubole command.
+    By default, each value on the first row of this
+    Qubole command is compared with a pre-defined value.
+    The check fails and errors out if the output of the command
+    is not within the permissible limit of expected value.
+
+    :param qubole_conn_id: Connection id which consists of qds auth_token
+    :type qubole_conn_id: str
+
+    :param pass_value: Expected value of the query results.
+    :type pass_value: str/int/float
+
+    :param tolerance: Defines the permissible pass_value range, for example if
+        tolerance is 2, the Qubole command output can be anything between
+        -2*pass_value and 2*pass_value, without the operator erring out.
+
+    :type tolerance: int/float
+
+
+    kwargs:
+
+        Arguments specific to Qubole command can be referred from 
QuboleOperator docs.
+
+        :results_parser_callable: This is an optional parameter to
+            extend the flexibility of parsing the results of Qubole
+            command to the users. This is a python callable which
+            can hold the logic to parse list of rows returned by Qubole 
command.
+            By default, only the values on first row are used for performing 
checks.
+            This callable should return a list of records on
+            which the checks have to be performed.
+
+
+    .. note:: All fields in common with template fields of
+            QuboleOperator and ValueCheckOperator are template-supported.
+    """
+
+    template_fields = QuboleOperator.template_fields + 
ValueCheckOperator.template_fields
+    template_ext = QuboleOperator.template_ext
+    ui_fgcolor = '#000'
+
+    @apply_defaults
+    def __init__(self, pass_value, tolerance=None,
+                 qubole_conn_id="qubole_default", *args, **kwargs):
+
+        sql = get_sql_from_qbol_cmd(kwargs)
+        super(QuboleValueCheckOperator, self).__init__(
+            qubole_conn_id=qubole_conn_id,
+            sql=sql, pass_value=pass_value, tolerance=tolerance,
+            *args, **kwargs)
+
+        self.on_failure_callback = QuboleCheckHook.handle_failure_retry
+        self.on_retry_callback = QuboleCheckHook.handle_failure_retry
+
+    def execute(self, context=None):
+        try:
+            self.hook = self.get_hook(context=context)
+            super(QuboleValueCheckOperator, self).execute(context=context)
+        except AirflowException as e:
+            handle_airflow_exception(e, self.get_hook())
+
+    def get_db_hook(self):
+        return self.get_hook()
+
+    def get_hook(self, context=None):
+        if hasattr(self, 'hook') and (self.hook is not None):
+            return self.hook
+        else:
+            return QuboleCheckHook(context=context, *self.args, **self.kwargs)
+
+    def __getattribute__(self, name):
+        if name in QuboleValueCheckOperator.template_fields:
+            if name in self.kwargs:
+                return self.kwargs[name]
+            else:
+                return ''
+        else:
+            return object.__getattribute__(self, name)
+
+    def __setattr__(self, name, value):
+        if name in QuboleValueCheckOperator.template_fields:
+            self.kwargs[name] = value
+        else:
+            object.__setattr__(self, name, value)
+
+
+def get_sql_from_qbol_cmd(params):
+    sql = ''
+    if 'query' in params:
+        sql = params['query']
+    elif 'sql' in params:
+        sql = params['sql']
+    return sql
+
+
+def handle_airflow_exception(airflow_exception, hook):
+    cmd = hook.cmd
+    if cmd is not None:
+        if cmd.is_success:
+            qubole_command_results = hook.get_query_results()
+            qubole_command_id = cmd.id
+            exception_message = '\nQubole Command Id: {qubole_command_id}' \
+                                '\nQubole Command Results:' \
+                                '\n{qubole_command_results}'.format(**locals())
+            raise AirflowException(str(airflow_exception) + exception_message)
+    raise AirflowException(airflow_exception.message)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c1d583f9/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 53c9313..857bf67 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -170,6 +170,8 @@ Operators
 .. autoclass:: 
airflow.contrib.operators.pubsub_operator.PubSubSubscriptionCreateOperator
 .. autoclass:: 
airflow.contrib.operators.pubsub_operator.PubSubSubscriptionDeleteOperator
 .. autoclass:: airflow.contrib.operators.pubsub_operator.PubSubPublishOperator
+.. autoclass:: 
airflow.contrib.operators.qubole_check_operator.QuboleCheckOperator
+.. autoclass:: 
airflow.contrib.operators.qubole_check_operator.QuboleValueCheckOperator
 .. autoclass:: airflow.contrib.operators.qubole_operator.QuboleOperator
 .. autoclass:: airflow.contrib.operators.s3_list_operator.S3ListOperator
 .. autoclass:: 
airflow.contrib.operators.s3_to_gcs_operator.S3ToGoogleCloudStorageOperator

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c1d583f9/tests/contrib/hooks/test_qubole_check_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_qubole_check_hook.py 
b/tests/contrib/hooks/test_qubole_check_hook.py
new file mode 100644
index 0000000..150eb45
--- /dev/null
+++ b/tests/contrib/hooks/test_qubole_check_hook.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import unittest
+from airflow.contrib.hooks.qubole_check_hook import parse_first_row
+
+
+class QuboleCheckHookTest(unittest.TestCase):
+    def test_single_row_bool(self):
+        query_result = ['true\ttrue']
+        record_list = parse_first_row(query_result)
+        self.assertEqual([True, True], record_list)
+
+    def test_multi_row_bool(self):
+        query_result = ['true\tfalse', 'true\tfalse']
+        record_list = parse_first_row(query_result)
+        self.assertEqual([True, False], record_list)
+
+    def test_single_row_float(self):
+        query_result = ['0.23\t34']
+        record_list = parse_first_row(query_result)
+        self.assertEqual([0.23, 34], record_list)
+
+    def test_single_row_mixed_types(self):
+        query_result = ['name\t44\t0.23\tTrue']
+        record_list = parse_first_row(query_result)
+        self.assertEqual(["name", 44, 0.23, True], record_list)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c1d583f9/tests/contrib/operators/test_qubole_check_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_qubole_check_operator.py 
b/tests/contrib/operators/test_qubole_check_operator.py
new file mode 100644
index 0000000..2904482
--- /dev/null
+++ b/tests/contrib/operators/test_qubole_check_operator.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import unittest
+from datetime import datetime
+from airflow.models import DAG
+from airflow.exceptions import AirflowException
+from airflow.contrib.operators.qubole_check_operator import 
QuboleValueCheckOperator
+from airflow.contrib.hooks.qubole_check_hook import QuboleCheckHook
+from airflow.contrib.hooks.qubole_hook import QuboleHook
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+class QuboleValueCheckOperatorTest(unittest.TestCase):
+
+    def setUp(self):
+        self.task_id = 'test_task'
+        self.conn_id = 'default_conn'
+
+    def __construct_operator(self, query, pass_value, tolerance=None,
+                             results_parser_callable=None):
+
+        dag = DAG('test_dag', start_date=datetime(2017, 1, 1))
+
+        return QuboleValueCheckOperator(
+            dag=dag,
+            task_id=self.task_id,
+            conn_id=self.conn_id,
+            query=query,
+            pass_value=pass_value,
+            results_parser_callable=results_parser_callable,
+            command_type='hivecmd',
+            tolerance=tolerance)
+
+    def test_pass_value_template(self):
+        pass_value_str = "2018-03-22"
+        operator = self.__construct_operator('select date from tab1;', "{{ ds 
}}")
+        result = operator.render_template('pass_value', operator.pass_value,
+                                          {'ds': pass_value_str})
+
+        self.assertEqual(operator.task_id, self.task_id)
+        self.assertEqual(result, pass_value_str)
+
+    @mock.patch.object(QuboleValueCheckOperator, 'get_hook')
+    def test_execute_pass(self, mock_get_hook):
+
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [10]
+        mock_get_hook.return_value = mock_hook
+
+        query = 'select value from tab1 limit 1;'
+
+        operator = self.__construct_operator(query, 5, 1)
+
+        operator.execute(None)
+
+        mock_hook.get_first.assert_called_with(query)
+
+    @mock.patch.object(QuboleValueCheckOperator, 'get_hook')
+    def test_execute_fail(self, mock_get_hook):
+
+        mock_cmd = mock.Mock()
+        mock_cmd.status = 'done'
+        mock_cmd.id = 123
+
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [11]
+        mock_hook.cmd = mock_cmd
+        mock_get_hook.return_value = mock_hook
+
+        operator = self.__construct_operator('select value from tab1 limit 
1;', 5, 1)
+
+        with self.assertRaisesRegexp(AirflowException,
+                                     'Qubole Command Id: ' + str(mock_cmd.id)):
+            operator.execute()
+
+    @mock.patch.object(QuboleCheckHook, 'get_query_results')
+    @mock.patch.object(QuboleHook, 'execute')
+    def test_results_parser_callable(self, mock_execute, 
mock_get_query_results):
+
+        mock_execute.return_value = None
+
+        pass_value = 'pass_value'
+        mock_get_query_results.return_value = pass_value
+
+        results_parser_callable = mock.Mock()
+        results_parser_callable.return_value = [pass_value]
+
+        operator = self.__construct_operator('select value from tab1 limit 1;',
+                                             pass_value, None, 
results_parser_callable)
+        operator.execute()
+        results_parser_callable.assert_called_once_with([pass_value])

Reply via email to