stale[bot] closed pull request #3139: [AIRFLOW-2224] Add support for CSV files 
in mysql_to_gcs operator
URL: https://github.com/apache/incubator-airflow/pull/3139
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/operators/mysql_to_gcs.py 
b/airflow/contrib/operators/mysql_to_gcs.py
index 9ba84c7556..c0c48c5c68 100644
--- a/airflow/contrib/operators/mysql_to_gcs.py
+++ b/airflow/contrib/operators/mysql_to_gcs.py
@@ -25,13 +25,14 @@
 from MySQLdb.constants import FIELD_TYPE
 from tempfile import NamedTemporaryFile
 from six import string_types
+import unicodecsv as csv
 
 PY3 = sys.version_info[0] == 3
 
 
 class MySqlToGoogleCloudStorageOperator(BaseOperator):
     """
-    Copy data from MySQL to Google cloud storage in JSON format.
+    Copy data from MySQL to Google cloud storage in JSON or CSV format.
     """
     template_fields = ('sql', 'bucket', 'filename', 'schema_filename', 
'schema')
     template_ext = ('.sql',)
@@ -48,6 +49,7 @@ def __init__(self,
                  google_cloud_storage_conn_id='google_cloud_storage_default',
                  schema=None,
                  delegate_to=None,
+                 export_format={'file_format': 'json'},
                  *args,
                  **kwargs):
         """
@@ -82,6 +84,50 @@ def __init__(self,
         :param delegate_to: The account to impersonate, if any. For this to
             work, the service account making the request must have domain-wide
             delegation enabled.
+        :param export_format: Details for files to be exported into GCS.
+            Allows to specify 'json' or 'csv', and also addiitional details for
+            CSV file exports (quotes, separators, etc.)
+            This is a dict with the following key-value pairs:
+              * file_format: 'json' or 'csv'. If using CSV, more details can
+                              be added
+              * csv_dialect: preconfigured set of CSV export parameters
+                             (i.e.: 'excel', 'excel-tab', 'unix_dialect').
+                             If present, will ignore all other 'csv_' options.
+                             See https://docs.python.org/3/library/csv.html
+              * csv_delimiter: A one-character string used to separate fields.
+                               It defaults to ','.
+              * csv_doublequote: If doublequote is False and no escapechar is 
set,
+                                 Error is raised if a quotechar is found in a 
field.
+                                 It defaults to True.
+              * csv_escapechar: A one-character string used to escape the 
delimiter
+                                if quoting is set to QUOTE_NONE and the 
quotechar
+                                if doublequote is False.
+                                It defaults to None, which disables escaping.
+              * csv_lineterminator: The string used to terminate lines.
+                                    It defaults to '\r\n'.
+              * csv_quotechar: A one-character string used to quote fields
+                                containing special characters, such as the 
delimiter
+                                or quotechar, or which contain new-line 
characters.
+                                It defaults to '"'.
+              * csv_quoting: Controls when quotes should be generated.
+                             It can take on any of the QUOTE_* constants
+                             Defaults to csv.QUOTE_MINIMAL.
+                             Valid values are:
+                             'csv.QUOTE_ALL': Quote all fields
+                             'csv.QUOTE_MINIMAL': only quote those fields 
which contain
+                                                    special characters such as 
delimiter,
+                                                    quotechar or any of the 
characters
+                                                    in lineterminator.
+                             'csv.QUOTE_NONNUMERIC': Quote all non-numeric 
fields.
+                             'csv.QUOTE_NONE': never quote fields. When the 
current
+                                                delimiter occurs in output 
data it is
+                                                preceded by the current 
escapechar
+                                                character. If escapechar is 
not set,
+                                                the writer will raise Error if 
any
+                                                characters that require 
escaping are
+                                                encountered.
+              * csv_columnheader: If True, first row in the file will include 
column
+                                  names. Defaults to False.
         """
         super(MySqlToGoogleCloudStorageOperator, self).__init__(*args, 
**kwargs)
         self.sql = sql
