You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/18 04:22:40 UTC

[GitHub] stale[bot] closed pull request #3139: [AIRFLOW-2224] Add support for CSV files in mysql_to_gcs operator

stale[bot] closed pull request #3139: [AIRFLOW-2224] Add support for CSV files in mysql_to_gcs operator
URL: https://github.com/apache/incubator-airflow/pull/3139
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/operators/mysql_to_gcs.py b/airflow/contrib/operators/mysql_to_gcs.py
index 9ba84c7556..c0c48c5c68 100644
--- a/airflow/contrib/operators/mysql_to_gcs.py
+++ b/airflow/contrib/operators/mysql_to_gcs.py
@@ -25,13 +25,14 @@
 from MySQLdb.constants import FIELD_TYPE
 from tempfile import NamedTemporaryFile
 from six import string_types
+import unicodecsv as csv
 
 PY3 = sys.version_info[0] == 3
 
 
 class MySqlToGoogleCloudStorageOperator(BaseOperator):
     """
-    Copy data from MySQL to Google cloud storage in JSON format.
+    Copy data from MySQL to Google cloud storage in JSON or CSV format.
     """
     template_fields = ('sql', 'bucket', 'filename', 'schema_filename', 'schema')
     template_ext = ('.sql',)
@@ -48,6 +49,7 @@ def __init__(self,
                  google_cloud_storage_conn_id='google_cloud_storage_default',
                  schema=None,
                  delegate_to=None,
+                 export_format={'file_format': 'json'},
                  *args,
                  **kwargs):
         """
@@ -82,6 +84,50 @@ def __init__(self,
         :param delegate_to: The account to impersonate, if any. For this to
             work, the service account making the request must have domain-wide
             delegation enabled.
+        :param export_format: Details for files to be exported into GCS.
+            Allows to specify 'json' or 'csv', and also addiitional details for
+            CSV file exports (quotes, separators, etc.)
+            This is a dict with the following key-value pairs:
+              * file_format: 'json' or 'csv'. If using CSV, more details can
+                              be added
+              * csv_dialect: preconfigured set of CSV export parameters
+                             (i.e.: 'excel', 'excel-tab', 'unix_dialect').
+                             If present, will ignore all other 'csv_' options.
+                             See https://docs.python.org/3/library/csv.html
+              * csv_delimiter: A one-character string used to separate fields.
+                               It defaults to ','.
+              * csv_doublequote: If doublequote is False and no escapechar is set,
+                                 Error is raised if a quotechar is found in a field.
+                                 It defaults to True.
+              * csv_escapechar: A one-character string used to escape the delimiter
+                                if quoting is set to QUOTE_NONE and the quotechar
+                                if doublequote is False.
+                                It defaults to None, which disables escaping.
+              * csv_lineterminator: The string used to terminate lines.
+                                    It defaults to '\r\n'.
+              * csv_quotechar: A one-character string used to quote fields
+                                containing special characters, such as the delimiter
+                                or quotechar, or which contain new-line characters.
+                                It defaults to '"'.
+              * csv_quoting: Controls when quotes should be generated.
+                             It can take on any of the QUOTE_* constants
+                             Defaults to csv.QUOTE_MINIMAL.
+                             Valid values are:
+                             'csv.QUOTE_ALL': Quote all fields
+                             'csv.QUOTE_MINIMAL': only quote those fields which contain
+                                                    special characters such as delimiter,
+                                                    quotechar or any of the characters
+                                                    in lineterminator.
+                             'csv.QUOTE_NONNUMERIC': Quote all non-numeric fields.
+                             'csv.QUOTE_NONE': never quote fields. When the current
+                                                delimiter occurs in output data it is
+                                                preceded by the current escapechar
+                                                character. If escapechar is not set,
+                                                the writer will raise Error if any
+                                                characters that require escaping are
+                                                encountered.
+              * csv_columnheader: If True, first row in the file will include column
+                                  names. Defaults to False.
         """
         super(MySqlToGoogleCloudStorageOperator, self).__init__(*args, **kwargs)
         self.sql = sql
