You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/04/04 15:11:31 UTC

[airflow] branch main updated: Modify transfer operators to handle more data (#22495)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 99b0211d50 Modify transfer operators to handle more data (#22495)
99b0211d50 is described below

commit 99b0211d5087cf486415b5fc8399d3f15d84ed69
Author: Matthew Wallace <gi...@matthewwallace.me>
AuthorDate: Mon Apr 4 09:11:23 2022 -0600

    Modify transfer operators to handle more data (#22495)
    
    * Modify transfer operators to handle more data
    
    This addresses an issue where large data imports can result in filling
    all available disk space and cause the task to fail.
    
    Previously all data would be written out to disk before any was uploaded
    to GCS. Now each data chunk is written to GCS and immediately freed.
---
 .../google/cloud/transfers/cassandra_to_gcs.py     | 71 +++++++++++-------
 .../providers/google/cloud/transfers/sql_to_gcs.py | 87 +++++++++++-----------
 .../google/cloud/transfers/test_sql_to_gcs.py      | 22 ++++--
 3 files changed, 104 insertions(+), 76 deletions(-)

diff --git a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
index 2c4bff7217..9fadba0bf0 100644
--- a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
@@ -169,21 +169,30 @@ class CassandraToGCSOperator(BaseOperator):
 
         cursor = hook.get_conn().execute(self.cql, **query_extra)
 
-        files_to_upload = self._write_local_data_files(cursor)
-
         # If a schema is set, create a BQ schema JSON file.
         if self.schema_filename:
-            files_to_upload.update(self._write_local_schema_file(cursor))
+            self.log.info('Writing local schema file')
+            schema_file = self._write_local_schema_file(cursor)
+
+            # Flush file before uploading
+            schema_file['file_handle'].flush()
+
+            self.log.info('Uploading schema file to GCS.')
+            self._upload_to_gcs(schema_file)
+            schema_file['file_handle'].close()
 
-        # Flush all files before uploading
-        for file_handle in files_to_upload.values():
-            file_handle.flush()
+        counter = 0
+        self.log.info('Writing local data files')
+        for file_to_upload in self._write_local_data_files(cursor):
+            # Flush file before uploading
+            file_to_upload['file_handle'].flush()
 
-        self._upload_to_gcs(files_to_upload)
+            self.log.info('Uploading chunk file #%d to GCS.', counter)
+            self._upload_to_gcs(file_to_upload)
 
-        # Close all temp file handles.
-        for file_handle in files_to_upload.values():
-            file_handle.close()
+            self.log.info('Removing local file')
+            file_to_upload['file_handle'].close()
+            counter += 1
 
         # Close all sessions and connection associated with this Cassandra cluster
         hook.shutdown_cluster()
@@ -197,8 +206,12 @@ class CassandraToGCSOperator(BaseOperator):
             contain the data for the GCS objects.
         """
         file_no = 0
+
         tmp_file_handle = NamedTemporaryFile(delete=True)
-        tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
+        file_to_upload = {
+            'file_name': self.filename.format(file_no),
+            'file_handle': tmp_file_handle,
+        }
         for row in cursor:
             row_dict = self.generate_data_dict(row._fields, row)
             content = json.dumps(row_dict).encode('utf-8')
@@ -209,10 +222,14 @@ class CassandraToGCSOperator(BaseOperator):
 
             if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
                 file_no += 1
-                tmp_file_handle = NamedTemporaryFile(delete=True)
-                tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle
 
-        return tmp_file_handles
+                yield file_to_upload
+                tmp_file_handle = NamedTemporaryFile(delete=True)
+                file_to_upload = {
+                    'file_name': self.filename.format(file_no),
+                    'file_handle': tmp_file_handle,
+                }
+        yield file_to_upload
 
     def _write_local_schema_file(self, cursor):
         """
@@ -231,22 +248,26 @@ class CassandraToGCSOperator(BaseOperator):
         json_serialized_schema = json.dumps(schema).encode('utf-8')
 
         tmp_schema_file_handle.write(json_serialized_schema)
-        return {self.schema_filename: tmp_schema_file_handle}
-
-    def _upload_to_gcs(self, files_to_upload: Dict[str, Any]):
+        schema_file_to_upload = {
+            'file_name': self.schema_filename,
+            'file_handle': tmp_schema_file_handle,
+        }
+        return schema_file_to_upload
+
+    def _upload_to_gcs(self, file_to_upload):
+        """Upload a file (data split or schema .json file) to Google Cloud Storage."""
         hook = GCSHook(
             gcp_conn_id=self.gcp_conn_id,
             delegate_to=self.delegate_to,
             impersonation_chain=self.impersonation_chain,
         )
-        for obj, tmp_file_handle in files_to_upload.items():
-            hook.upload(
-                bucket_name=self.bucket,
-                object_name=obj,
-                filename=tmp_file_handle.name,
-                mime_type='application/json',
-                gzip=self.gzip,
-            )
+        hook.upload(
+            bucket_name=self.bucket,
+            object_name=file_to_upload.get('file_name'),
+            filename=file_to_upload.get('file_handle').name,
+            mime_type='application/json',
+            gzip=self.gzip,
+        )
 
     @classmethod
     def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any]:
diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
index ed2e28cb30..077e8d2085 100644
--- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
@@ -139,24 +139,30 @@ class BaseSQLToGCSOperator(BaseOperator):
         self.log.info("Executing query")
         cursor = self.query()
 
-        self.log.info("Writing local data files")
-        files_to_upload = self._write_local_data_files(cursor)
         # If a schema is set, create a BQ schema JSON file.
         if self.schema_filename:
-            self.log.info("Writing local schema file")
-            files_to_upload.append(self._write_local_schema_file(cursor))
+            self.log.info('Writing local schema file')
+            schema_file = self._write_local_schema_file(cursor)
 
-        # Flush all files before uploading
-        for tmp_file in files_to_upload:
-            tmp_file['file_handle'].flush()
+            # Flush file before uploading
+            schema_file['file_handle'].flush()
 
-        self.log.info("Uploading %d files to GCS.", len(files_to_upload))
-        self._upload_to_gcs(files_to_upload)
+            self.log.info('Uploading schema file to GCS.')
+            self._upload_to_gcs(schema_file)
+            schema_file['file_handle'].close()
 
-        self.log.info("Removing local files")
-        # Close all temp file handles.
-        for tmp_file in files_to_upload:
-            tmp_file['file_handle'].close()
+        counter = 0
+        self.log.info('Writing local data files')
+        for file_to_upload in self._write_local_data_files(cursor):
+            # Flush file before uploading
+            file_to_upload['file_handle'].flush()
+
+            self.log.info('Uploading chunk file #%d to GCS.', counter)
+            self._upload_to_gcs(file_to_upload)
+
+            self.log.info('Removing local file')
+            file_to_upload['file_handle'].close()
+            counter += 1
 
     def convert_types(self, schema, col_type_dict, row) -> list:
         """Convert values from DBAPI to output-friendly formats."""
@@ -181,14 +187,11 @@ class BaseSQLToGCSOperator(BaseOperator):
             file_mime_type = 'application/octet-stream'
         else:
             file_mime_type = 'application/json'
-        files_to_upload = [
-            {
-                'file_name': self.filename.format(file_no),
-                'file_handle': tmp_file_handle,
-                'file_mime_type': file_mime_type,
-            }
-        ]
-        self.log.info("Current file count: %d", len(files_to_upload))
+        file_to_upload = {
+            'file_name': self.filename.format(file_no),
+            'file_handle': tmp_file_handle,
+            'file_mime_type': file_mime_type,
+        }
 
         if self.export_format == 'csv':
             csv_writer = self._configure_csv_file(tmp_file_handle, schema)
@@ -225,20 +228,22 @@ class BaseSQLToGCSOperator(BaseOperator):
             if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
                 file_no += 1
 
+                if self.export_format == 'parquet':
+                    parquet_writer.close()
+                yield file_to_upload
                 tmp_file_handle = NamedTemporaryFile(delete=True)
-                files_to_upload.append(
-                    {
-                        'file_name': self.filename.format(file_no),
-                        'file_handle': tmp_file_handle,
-                        'file_mime_type': file_mime_type,
-                    }
-                )
-                self.log.info("Current file count: %d", len(files_to_upload))
+                file_to_upload = {
+                    'file_name': self.filename.format(file_no),
+                    'file_handle': tmp_file_handle,
+                    'file_mime_type': file_mime_type,
+                }
                 if self.export_format == 'csv':
                     csv_writer = self._configure_csv_file(tmp_file_handle, schema)
                 if self.export_format == 'parquet':
                     parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
-        return files_to_upload
+        if self.export_format == 'parquet':
+            parquet_writer.close()
+        yield file_to_upload
 
     def _configure_csv_file(self, file_handle, schema):
         """Configure a csv writer with the file_handle and write schema
@@ -338,21 +343,17 @@ class BaseSQLToGCSOperator(BaseOperator):
         }
         return schema_file_to_upload
 
-    def _upload_to_gcs(self, files_to_upload):
-        """
-        Upload all of the file splits (and optionally the schema .json file) to
-        Google Cloud Storage.
-        """
+    def _upload_to_gcs(self, file_to_upload):
+        """Upload a file (data split or schema .json file) to Google Cloud Storage."""
         hook = GCSHook(
             gcp_conn_id=self.gcp_conn_id,
             delegate_to=self.delegate_to,
             impersonation_chain=self.impersonation_chain,
         )
-        for tmp_file in files_to_upload:
-            hook.upload(
-                self.bucket,
-                tmp_file.get('file_name'),
-                tmp_file.get('file_handle').name,
-                mime_type=tmp_file.get('file_mime_type'),
-                gzip=self.gzip if tmp_file.get('file_name') != self.schema_filename else False,
-            )
+        hook.upload(
+            self.bucket,
+            file_to_upload.get('file_name'),
+            file_to_upload.get('file_handle').name,
+            mime_type=file_to_upload.get('file_mime_type'),
+            gzip=self.gzip if file_to_upload.get('file_name') != self.schema_filename else False,
+        )
diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
index 4f5e7a8f34..668e8e48e4 100644
--- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
@@ -29,7 +29,7 @@ from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOper
 
 SQL = "SELECT * FROM test_table"
 BUCKET = "TEST-BUCKET-1"
-FILENAME = "test_results.csv"
+FILENAME = "test_results_{}.csv"
 TASK_ID = "TEST_TASK_ID"
 SCHEMA = [
     {"name": "column_a", "type": "3"},
@@ -137,9 +137,13 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
             ]
         )
         mock_flush.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call(), mock.call()])
-        csv_call = mock.call(BUCKET, FILENAME, TMP_FILE_NAME, mime_type='text/csv', gzip=True)
+        csv_calls = []
+        for i in range(0, 3):
+            csv_calls.append(
+                mock.call(BUCKET, FILENAME.format(i), TMP_FILE_NAME, mime_type='text/csv', gzip=True)
+            )
         json_call = mock.call(BUCKET, SCHEMA_FILE, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
-        upload_calls = [csv_call, csv_call, csv_call, json_call]
+        upload_calls = [json_call, csv_calls[0], csv_calls[1], csv_calls[2]]
         mock_upload.assert_has_calls(upload_calls)
         mock_close.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call(), mock.call()])
 
@@ -169,7 +173,9 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
             ]
         )
         mock_flush.assert_called_once()
-        mock_upload.assert_called_once_with(BUCKET, FILENAME, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False)
+        mock_upload.assert_called_once_with(
+            BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type=APP_JSON, gzip=False
+        )
         mock_close.assert_called_once()
 
         mock_query.reset_mock()
@@ -189,7 +195,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         mock_query.assert_called_once()
         mock_flush.assert_called_once()
         mock_upload.assert_called_once_with(
-            BUCKET, FILENAME, TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
+            BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False
         )
         mock_close.assert_called_once()
 
@@ -233,7 +239,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         cursor.description = CURSOR_DESCRIPTION
 
         files = op._write_local_data_files(cursor)
-        file = files[0]['file_handle']
+        file = next(files)['file_handle']
         file.flush()
         df = pd.read_csv(file.name)
         assert df.equals(OUTPUT_DF)
@@ -255,7 +261,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         cursor.description = CURSOR_DESCRIPTION
 
         files = op._write_local_data_files(cursor)
-        file = files[0]['file_handle']
+        file = next(files)['file_handle']
         file.flush()
         df = pd.read_json(file.name, orient='records', lines=True)
         assert df.equals(OUTPUT_DF)
@@ -277,7 +283,7 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         cursor.description = CURSOR_DESCRIPTION
 
         files = op._write_local_data_files(cursor)
-        file = files[0]['file_handle']
+        file = next(files)['file_handle']
         file.flush()
         df = pd.read_parquet(file.name)
         assert df.equals(OUTPUT_DF)