@@ -93,6 +139,7 @@ def __init__(self,
         self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
         self.schema = schema
         self.delegate_to = delegate_to
+        self.export_format = export_format
 
     def execute(self, context):
         cursor = self._query_mysql()
@@ -135,19 +182,63 @@ def _write_local_data_files(self, cursor):
         tmp_file_handle = NamedTemporaryFile(delete=True)
         tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
 
+        # Save file header for csv if required
+        if(self.export_format['file_format'] == 'csv'):
+
+            # Deal with CSV formatting. Try to use dialect if passed
+            if('csv_dialect' in self.export_format):
+                # Use dialect name from params
+                dialect_name = self.export_format['csv_dialect']
+            else:
+                # Create internal dialect based on parameters passed
+                dialect_name = 'mysql_to_gcs'
+                csv.register_dialect(dialect_name,
+                                     
delimiter=self.export_format.get('csv_delimiter') or
+                                     ',',
+                                     doublequote=self.export_format.get(
+                                         'csv_doublequote') or
+                                     'True',
+                                     escapechar=self.export_format.get(
+                                         'csv_escapechar') or
+                                     None,
+                                     lineterminator=self.export_format.get(
+                                         'csv_lineterminator') or
+                                     '\r\n',
+                                     
quotechar=self.export_format.get('csv_quotechar') or
+                                     '"',
+                                     quoting=eval(self.export_format.get(
+                                         'csv_quoting') or
+                                         'csv.QUOTE_MINIMAL'))
+            # Create CSV writer using either provided or generated dialect
+            csv_writer = csv.writer(tmp_file_handle,
+                                    encoding='utf-8',
+                                    dialect=dialect_name)
+
+            # Include column header in first row
+            if('csv_columnheader' in self.export_format and
+                    eval(self.export_format['csv_columnheader'])):
+                csv_writer.writerow(schema)
+
         for row in cursor:
-            # Convert datetime objects to utc seconds, and decimals to floats
+            # Convert datetimes and longs to BigQuery safe types
             row = map(self.convert_types, row)
-            row_dict = dict(zip(schema, row))
 
-            # TODO validate that row isn't > 2MB. BQ enforces a hard row size 
of 2MB.
-            s = json.dumps(row_dict)
-            if PY3:
-                s = s.encode('utf-8')
-            tmp_file_handle.write(s)
+            # Save rows as CSV
+            if(self.export_format['file_format'] == 'csv'):
+                csv_writer.writerow(row)
+            # Save rows as JSON
+            else:
+                # Convert datetime objects to utc seconds, and decimals to 
floats
+                row_dict = dict(zip(schema, row))
 
-            # Append newline to make dumps BigQuery compatible.
-            tmp_file_handle.write(b'\n')
+                # TODO validate that row isn't > 2MB. BQ enforces a hard row 
size of 2MB.
+                s = json.dumps(row_dict, sort_keys=True)
+                if PY3:
+                    s = s.encode('utf-8')
+                tmp_file_handle.write(s)
+
+                # Append newline to make dumps BigQuery compatible.
+                tmp_file_handle.write(b'\n')
 
             # Stop if the file exceeds the file size limit.
             if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
@@ -155,6 +246,16 @@ def _write_local_data_files(self, cursor):
                 tmp_file_handle = NamedTemporaryFile(delete=True)
                 tmp_file_handles[self.filename.format(file_no)] = 
tmp_file_handle
 
+                # For CSV files, weed to create a new writer with the new 
handle
+                # and write header in first row
+                if(self.export_format['file_format'] == 'csv'):
+                    csv_writer = csv.writer(tmp_file_handle,
+                                            encoding='utf-8',
+                                            dialect=dialect_name)
+                    if('csv_columnheader' in self.export_format and
+                            eval(self.export_format['csv_columnheader'])):
+                        csv_writer.writerow(schema)
+
         return tmp_file_handles
 
     def _write_local_schema_file(self, cursor):
@@ -191,7 +292,7 @@ def _write_local_schema_file(self, cursor):
                         'type': field_type,
                         'mode': field_mode,
                     })
-            s = json.dumps(schema, tmp_schema_file_handle)
+            s = json.dumps(schema, tmp_schema_file_handle, sort_keys=True)
             if PY3:
                 s = s.encode('utf-8')
             tmp_schema_file_handle.write(s)