@@ -93,6 +139,7 @@ def __init__(self,
         self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
         self.schema = schema
         self.delegate_to = delegate_to
+        self.export_format = export_format
 
     def execute(self, context):
         cursor = self._query_mysql()
@@ -135,19 +182,63 @@ def _write_local_data_files(self, cursor):
         tmp_file_handle = NamedTemporaryFile(delete=True)
         tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
 
+        # Save file header for csv if required
+        if(self.export_format['file_format'] == 'csv'):
+
+            # Deal with CSV formatting. Try to use dialect if passed
+            if('csv_dialect' in self.export_format):
+                # Use dialect name from params
+                dialect_name = self.export_format['csv_dialect']
+            else:
+                # Create internal dialect based on parameters passed
+                dialect_name = 'mysql_to_gcs'
+                csv.register_dialect(dialect_name,
+                                     delimiter=self.export_format.get('csv_delimiter') or
+                                     ',',
+                                     doublequote=self.export_format.get(
+                                         'csv_doublequote') or
+                                     'True',
+                                     escapechar=self.export_format.get(
+                                         'csv_escapechar') or
+                                     None,
+                                     lineterminator=self.export_format.get(
+                                         'csv_lineterminator') or
+                                     '\r\n',
+                                     quotechar=self.export_format.get('csv_quotechar') or
+                                     '"',
+                                     quoting=eval(self.export_format.get(
+                                         'csv_quoting') or
+                                         'csv.QUOTE_MINIMAL'))
+            # Create CSV writer using either provided or generated dialect
+            csv_writer = csv.writer(tmp_file_handle,
+                                    encoding='utf-8',
+                                    dialect=dialect_name)
+
+            # Include column header in first row
+            if('csv_columnheader' in self.export_format and
+                    eval(self.export_format['csv_columnheader'])):
+                csv_writer.writerow(schema)
+
         for row in cursor:
-            # Convert datetime objects to utc seconds, and decimals to floats
+            # Convert datetimes and longs to BigQuery safe types
             row = map(self.convert_types, row)
-            row_dict = dict(zip(schema, row))
 
-            # TODO validate that row isn't > 2MB. BQ enforces a hard row size of 2MB.
-            s = json.dumps(row_dict)
-            if PY3:
-                s = s.encode('utf-8')
-            tmp_file_handle.write(s)
+            # Save rows as CSV
+            if(self.export_format['file_format'] == 'csv'):
+                csv_writer.writerow(row)
+            # Save rows as JSON
+            else:
+                # Convert datetime objects to utc seconds, and decimals to floats
+                row_dict = dict(zip(schema, row))
 
-            # Append newline to make dumps BigQuery compatible.
-            tmp_file_handle.write(b'\n')
+                # TODO validate that row isn't > 2MB. BQ enforces a hard row size of 2MB.
+                s = json.dumps(row_dict, sort_keys=True)
+                if PY3:
+                    s = s.encode('utf-8')
+                tmp_file_handle.write(s)
+
+                # Append newline to make dumps BigQuery compatible.
+                tmp_file_handle.write(b'\n')
 
             # Stop if the file exceeds the file size limit.
             if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
@@ -155,6 +246,16 @@ def _write_local_data_files(self, cursor):
                 tmp_file_handle = NamedTemporaryFile(delete=True)
                 tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle
 
+                # For CSV files, weed to create a new writer with the new handle
+                # and write header in first row
+                if(self.export_format['file_format'] == 'csv'):
+                    csv_writer = csv.writer(tmp_file_handle,
+                                            encoding='utf-8',
+                                            dialect=dialect_name)
+                    if('csv_columnheader' in self.export_format and
+                            eval(self.export_format['csv_columnheader'])):
+                        csv_writer.writerow(schema)
+
         return tmp_file_handles
 
     def _write_local_schema_file(self, cursor):
@@ -191,7 +292,7 @@ def _write_local_schema_file(self, cursor):
                         'type': field_type,
                         'mode': field_mode,
                     })
