Repository: incubator-airflow
Updated Branches:
  refs/heads/master b9cb54f87 -> 4d153ad4e


[AIRFLOW-2627] Add a sensor for Cassandra

Closes #3510 from sekikn/AIRFLOW-2627


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/4d153ad4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/4d153ad4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/4d153ad4

Branch: refs/heads/master
Commit: 4d153ad4e85a945defca618d8fc3dc22b8535f93
Parents: b9cb54f
Author: Kengo Seki <[email protected]>
Authored: Sun Jun 17 19:10:48 2018 +0100
Committer: Kaxil Naik <[email protected]>
Committed: Sun Jun 17 19:10:48 2018 +0100

----------------------------------------------------------------------
 .travis.yml                                    |  1 +
 airflow/contrib/hooks/cassandra_hook.py        | 28 +++++++++-
 airflow/contrib/sensors/cassandra_sensor.py    | 60 +++++++++++++++++++++
 docs/code.rst                                  |  1 +
 tests/contrib/hooks/test_cassandra_hook.py     | 20 +++++++
 tests/contrib/sensors/test_cassandra_sensor.py | 58 ++++++++++++++++++++
 6 files changed, 166 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4d153ad4/.travis.yml
----------------------------------------------------------------------
diff --git a/.travis.yml b/.travis.yml
index cbec39a..01c08d9 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -22,6 +22,7 @@ language: python
 jdk:
   - oraclejdk8
 services:
+  - cassandra
   - mysql
   - postgresql
   - rabbitmq

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4d153ad4/airflow/contrib/hooks/cassandra_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/cassandra_hook.py 
b/airflow/contrib/hooks/cassandra_hook.py
index 6fce98b..704ba0d 100644
--- a/airflow/contrib/hooks/cassandra_hook.py
+++ b/airflow/contrib/hooks/cassandra_hook.py
@@ -106,9 +106,10 @@ class CassandraHook(BaseHook, LoggingMixin):
         """
         Returns a cassandra Session object
         """
-        if self.session:
+        if self.session and not self.session.is_shutdown:
             return self.session
-        return self.cluster.connect(self.keyspace)
+        self.session = self.cluster.connect(self.keyspace)
+        return self.session
 
     def get_cluster(self):
         return self.cluster
@@ -156,3 +157,26 @@ class CassandraHook(BaseHook, LoggingMixin):
                 child_policy = CassandraHook.get_lb_policy(child_policy_name,
                                                            child_policy_args)
                 return TokenAwarePolicy(child_policy)
+
+    def record_exists(self, table, keys):
+        """
+        Checks if a record exists in Cassandra
+
+        :param table: Target Cassandra table.
+                      Use dot notation to target a specific keyspace.
+        :type table: string
+        :param keys: The keys and their values to check the existence.
+        :type keys: dict
+        """
+        keyspace = None
+        if '.' in table:
+            keyspace, table = table.split('.', 1)
+        ks = " AND ".join("{}=%({})s".format(key, key) for key in keys.keys())
+        cql = "SELECT * FROM {keyspace}.{table} WHERE {keys}".format(
+            keyspace=(keyspace or self.keyspace), table=table, keys=ks)
+
+        try:
+            rs = self.get_conn().execute(cql, keys)
+            return rs.one() is not None
+        except Exception:
+            return False

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4d153ad4/airflow/contrib/sensors/cassandra_sensor.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/sensors/cassandra_sensor.py 
b/airflow/contrib/sensors/cassandra_sensor.py
new file mode 100644
index 0000000..aef6612
--- /dev/null
+++ b/airflow/contrib/sensors/cassandra_sensor.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from airflow.contrib.hooks.cassandra_hook import CassandraHook
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class CassandraRecordSensor(BaseSensorOperator):
+    """
+    Checks for the existence of a record in a Cassandra cluster.
+
+    For example, if you want to wait for a record that has values 'v1' and 
'v2' for each
+    primary keys 'p1' and 'p2' to be populated in keyspace 'k' and table 't',
+    instantiate it as follows:
+
+    >>> CassandraRecordSensor(table="k.t", keys={"p1": "v1", "p2": "v2"},
+    ...     cassandra_conn_id="cassandra_default", task_id="cassandra_sensor")
+    <Task(CassandraRecordSensor): cassandra_sensor>
+    """
+    template_fields = ('table', 'keys')
+
+    @apply_defaults
+    def __init__(self, table, keys, cassandra_conn_id, *args, **kwargs):
+        """
+        Create a new CassandraRecordSensor
+
+        :param table: Target Cassandra table.
+                      Use dot notation to target a specific keyspace.
+        :type table: string
+        :param keys: The keys and their values to be monitored
+        :type keys: dict
+        :param cassandra_conn_id: The connection ID to use
+                                  when connecting to Cassandra cluster
+        :type cassandra_conn_id: string
+        """
+        super(CassandraRecordSensor, self).__init__(*args, **kwargs)
+        self.cassandra_conn_id = cassandra_conn_id
+        self.table = table
+        self.keys = keys
+
+    def poke(self, context):
+        self.log.info('Sensor check existence of record: %s', self.keys)
+        hook = CassandraHook(self.cassandra_conn_id)
+        return hook.record_exists(self.table, self.keys)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4d153ad4/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 3b51484..6b3a84a 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -198,6 +198,7 @@ Sensors
 .. autoclass:: 