@@ -204,11 +305,13 @@ def _upload_to_gcs(self, files_to_upload):
         Upload all of the file splits (and optionally the schema .json file) to
         Google cloud storage.
         """
+        # Compose mime_type using file format passed as param
+        mime_type = 'application/' + self.export_format['file_format']
         hook = GoogleCloudStorageHook(
             google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
             delegate_to=self.delegate_to)
         for object, tmp_file_handle in files_to_upload.items():
-            hook.upload(self.bucket, object, tmp_file_handle.name, 
'application/json')
+            hook.upload(self.bucket, object, tmp_file_handle.name, mime_type)
 
     @classmethod
     def convert_types(cls, value):
diff --git a/tests/contrib/operators/test_mysql_to_gcs_operator.py 
b/tests/contrib/operators/test_mysql_to_gcs_operator.py
new file mode 100644
index 0000000000..8f04466826
--- /dev/null
+++ b/tests/contrib/operators/test_mysql_to_gcs_operator.py
@@ -0,0 +1,207 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import sys
+import unittest
+
+from airflow.contrib.operators.mysql_to_gcs import 
MySqlToGoogleCloudStorageOperator
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+PY3 = sys.version_info[0] == 3
+
+TASK_ID = 'test-mysql-to-gcs'
+MYSQL_CONN_ID = 'mysql_conn_test'
+SQL = 'select 1'
+BUCKET = 'gs://test'
+FILENAME = 'test_{}.ndjson'
+
+if PY3:
+    ROWS = [
+        ('mock_row_content_1', 42),
+        ('mock_row_content_2', 43),
+        ('mock_row_content_3', 44)
+    ]
+    CURSOR_DESCRIPTION = (
+        ('some_str', 0, 0, 0, 0, 0, False),
+        ('some_num', 1005, 0, 0, 0, 0, False)
+    )
+else:
+    ROWS = [
+        (b'mock_row_content_1', 42),
+        (b'mock_row_content_2', 43),
+        (b'mock_row_content_3', 44)
+    ]
+    CURSOR_DESCRIPTION = (
+        (b'some_str', 0, 0, 0, 0, 0, False),
+        (b'some_num', 1005, 0, 0, 0, 0, False)
+    )
+NDJSON_LINES = [
+    b'{"some_num": 42, "some_str": "mock_row_content_1"}\n',
+    b'{"some_num": 43, "some_str": "mock_row_content_2"}\n',
+    b'{"some_num": 44, "some_str": "mock_row_content_3"}\n'
+]
+CSV_LINES = [
+    b'mock_row_content_1,42\r\n',
+    b'mock_row_content_2,43\r\n',
+    b'mock_row_content_3,44\r\n'
+]
+SCHEMA_FILENAME = 'schema_test.json'
+SCHEMA_JSON = [
+    b'[{"mode": "REQUIRED", "name": "some_str", "type": "FLOAT"}, ',
+    b'{"mode": "REQUIRED", "name": "some_num", "type": "STRING"}]'
+]
+
+
+class MySqlToGoogleCloudStorageOperatorTest(unittest.TestCase):
+    def test_init(self):
+        """Test MySqlToGoogleCloudStorageOperator instance is properly 
initialized."""
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=FILENAME)
+        self.assertEqual(op.task_id, TASK_ID)
+        self.assertEqual(op.sql, SQL)
+        self.assertEqual(op.bucket, BUCKET)
+        self.assertEqual(op.filename, FILENAME)
+
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.MySqlHook')
+    
@mock.patch('airflow.contrib.operators.mysql_to_gcs.GoogleCloudStorageHook')
+    def test_exec_success_json(self, gcs_hook_mock_class, 
mysql_hook_mock_class):
+        """Test the execute function in case where the run is successful."""
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID,
+            mysql_conn_id=MYSQL_CONN_ID,
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME)
+
+        mysql_hook_mock = mysql_hook_mock_class.return_value
+        mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
+        mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
+
+        gcs_hook_mock = gcs_hook_mock_class.return_value
+
+        def _assert_upload(bucket, obj, tmp_filename, content_type):
+            self.assertEqual(BUCKET, bucket)
+            self.assertEqual(FILENAME.format(0), obj)
+            self.assertEqual('application/json', content_type)
+            with open(tmp_filename, 'rb') as f:
+                self.assertEqual(b''.join(NDJSON_LINES), f.read())
+
+        gcs_hook_mock.upload.side_effect = _assert_upload
+
+        op.execute(None)
+
+        
mysql_hook_mock_class.assert_called_once_with(mysql_conn_id=MYSQL_CONN_ID)
+        
mysql_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL)
+
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.MySqlHook')
+    
@mock.patch('airflow.contrib.operators.mysql_to_gcs.GoogleCloudStorageHook')
+    def test_exec_success_csv(self, gcs_hook_mock_class, 
mysql_hook_mock_class):
+        """Test the execute function in case where the run is successful."""
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID,
+            mysql_conn_id=MYSQL_CONN_ID,
+            sql=SQL,
+            export_format={'file_format': 'csv', 'csv_dialect': 'excel'},
+            bucket=BUCKET,
+            filename=FILENAME)
+
+        mysql_hook_mock = mysql_hook_mock_class.return_value
+        mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
+        mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
+
+        gcs_hook_mock = gcs_hook_mock_class.return_value
+
+        def _assert_upload(bucket, obj, tmp_filename, content_type):
+            self.assertEqual(BUCKET, bucket)
+            self.assertEqual(FILENAME.format(0), obj)
+            self.assertEqual('application/csv', content_type)
+            with open(tmp_filename, 'rb') as f:
+                self.assertEqual(b''.join(CSV_LINES), f.read())
+
+        gcs_hook_mock.upload.side_effect = _assert_upload
+
+        op.execute(None)
+
+        
mysql_hook_mock_class.assert_called_once_with(mysql_conn_id=MYSQL_CONN_ID)
+        
mysql_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL)
+
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.MySqlHook')
+    
@mock.patch('airflow.contrib.operators.mysql_to_gcs.GoogleCloudStorageHook')
+    def test_file_splitting(self, gcs_hook_mock_class, mysql_hook_mock_class):
+        """Test that ndjson is split by approx_max_file_size_bytes param."""
+        mysql_hook_mock = mysql_hook_mock_class.return_value
+        mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
+        mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
+
+        gcs_hook_mock = gcs_hook_mock_class.return_value
+        expected_upload = {
+            FILENAME.format(0): b''.join(NDJSON_LINES[:2]),
+            FILENAME.format(1): NDJSON_LINES[2],
+        }
+
+        def _assert_upload(bucket, obj, tmp_filename, content_type):
+            self.assertEqual(BUCKET, bucket)
+            self.assertEqual('application/json', content_type)
+            with open(tmp_filename, 'rb') as f:
+                self.assertEqual(expected_upload[obj], f.read())
+
+        gcs_hook_mock.upload.side_effect = _assert_upload
+
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID,
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME,
+            
approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]))
+        op.execute(None)
+
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.MySqlHook')
+    
@mock.patch('airflow.contrib.operators.mysql_to_gcs.GoogleCloudStorageHook')
+    def test_schema_file(self, gcs_hook_mock_class, mysql_hook_mock_class):
+        """Test writing schema files."""
+        mysql_hook_mock = mysql_hook_mock_class.return_value
+        mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
+        mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
+
+        gcs_hook_mock = gcs_hook_mock_class.return_value
+
+        def _assert_upload(bucket, obj, tmp_filename, content_type):
+            if obj == SCHEMA_FILENAME:
+                with open(tmp_filename, 'rb') as f:
+                    self.assertEqual(b''.join(SCHEMA_JSON), f.read())
+
+        gcs_hook_mock.upload.side_effect = _assert_upload
+
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID,
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME,
+            schema_filename=SCHEMA_FILENAME)
+        op.execute(None)
+
+        # once for the file and once for the schema
+        self.assertEqual(2, gcs_hook_mock.upload.call_count)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to