-            s = json.dumps(schema, tmp_schema_file_handle)
+            s = json.dumps(schema, tmp_schema_file_handle, sort_keys=True)
             if PY3:
                 s = s.encode('utf-8')
             tmp_schema_file_handle.write(s)
@@ -204,11 +305,13 @@ def _upload_to_gcs(self, files_to_upload):
         Upload all of the file splits (and optionally the schema .json file) to
         Google cloud storage.
         """
+        # Compose mime_type using file format passed as param
+        mime_type = 'application/' + self.export_format['file_format']
         hook = GoogleCloudStorageHook(
             google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
             delegate_to=self.delegate_to)
         for object, tmp_file_handle in files_to_upload.items():
-            hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json')
+            hook.upload(self.bucket, object, tmp_file_handle.name, mime_type)
 
     @classmethod
     def convert_types(cls, value):
diff --git a/tests/contrib/operators/test_mysql_to_gcs_operator.py b/tests/contrib/operators/test_mysql_to_gcs_operator.py
new file mode 100644
index 0000000000..8f04466826
--- /dev/null
+++ b/tests/contrib/operators/test_mysql_to_gcs_operator.py
@@ -0,0 +1,207 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import sys
+import unittest
+
+from airflow.contrib.operators.mysql_to_gcs import MySqlToGoogleCloudStorageOperator
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+PY3 = sys.version_info[0] == 3
+
+TASK_ID = 'test-mysql-to-gcs'
+MYSQL_CONN_ID = 'mysql_conn_test'
+SQL = 'select 1'
+BUCKET = 'gs://test'
+FILENAME = 'test_{}.ndjson'
+
+if PY3:
+    ROWS = [
+        ('mock_row_content_1', 42),
+        ('mock_row_content_2', 43),
+        ('mock_row_content_3', 44)
+    ]
+    CURSOR_DESCRIPTION = (
+        ('some_str', 0, 0, 0, 0, 0, False),
+        ('some_num', 1005, 0, 0, 0, 0, False)
+    )
+else:
+    ROWS = [
+        (b'mock_row_content_1', 42),
+        (b'mock_row_content_2', 43),
+        (b'mock_row_content_3', 44)
+    ]
+    CURSOR_DESCRIPTION = (
+        (b'some_str', 0, 0, 0, 0, 0, False),
+        (b'some_num', 1005, 0, 0, 0, 0, False)
+    )
+NDJSON_LINES = [
+    b'{"some_num": 42, "some_str": "mock_row_content_1"}\n',
+    b'{"some_num": 43, "some_str": "mock_row_content_2"}\n',
+    b'{"some_num": 44, "some_str": "mock_row_content_3"}\n'
+]
+CSV_LINES = [
+    b'mock_row_content_1,42\r\n',
+    b'mock_row_content_2,43\r\n',
+    b'mock_row_content_3,44\r\n'
+]
+SCHEMA_FILENAME = 'schema_test.json'
+SCHEMA_JSON = [
+    b'[{"mode": "REQUIRED", "name": "some_str", "type": "FLOAT"}, ',
+    b'{"mode": "REQUIRED", "name": "some_num", "type": "STRING"}]'
+]
+
+
+class MySqlToGoogleCloudStorageOperatorTest(unittest.TestCase):
+    def test_init(self):
+        """Test MySqlToGoogleCloudStorageOperator instance is properly initialized."""
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID, sql=SQL, bucket=BUCKET, filename=FILENAME)
+        self.assertEqual(op.task_id, TASK_ID)
+        self.assertEqual(op.sql, SQL)
+        self.assertEqual(op.bucket, BUCKET)
+        self.assertEqual(op.filename, FILENAME)
+
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.MySqlHook')
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.GoogleCloudStorageHook')
+    def test_exec_success_json(self, gcs_hook_mock_class, mysql_hook_mock_class):
+        """Test the execute function in case where the run is successful."""
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID,
+            mysql_conn_id=MYSQL_CONN_ID,
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME)
+
+        mysql_hook_mock = mysql_hook_mock_class.return_value
+        mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
+        mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
+
+        gcs_hook_mock = gcs_hook_mock_class.return_value
+
+        def _assert_upload(bucket, obj, tmp_filename, content_type):
+            self.assertEqual(BUCKET, bucket)
+            self.assertEqual(FILENAME.format(0), obj)
+            self.assertEqual('application/json', content_type)
+            with open(tmp_filename, 'rb') as f:
+                self.assertEqual(b''.join(NDJSON_LINES), f.read())
+
+        gcs_hook_mock.upload.side_effect = _assert_upload
+
+        op.execute(None)
+
+        mysql_hook_mock_class.assert_called_once_with(mysql_conn_id=MYSQL_CONN_ID)
+        mysql_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL)
+
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.MySqlHook')
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.GoogleCloudStorageHook')
+    def test_exec_success_csv(self, gcs_hook_mock_class, mysql_hook_mock_class):
+        """Test the execute function in case where the run is successful."""
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID,
+            mysql_conn_id=MYSQL_CONN_ID,
+            sql=SQL,
+            export_format={'file_format': 'csv', 'csv_dialect': 'excel'},
+            bucket=BUCKET,
+            filename=FILENAME)
+
+        mysql_hook_mock = mysql_hook_mock_class.return_value
+        mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
+        mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
+
+        gcs_hook_mock = gcs_hook_mock_class.return_value
+
+        def _assert_upload(bucket, obj, tmp_filename, content_type):
+            self.assertEqual(BUCKET, bucket)
+            self.assertEqual(FILENAME.format(0), obj)
+            self.assertEqual('application/csv', content_type)
+            with open(tmp_filename, 'rb') as f:
+                self.assertEqual(b''.join(CSV_LINES), f.read())
+
+        gcs_hook_mock.upload.side_effect = _assert_upload
+
+        op.execute(None)
+
+        mysql_hook_mock_class.assert_called_once_with(mysql_conn_id=MYSQL_CONN_ID)
+        mysql_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL)
+
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.MySqlHook')
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.GoogleCloudStorageHook')
+    def test_file_splitting(self, gcs_hook_mock_class, mysql_hook_mock_class):
+        """Test that ndjson is split by approx_max_file_size_bytes param."""
+        mysql_hook_mock = mysql_hook_mock_class.return_value
+        mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
+        mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
+
+        gcs_hook_mock = gcs_hook_mock_class.return_value
+        expected_upload = {
+            FILENAME.format(0): b''.join(NDJSON_LINES[:2]),
+            FILENAME.format(1): NDJSON_LINES[2],
+        }
+
+        def _assert_upload(bucket, obj, tmp_filename, content_type):
+            self.assertEqual(BUCKET, bucket)
+            self.assertEqual('application/json', content_type)
+            with open(tmp_filename, 'rb') as f:
+                self.assertEqual(expected_upload[obj], f.read())
+
+        gcs_hook_mock.upload.side_effect = _assert_upload
+
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID,
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME,
+            approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]))
+        op.execute(None)
+
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.MySqlHook')
+    @mock.patch('airflow.contrib.operators.mysql_to_gcs.GoogleCloudStorageHook')
+    def test_schema_file(self, gcs_hook_mock_class, mysql_hook_mock_class):
+        """Test writing schema files."""
+        mysql_hook_mock = mysql_hook_mock_class.return_value
+        mysql_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
+        mysql_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
+
+        gcs_hook_mock = gcs_hook_mock_class.return_value
+
+        def _assert_upload(bucket, obj, tmp_filename, content_type):
+            if obj == SCHEMA_FILENAME:
+                with open(tmp_filename, 'rb') as f:
+                    self.assertEqual(b''.join(SCHEMA_JSON), f.read())
+
+        gcs_hook_mock.upload.side_effect = _assert_upload
+
+        op = MySqlToGoogleCloudStorageOperator(
+            task_id=TASK_ID,
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME,
+            schema_filename=SCHEMA_FILENAME)
+        op.execute(None)
+
+        # once for the file and once for the schema
+        self.assertEqual(2, gcs_hook_mock.upload.call_count)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services