airflow.contrib.sensors.aws_redshift_cluster_sensor.AwsRedshiftClusterSensor
 .. autoclass:: airflow.contrib.sensors.bash_sensor.BashSensor
 .. autoclass:: airflow.contrib.sensors.bigquery_sensor.BigQueryTableSensor
+.. autoclass:: airflow.contrib.sensors.cassandra_sensor.CassandraRecordSensor
 .. autoclass:: airflow.contrib.sensors.datadog_sensor.DatadogSensor
 .. autoclass:: airflow.contrib.sensors.emr_base_sensor.EmrBaseSensor
 .. autoclass:: airflow.contrib.sensors.emr_job_flow_sensor.EmrJobFlowSensor

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4d153ad4/tests/contrib/hooks/test_cassandra_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_cassandra_hook.py 
b/tests/contrib/hooks/test_cassandra_hook.py
index fd9e93c..e420ec0 100644
--- a/tests/contrib/hooks/test_cassandra_hook.py
+++ b/tests/contrib/hooks/test_cassandra_hook.py
@@ -117,6 +117,26 @@ class CassandraHookTest(unittest.TestCase):
             thrown = True
         self.assertEqual(should_throw, thrown)
 
+    def test_record_exists(self):
+        hook = CassandraHook()
+        session = hook.get_conn()
+
+        cqls = [
+            "DROP SCHEMA IF EXISTS s",
+            """
+                CREATE SCHEMA s WITH REPLICATION =
+                    { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }
+            """,
+            "DROP TABLE IF EXISTS s.t",
+            "CREATE TABLE s.t (pk1 text, pk2 text, c text, PRIMARY KEY (pk1, 
pk2))",
+            "INSERT INTO s.t (pk1, pk2, c) VALUES ('foo', 'bar', 'baz')",
+        ]
+        for cql in cqls:
+            session.execute(cql)
+
+        self.assertTrue(hook.record_exists("s.t", {"pk1": "foo", "pk2": 
"bar"}))
+        self.assertFalse(hook.record_exists("s.t", {"pk1": "foo", "pk2": 
"baz"}))
+
 
 if __name__ == '__main__':
     unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4d153ad4/tests/contrib/sensors/test_cassandra_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_cassandra_sensor.py 
b/tests/contrib/sensors/test_cassandra_sensor.py
new file mode 100644
index 0000000..0f0e7f5
--- /dev/null
+++ b/tests/contrib/sensors/test_cassandra_sensor.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import unittest
+
+from mock import patch
+
+from airflow import DAG
+from airflow import configuration
+from airflow.contrib.sensors.cassandra_sensor import CassandraRecordSensor
+from airflow.utils import timezone
+
+
+DEFAULT_DATE = timezone.datetime(2017, 1, 1)
+
+
+class TestCassandraRecordSensor(unittest.TestCase):
+
+    def setUp(self):
+        configuration.load_test_config()
+        args = {
+            'owner': 'airflow',
+            'start_date': DEFAULT_DATE
+        }
+        self.dag = DAG('test_dag_id', default_args=args)
+        self.sensor = CassandraRecordSensor(
+            task_id='test_task',
+            cassandra_conn_id='cassandra_default',
+            dag=self.dag,
+            table='t',
+            keys={'foo': 'bar'}
+        )
+
+    @patch("airflow.contrib.hooks.cassandra_hook.CassandraHook.record_exists")
+    def test_poke(self, mock_record_exists):
+        self.sensor.poke(None)
+        mock_record_exists.assert_called_once_with('t', {'foo': 'bar'})
+
+
+if __name__ == '__main__':
+    unittest.main()

Reply via email to