Repository: incubator-airflow
Updated Branches:
  refs/heads/master 8873a8df8 -> f5115b7e6


[ARIFLOW-2458] Add cassandra-to-gcs operator

Closes #3354 from jgao54/cassandra-to-gcs


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/f5115b7e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/f5115b7e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/f5115b7e

Branch: refs/heads/master
Commit: f5115b7e6a105e6baedd8efa9b4d4afc12ee880d
Parents: 8873a8d
Author: Joy Gao <joy...@apache.org>
Authored: Fri May 18 02:01:41 2018 +0100
Committer: Kaxil Naik <kaxiln...@apache.org>
Committed: Fri May 18 02:02:57 2018 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/cassandra_hook.py         |  88 +++++
 airflow/contrib/operators/cassandra_to_gcs.py   | 351 +++++++++++++++++++
 airflow/models.py                               |   4 +
 airflow/utils/db.py                             |   4 +
 docs/code.rst                                   |   2 +
 setup.py                                        |   7 +-
 tests/contrib/hooks/test_cassandra_hook.py      |  56 +++
 .../operators/test_cassandra_to_gcs_operator.py |  92 +++++
 8 files changed, 602 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/airflow/contrib/hooks/cassandra_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/cassandra_hook.py 
b/airflow/contrib/hooks/cassandra_hook.py
new file mode 100644
index 0000000..90046a8
--- /dev/null
+++ b/airflow/contrib/hooks/cassandra_hook.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from cassandra.cluster import Cluster
+from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
+                                TokenAwarePolicy, HostFilterPolicy,
+                                WhiteListRoundRobinPolicy)
+from cassandra.auth import PlainTextAuthProvider
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+class CassandraHook(BaseHook, LoggingMixin):
+    """
+    Hook used to interact with Cassandra
+
+    Contact_points can be specified as a comma-separated string in the 'hosts'
+    field of the connection. Port can be specified in the port field of the
+    connection. Load_alancing_policy, ssl_options, cql_version can be specified
+    in the extra field of the connection.
+
+    For details of the Cluster config, see cassandra.cluster for more details.
+    """
+    def __init__(self, cassandra_conn_id='cassandra_default'):
+        conn = self.get_connection(cassandra_conn_id)
+
+        conn_config = {}
+        if conn.host:
+            conn_config['contact_points'] = conn.host.split(',')
+
+        if conn.port:
+            conn_config['port'] = int(conn.port)
+
+        if conn.login:
+            conn_config['auth_provider'] = PlainTextAuthProvider(
+                username=conn.login, password=conn.password)
+
+        lb_policy = 
self.get_policy(conn.extra_dejson.get('load_balancing_policy', None))
+        if lb_policy:
+            conn_config['load_balancing_policy'] = lb_policy
+
+        cql_version = conn.extra_dejson.get('cql_version', None)
+        if cql_version:
+            conn_config['cql_version'] = cql_version
+
+        ssl_options = conn.extra_dejson.get('ssl_options', None)
+        if ssl_options:
+            conn_config['ssl_options'] = ssl_options
+
+        self.cluster = Cluster(**conn_config)
+        self.keyspace = conn.schema
+
+    def get_conn(self):
+        """
+        Returns a cassandra connection object
+        """
+        return self.cluster.connect(self.keyspace)
+
+    def get_cluster(self):
+        return self.cluster
+
+    @classmethod
+    def get_policy(cls, policy_name):
+        policies = {
+            'RoundRobinPolicy': RoundRobinPolicy,
+            'DCAwareRoundRobinPolicy': DCAwareRoundRobinPolicy,
+            'TokenAwarePolicy': TokenAwarePolicy,
+            'HostFilterPolicy': HostFilterPolicy,
+            'WhiteListRoundRobinPolicy': WhiteListRoundRobinPolicy,
+        }
+        return policies.get(policy_name)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/airflow/contrib/operators/cassandra_to_gcs.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cassandra_to_gcs.py 
b/airflow/contrib/operators/cassandra_to_gcs.py
new file mode 100644
index 0000000..b4e216d
--- /dev/null
+++ b/airflow/contrib/operators/cassandra_to_gcs.py
@@ -0,0 +1,351 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import unicode_literals
+
+import json
+from builtins import str
+from base64 import b64encode
+from cassandra.util import Date, Time, SortedSet, OrderedMapSerializedKey
+from datetime import datetime
+from decimal import Decimal
+from six import text_type, binary_type, PY3
+from tempfile import NamedTemporaryFile
+from uuid import UUID
+
+from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
+from airflow.contrib.hooks.cassandra_hook import CassandraHook
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class CassandraToGoogleCloudStorageOperator(BaseOperator):
+    """
+    Copy data from Cassandra to Google cloud storage in JSON format
+
+    Note: Arrays of arrays are not supported.
+    """
+    template_fields = ('cql', 'bucket', 'filename', 'schema_filename',)
+    template_ext = ('.cql',)
+    ui_color = '#a0e08c'
+
+    @apply_defaults
+    def __init__(self,
+                 cql,
+                 bucket,
+                 filename,
+                 schema_filename=None,
+                 approx_max_file_size_bytes=1900000000,
+                 cassandra_conn_id='cassandra_default',
+                 google_cloud_storage_conn_id='google_cloud_default',
+                 delegate_to=None,
+                 *args,
+                 **kwargs):
+        """
+        :param cql: The CQL to execute on the Cassandra table.
+        :type cql: string
+        :param bucket: The bucket to upload to.
+        :type bucket: string
+        :param filename: The filename to use as the object name when uploading
+            to Google cloud storage. A {} should be specified in the filename
+            to allow the operator to inject file numbers in cases where the
+            file is split due to size.
+        :type filename: string
+        :param schema_filename: If set, the filename to use as the object name
+            when uploading a .json file containing the BigQuery schema fields
+            for the table that was dumped from MySQL.
+        :type schema_filename: string
+        :param approx_max_file_size_bytes: This operator supports the ability
+            to split large table dumps into multiple files (see notes in the
+            filenamed param docs above). Google cloud storage allows for files
+            to be a maximum of 4GB. This param allows developers to specify the
+            file size of the splits.
+        :type approx_max_file_size_bytes: long
+        :param cassandra_conn_id: Reference to a specific Cassandra hook.
+        :type cassandra_conn_id: string
+        :param google_cloud_storage_conn_id: Reference to a specific Google
+            cloud storage hook.
+        :type google_cloud_storage_conn_id: string
+        :param delegate_to: The account to impersonate, if any. For this to
+            work, the service account making the request must have domain-wide
+            delegation enabled.
+        :type delegate_to: string
+        """
+        super(CassandraToGoogleCloudStorageOperator, self).__init__(*args, 
**kwargs)
+        self.cql = cql
+        self.bucket = bucket
+        self.filename = filename
+        self.schema_filename = schema_filename
+        self.approx_max_file_size_bytes = approx_max_file_size_bytes
+        self.cassandra_conn_id = cassandra_conn_id
+        self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
+        self.delegate_to = delegate_to
+
+    # Default Cassandra to BigQuery type mapping
+    CQL_TYPE_MAP = {
+        'BytesType': 'BYTES',
+        'DecimalType': 'FLOAT',
+        'UUIDType': 'STRING',
+        'BooleanType': 'BOOL',
+        'ByteType': 'INTEGER',
+        'AsciiType': 'STRING',
+        'FloatType': 'FLOAT',
+        'DoubleType': 'FLOAT',
+        'LongType': 'INTEGER',
+        'Int32Type': 'INTEGER',
+        'IntegerType': 'INTEGER',
+        'InetAddressType': 'STRING',
+        'CounterColumnType': 'INTEGER',
+        'DateType': 'TIMESTAMP',
+        'SimpleDateType': 'DATE',
+        'TimestampType': 'TIMESTAMP',
+        'TimeUUIDType': 'BYTES',
+        'ShortType': 'INTEGER',
+        'TimeType': 'TIME',
+        'DurationType': 'INTEGER',
+        'UTF8Type': 'STRING',
+        'VarcharType': 'STRING',
+    }
+
+    def execute(self, context):
+        cursor = self._query_cassandra()
+        files_to_upload = self._write_local_data_files(cursor)
+
+        # If a schema is set, create a BQ schema JSON file.
+        if self.schema_filename:
+            files_to_upload.update(self._write_local_schema_file(cursor))
+
+        # Flush all files before uploading
+        for file_handle in files_to_upload.values():
+            file_handle.flush()
+
+        self._upload_to_gcs(files_to_upload)
+
+        # Close all temp file handles.
+        for file_handle in files_to_upload.values():
+            file_handle.close()
+
+    def _query_cassandra(self):
+        """
+        Queries cassandra and returns a cursor to the results.
+        """
+        hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id)
+        session = hook.get_conn()
+        cursor = session.execute(self.cql)
+        return cursor
+
+    def _write_local_data_files(self, cursor):
+        """
+        Takes a cursor, and writes results to a local file.
+
+        :return: A dictionary where keys are filenames to be used as object
+            names in GCS, and values are file handles to local files that
+            contain the data for the GCS objects.
+        """
+        file_no = 0
+        tmp_file_handle = NamedTemporaryFile(delete=True)
+        tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
+        for row in cursor:
+            row_dict = self.generate_data_dict(row._fields, row)
+            s = json.dumps(row_dict)
+            if PY3:
+                s = s.encode('utf-8')
+            tmp_file_handle.write(s)
+
+            # Append newline to make dumps BigQuery compatible.
+            tmp_file_handle.write(b'\n')
+
+            if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
+                file_no += 1
+                tmp_file_handle = NamedTemporaryFile(delete=True)
+                tmp_file_handles[self.filename.format(file_no)] = 
tmp_file_handle
+
+        return tmp_file_handles
+
+    def _write_local_schema_file(self, cursor):
+        """
+        Takes a cursor, and writes the BigQuery schema for the results to a
+        local file system.
+
+        :return: A dictionary where key is a filename to be used as an object
+            name in GCS, and values are file handles to local files that
+            contains the BigQuery schema fields in .json format.
+        """
+        schema = []
+        tmp_schema_file_handle = NamedTemporaryFile(delete=True)
+
+        for name, type in zip(cursor.column_names, cursor.column_types):
+            schema.append(self.generate_schema_dict(name, type))
+        json_serialized_schema = json.dumps(schema)
+        if PY3:
+            json_serialized_schema = json_serialized_schema.encode('utf-8')
+
+        tmp_schema_file_handle.write(json_serialized_schema)
+        return {self.schema_filename: tmp_schema_file_handle}
+
+    def _upload_to_gcs(self, files_to_upload):
+        hook = GoogleCloudStorageHook(
+            google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
+            delegate_to=self.delegate_to)
+        for object, tmp_file_handle in files_to_upload.items():
+            hook.upload(self.bucket, object, tmp_file_handle.name, 
'application/json')
+
+    @classmethod
+    def generate_data_dict(cls, names, values):
+        row_dict = {}
+        for name, value in zip(names, values):
+            row_dict.update({name: cls.convert_value(name, value)})
+        return row_dict
+
+    @classmethod
+    def convert_value(cls, name, value):
+        if not value:
+            return value
+        elif isinstance(value, (text_type, int, float, bool, dict)):
+            return value
+        elif isinstance(value, binary_type):
+            encoded_value = b64encode(value)
+            if PY3:
+                encoded_value = encoded_value.decode('ascii')
+            return encoded_value
+        elif isinstance(value, (datetime, Date, UUID)):
+            return str(value)
+        elif isinstance(value, Decimal):
+            return float(value)
+        elif isinstance(value, Time):
+            return str(value).split('.')[0]
+        elif isinstance(value, (list, SortedSet)):
+            return cls.convert_array_types(name, value)
+        elif hasattr(value, '_fields'):
+            return cls.convert_user_type(name, value)
+        elif isinstance(value, tuple):
+            return cls.convert_tuple_type(name, value)
+        elif isinstance(value, OrderedMapSerializedKey):
+            return cls.convert_map_type(name, value)
+        else:
+            raise AirflowException('unexpected value: ' + str(value))
+
+    @classmethod
+    def convert_array_types(cls, name, value):
+        return [cls.convert_value(name, nested_value) for nested_value in 
value]
+
+    @classmethod
+    def convert_user_type(cls, name, value):
+        """
+        Converts a user type to RECORD that contains n fields, where n is the
+        number of attributes. Each element in the user type class will be 
converted to its
+        corresponding data type in BQ.
+        """
+        names = value._fields
+        values = [cls.convert_value(name, getattr(value, name)) for name in 
names]
+        return cls.generate_data_dict(names, values)
+
+    @classmethod
+    def convert_tuple_type(cls, name, value):
+        """
+        Converts a tuple to RECORD that contains n fields, each will be 
converted
+        to its corresponding data type in bq and will be named 
'field_<index>', where
+        index is determined by the order of the tuple elments defined in 
cassandra.
+        """
+        names = ['field_' + str(i) for i in range(len(value))]
+        values = [cls.convert_value(name, value) for name, value in zip(names, 
value)]
+        return cls.generate_data_dict(names, values)
+
+    @classmethod
+    def convert_map_type(cls, name, value):
+        """
+        Converts a map to a repeated RECORD that contains two fields: 'key' 
and 'value',
+        each will be converted to its corresopnding data type in BQ.
+        """
+        converted_map = []
+        for k, v in zip(value.keys(), value.values()):
+            converted_map.append({
+                'key': cls.convert_value('key', k),
+                'value': cls.convert_value('value', v)
+            })
+        return converted_map
+
+    @classmethod
+    def generate_schema_dict(cls, name, type):
+        field_schema = dict()
+        field_schema.update({'name': name})
+        field_schema.update({'type': cls.get_bq_type(type)})
+        field_schema.update({'mode': cls.get_bq_mode(type)})
+        fields = cls.get_bq_fields(name, type)
+        if fields:
+            field_schema.update({'fields': fields})
+        return field_schema
+
+    @classmethod
+    def get_bq_fields(cls, name, type):
+        fields = []
+
+        if not cls.is_simple_type(type):
+            names, types = [], []
+
+            if cls.is_array_type(type) and 
cls.is_record_type(type.subtypes[0]):
+                names = type.subtypes[0].fieldnames
+                types = type.subtypes[0].subtypes
+            elif cls.is_record_type(type):
+                names = type.fieldnames
+                types = type.subtypes
+
+            if types and not names and type.cassname == 'TupleType':
+                names = ['field_' + str(i) for i in range(len(types))]
+            elif types and not names and type.cassname == 'MapType':
+                names = ['key', 'value']
+
+            for name, type in zip(names, types):
+                field = cls.generate_schema_dict(name, type)
+                fields.append(field)
+
+        return fields
+
+    @classmethod
+    def is_simple_type(cls, type):
+        return type.cassname in 
CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP
+
+    @classmethod
+    def is_array_type(cls, type):
+        return type.cassname in ['ListType', 'SetType']
+
+    @classmethod
+    def is_record_type(cls, type):
+        return type.cassname in ['UserType', 'TupleType', 'MapType']
+
+    @classmethod
+    def get_bq_type(cls, type):
+        if cls.is_simple_type(type):
+            return 
CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type.cassname]
+        elif cls.is_record_type(type):
+            return 'RECORD'
+        elif cls.is_array_type(type):
+            return cls.get_bq_type(type.subtypes[0])
+        else:
+            raise AirflowException('Not a supported type: ' + type.cassname)
+
+    @classmethod
+    def get_bq_mode(cls, type):
+        if cls.is_array_type(type) or type.cassname == 'MapType':
+            return 'REPEATED'
+        elif cls.is_record_type(type) or cls.is_simple_type(type):
+            return 'NULLABLE'
+        else:
+            raise AirflowException('Not a supported type: ' + type.cassname)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index c9fee0c..7aab4b5 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -603,6 +603,7 @@ class Connection(Base, LoggingMixin):
         ('snowflake', 'Snowflake',),
         ('segment', 'Segment',),
         ('azure_data_lake', 'Azure Data Lake'),
+        ('cassandra', 'Cassandra',),
     ]
 
     def __init__(
@@ -753,6 +754,9 @@ class Connection(Base, LoggingMixin):
             elif self.conn_type == 'azure_data_lake':
                 from airflow.contrib.hooks.azure_data_lake_hook import 
AzureDataLakeHook
                 return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
+            elif self.conn_type == 'cassandra':
+                from airflow.contrib.hooks.cassandra_hook import CassandraHook
+                return CassandraHook(cassandra_conn_id=self.conn_id)
         except:
             pass
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/airflow/utils/db.py
----------------------------------------------------------------------
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index adda6fd..270939a 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -276,6 +276,10 @@ def initdb(rbac=False):
         models.Connection(
             conn_id='azure_data_lake_default', conn_type='azure_data_lake',
             extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }'))
+    merge_conn(
+        models.Connection(
+            conn_id='cassandra_default', conn_type='cassandra',
+            host='localhost', port=9042))
 
     # Known event types
     KET = models.KnownEventType

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 857bf67..1737d15 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -121,6 +121,7 @@ Operators
 .. autoclass:: 
airflow.contrib.operators.bigquery_table_delete_operator.BigQueryTableDeleteOperator
 .. autoclass:: 
airflow.contrib.operators.bigquery_to_bigquery.BigQueryToBigQueryOperator
 .. autoclass:: 
airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator
+.. autoclass:: 
airflow.contrib.operators.cassandra_to_gcs.CassandraToGoogleCloudStorageOperator
 .. autoclass:: 
airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator
 .. autoclass:: airflow.contrib.operators.dataflow_operator.DataFlowJavaOperator
 .. autoclass:: 
airflow.contrib.operators.dataflow_operator.DataflowTemplateOperator
@@ -354,6 +355,7 @@ Community contributed hooks
 .. autoclass:: airflow.contrib.hooks.aws_hook.AwsHook
 .. autoclass:: airflow.contrib.hooks.aws_lambda_hook.AwsLambdaHook
 .. autoclass:: airflow.contrib.hooks.bigquery_hook.BigQueryHook
+.. autoclass:: airflow.contrib.hooks.cassandra_hook.CassandraHook
 .. autoclass:: airflow.contrib.hooks.cloudant_hook.CloudantHook
 .. autoclass:: airflow.contrib.hooks.databricks_hook.DatabricksHook
 .. autoclass:: airflow.contrib.hooks.datadog_hook.DatadogHook

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index 9813ea2..97b6883 100644
--- a/setup.py
+++ b/setup.py
@@ -114,7 +114,7 @@ azure_data_lake = [
     'azure-mgmt-datalake-store==0.4.0',
     'azure-datalake-store==0.0.19'
 ]
-sendgrid = ['sendgrid>=5.2.0']
+cassandra = ['cassandra-driver>=3.13.0']
 celery = [
     'celery>=4.0.2',
     'flower>=0.7.3'
@@ -184,6 +184,7 @@ s3 = ['boto3>=1.7.0']
 salesforce = ['simple-salesforce>=0.72']
 samba = ['pysmbclient>=0.1.3']
 segment = ['analytics-python>=1.2.9']
+sendgrid = ['sendgrid>=5.2.0']
 slack = ['slackclient>=1.0.0']
 snowflake = ['snowflake-connector-python>=1.5.2',
              'snowflake-sqlalchemy>=1.1.0']
@@ -194,7 +195,8 @@ webhdfs = ['hdfs[dataframe,avro,kerberos]>=2.0.4']
 winrm = ['pywinrm==0.2.2']
 zendesk = ['zdesk']
 
-all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid 
+ pinot
+all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid 
+ pinot \
+    + cassandra
 devel = [
     'click',
     'freezegun',
@@ -290,6 +292,7 @@ def do_setup():
             'async': async,
             'azure_blob_storage': azure_blob_storage,
             'azure_data_lake': azure_data_lake,
+            'cassandra': cassandra,
             'celery': celery,
             'cgroups': cgroups,
             'cloudant': cloudant,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/tests/contrib/hooks/test_cassandra_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_cassandra_hook.py 
b/tests/contrib/hooks/test_cassandra_hook.py
new file mode 100644
index 0000000..42afd9e
--- /dev/null
+++ b/tests/contrib/hooks/test_cassandra_hook.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+import mock
+
+from airflow import configuration
+from airflow.contrib.hooks.cassandra_hook import CassandraHook
+from cassandra.cluster import Cluster
+from cassandra.policies import TokenAwarePolicy
+from airflow import models
+from airflow.utils import db
+
+
+class CassandraHookTest(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        db.merge_conn(
+            models.Connection(
+                conn_id='cassandra_test', conn_type='cassandra',
+                host='host-1,host-2', port='9042', schema='test_keyspace',
+                extra='{"load_balancing_policy":"TokenAwarePolicy"'))
+
+    def test_get_conn(self):
+        with mock.patch.object(Cluster, "connect") as mock_connect, \
+                mock.patch("socket.getaddrinfo", return_value=[]) as 
mock_getaddrinfo:
+            mock_connect.return_value = 'session'
+            hook = CassandraHook(cassandra_conn_id='cassandra_test')
+            hook.get_conn()
+            mock_getaddrinfo.assert_called()
+            mock_connect.assert_called_once_with('test_keyspace')
+
+            cluster = hook.get_cluster()
+            self.assertEqual(cluster.contact_points, ['host-1', 'host-2'])
+            self.assertEqual(cluster.port, 9042)
+            self.assertTrue(isinstance(cluster.load_balancing_policy, 
TokenAwarePolicy))
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/tests/contrib/operators/test_cassandra_to_gcs_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cassandra_to_gcs_operator.py 
b/tests/contrib/operators/test_cassandra_to_gcs_operator.py
new file mode 100644
index 0000000..add115f
--- /dev/null
+++ b/tests/contrib/operators/test_cassandra_to_gcs_operator.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import unicode_literals
+
+import unittest
+import mock
+from builtins import str
+from airflow.contrib.operators.cassandra_to_gcs import \
+    CassandraToGoogleCloudStorageOperator
+
+
+class CassandraToGCSTest(unittest.TestCase):
+
+    
@mock.patch('airflow.contrib.operators.gcs_to_s3.GoogleCloudStorageHook.upload')
+    @mock.patch('airflow.contrib.hooks.cassandra_hook.CassandraHook.get_conn')
+    def test_execute(self, upload, get_conn):
+        operator = CassandraToGoogleCloudStorageOperator(
+            task_id='test-cas-to-gcs',
+            cql='select * from keyspace1.table1',
+            bucket='test-bucket',
+            filename='data.json',
+            schema_filename='schema.json')
+
+        operator.execute(None)
+
+        self.assertTrue(get_conn.called_once())
+        self.assertTrue(upload.called_once())
+
+    def test_convert_value(self):
+        op = CassandraToGoogleCloudStorageOperator
+        self.assertEquals(op.convert_value('None', None), None)
+        self.assertEquals(op.convert_value('int', 1), 1)
+        self.assertEquals(op.convert_value('float', 1.0), 1.0)
+        self.assertEquals(op.convert_value('str', "text"), "text")
+        self.assertEquals(op.convert_value('bool', True), True)
+        self.assertEquals(op.convert_value('dict', {"a": "b"}), {"a": "b"})
+
+        from datetime import datetime
+        now = datetime.now()
+        self.assertEquals(op.convert_value('datetime', now), str(now))
+
+        from cassandra.util import Date
+        date_str = '2018-01-01'
+        date = Date(date_str)
+        self.assertEquals(op.convert_value('date', date), str(date_str))
+
+        import uuid
+        test_uuid = uuid.uuid4()
+        self.assertEquals(op.convert_value('uuid', test_uuid), str(test_uuid))
+
+        from decimal import Decimal
+        d = Decimal(1.0)
+        self.assertEquals(op.convert_value('decimal', d), float(d))
+
+        from base64 import b64encode
+        b = b'abc'
+        encoded_b = b64encode(b).decode('ascii')
+        self.assertEquals(op.convert_value('binary', b), encoded_b)
+
+        from cassandra.util import Time
+        time = Time(0)
+        self.assertEquals(op.convert_value('time', time), '00:00:00')
+
+        date_str_lst = ['2018-01-01', '2018-01-02', '2018-01-03']
+        date_lst = [Date(d) for d in date_str_lst]
+        self.assertEquals(op.convert_value('list', date_lst), date_str_lst)
+
+        date_tpl = tuple(date_lst)
+        self.assertEquals(op.convert_value('tuple', date_tpl),
+                          {'field_0': '2018-01-01',
+                           'field_1': '2018-01-02',
+                           'field_2': '2018-01-03', })
+
+
+if __name__ == '__main__':
+    unittest.main()

Reply via email to