You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/01/10 05:55:23 UTC

[airflow] branch main updated: Support partition_columns in BaseSQLToGCSOperator (#28677)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 35a8ffc55a Support partition_columns in BaseSQLToGCSOperator (#28677)
35a8ffc55a is described below

commit 35a8ffc55af220b16ea345d770f80f698dcae3fb
Author: Victor Chiapaikeo <vc...@gmail.com>
AuthorDate: Tue Jan 10 00:55:15 2023 -0500

    Support partition_columns in BaseSQLToGCSOperator (#28677)
    
    * Support partition_columns in BaseSQLToGCSOperator
    
    Co-authored-by: eladkal <45...@users.noreply.github.com>
---
 .../providers/google/cloud/transfers/sql_to_gcs.py | 116 ++++++++++++++++-----
 .../google/cloud/transfers/test_sql_to_gcs.py      |  97 +++++++++++++++++
 2 files changed, 188 insertions(+), 25 deletions(-)

diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
index e4a3f3e942..12043b05d6 100644
--- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 
 import abc
 import json
+import os
 from tempfile import NamedTemporaryFile
 from typing import TYPE_CHECKING, Sequence
 
@@ -77,6 +78,10 @@ class BaseSQLToGCSOperator(BaseOperator):
         account from the list granting this role to the originating account (templated).
     :param upload_metadata: whether to upload the row count metadata as blob metadata
     :param exclude_columns: set of columns to exclude from transmission
+    :param partition_columns: list of columns to use for file partitioning. In order to use
+        this parameter, you must sort your dataset by partition_columns. Do this by
+        passing an ORDER BY clause to the sql query. Files are uploaded to GCS as objects
+        with a hive style partitioning directory structure (templated).
     """
 
     template_fields: Sequence[str] = (
@@ -87,6 +92,7 @@ class BaseSQLToGCSOperator(BaseOperator):
         "schema",
         "parameters",
         "impersonation_chain",
+        "partition_columns",
     )
     template_ext: Sequence[str] = (".sql",)
     template_fields_renderers = {"sql": "sql"}
@@ -111,7 +117,8 @@ class BaseSQLToGCSOperator(BaseOperator):
         delegate_to: str | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
         upload_metadata: bool = False,
-        exclude_columns=None,
+        exclude_columns: set | None = None,
+        partition_columns: list | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -135,8 +142,16 @@ class BaseSQLToGCSOperator(BaseOperator):
         self.impersonation_chain = impersonation_chain
         self.upload_metadata = upload_metadata
         self.exclude_columns = exclude_columns
+        self.partition_columns = partition_columns
 
     def execute(self, context: Context):
+        if self.partition_columns:
+            self.log.info(
+                f"Found partition columns: {','.join(self.partition_columns)}. "
+                "Assuming the SQL statement is properly sorted by these columns in "
+                "ascending or descending order."
+            )
+
         self.log.info("Executing query")
         cursor = self.query()
 
@@ -158,6 +173,7 @@ class BaseSQLToGCSOperator(BaseOperator):
         total_files = 0
         self.log.info("Writing local data files")
         for file_to_upload in self._write_local_data_files(cursor):
+
             # Flush file before uploading
             file_to_upload["file_handle"].flush()
 
@@ -204,27 +220,13 @@ class BaseSQLToGCSOperator(BaseOperator):
             names in GCS, and values are file handles to local files that
             contain the data for the GCS objects.
         """
-        import os
-
         org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
         schema = [column for column in org_schema if column not in self.exclude_columns]
 
         col_type_dict = self._get_col_type_dict()
         file_no = 0
-
-        tmp_file_handle = NamedTemporaryFile(delete=True)
-        if self.export_format == "csv":
-            file_mime_type = "text/csv"
-        elif self.export_format == "parquet":
-            file_mime_type = "application/octet-stream"
-        else:
-            file_mime_type = "application/json"
-        file_to_upload = {
-            "file_name": self.filename.format(file_no),
-            "file_handle": tmp_file_handle,
-            "file_mime_type": file_mime_type,
-            "file_row_count": 0,
-        }
+        file_mime_type = self._get_file_mime_type()
+        file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
 
         if self.export_format == "csv":
             csv_writer = self._configure_csv_file(tmp_file_handle, schema)
@@ -232,8 +234,42 @@ class BaseSQLToGCSOperator(BaseOperator):
             parquet_schema = self._convert_parquet_schema(cursor)
             parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
 
+        prev_partition_values = None
+        curr_partition_values = None
         for row in cursor:
+            if self.partition_columns:
+                row_dict = dict(zip(schema, row))
+                curr_partition_values = tuple(
+                    [row_dict.get(partition_column, "") for partition_column in self.partition_columns]
+                )
+
+                if prev_partition_values is None:
+                    # We haven't set prev_partition_values before. Set to current
+                    prev_partition_values = curr_partition_values
+
+                elif prev_partition_values != curr_partition_values:
+                    # If the partition values differ, write the current local file out
+                    # Yield first before we write the current record
+                    file_no += 1
+
+                    if self.export_format == "parquet":
+                        parquet_writer.close()
+
+                    file_to_upload["partition_values"] = prev_partition_values
+                    yield file_to_upload
+                    file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
+                    if self.export_format == "csv":
+                        csv_writer = self._configure_csv_file(tmp_file_handle, schema)
+                    if self.export_format == "parquet":
+                        parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
+
+                    # Reset previous to current after writing out the file
+                    prev_partition_values = curr_partition_values
+
+            # Incrementing file_row_count after partition yield ensures all rows are written
             file_to_upload["file_row_count"] += 1
+
+            # Proceed to write the row to the localfile
             if self.export_format == "csv":
                 row = self.convert_types(schema, col_type_dict, row)
                 if self.null_marker is not None:
@@ -268,24 +304,44 @@ class BaseSQLToGCSOperator(BaseOperator):
 
                 if self.export_format == "parquet":
                     parquet_writer.close()
+
+                file_to_upload["partition_values"] = curr_partition_values
                 yield file_to_upload
-                tmp_file_handle = NamedTemporaryFile(delete=True)
-                file_to_upload = {
-                    "file_name": self.filename.format(file_no),
-                    "file_handle": tmp_file_handle,
-                    "file_mime_type": file_mime_type,
-                    "file_row_count": 0,
-                }
+                file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
                 if self.export_format == "csv":
                     csv_writer = self._configure_csv_file(tmp_file_handle, schema)
                 if self.export_format == "parquet":
                     parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
+
         if self.export_format == "parquet":
             parquet_writer.close()
         # Last file may have 0 rows, don't yield if empty
         if file_to_upload["file_row_count"] > 0:
+            file_to_upload["partition_values"] = curr_partition_values
             yield file_to_upload
 
+    def _get_file_to_upload(self, file_mime_type, file_no):
+        """Returns a dictionary that represents the file to upload"""
+        tmp_file_handle = NamedTemporaryFile(delete=True)
+        return (
+            {
+                "file_name": self.filename.format(file_no),
+                "file_handle": tmp_file_handle,
+                "file_mime_type": file_mime_type,
+                "file_row_count": 0,
+            },
+            tmp_file_handle,
+        )
+
+    def _get_file_mime_type(self):
+        if self.export_format == "csv":
+            file_mime_type = "text/csv"
+        elif self.export_format == "parquet":
+            file_mime_type = "application/octet-stream"
+        else:
+            file_mime_type = "application/json"
+        return file_mime_type
+
     def _configure_csv_file(self, file_handle, schema):
         """Configure a csv writer with the file_handle and write schema
         as headers for the new file.
@@ -400,9 +456,19 @@ class BaseSQLToGCSOperator(BaseOperator):
         if is_data_file and self.upload_metadata:
             metadata = {"row_count": file_to_upload["file_row_count"]}
 
+        object_name = file_to_upload.get("file_name")
+        if is_data_file and self.partition_columns:
+            # Add partition column values to object_name
+            partition_values = file_to_upload.get("partition_values")
+            head_path, tail_path = os.path.split(object_name)
+            partition_subprefix = [
+                f"{col}={val}" for col, val in zip(self.partition_columns, partition_values)
+            ]
+            object_name = os.path.join(head_path, *partition_subprefix, tail_path)
+
         hook.upload(
             self.bucket,
-            file_to_upload.get("file_name"),
+            object_name,
             file_to_upload.get("file_handle").name,
             mime_type=file_to_upload.get("file_mime_type"),
             gzip=self.gzip if is_data_file else False,
diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
index 100ad1976c..bcb2a39100 100644
--- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
@@ -62,6 +62,7 @@ APP_JSON = "application/json"
 OUTPUT_DF = pd.DataFrame([["convert_type_return_value"] * 3] * 3, columns=COLUMNS)
 
 EXCLUDE_COLUMNS = set("column_c")
+PARTITION_COLUMNS = ["column_b", "column_c"]
 NEW_COLUMNS = [c for c in COLUMNS if c not in EXCLUDE_COLUMNS]
 OUTPUT_DF_WITH_EXCLUDE_COLUMNS = pd.DataFrame(
     [["convert_type_return_value"] * len(NEW_COLUMNS)] * 3, columns=NEW_COLUMNS
@@ -305,6 +306,74 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         )
         mock_close.assert_called_once()
 
+        mock_query.reset_mock()
+        mock_flush.reset_mock()
+        mock_upload.reset_mock()
+        mock_close.reset_mock()
+        cursor_mock.reset_mock()
+
+        cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
+
+        # Test partition columns
+        operator = DummySQLToGCSOperator(
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME,
+            task_id=TASK_ID,
+            export_format="parquet",
+            schema=SCHEMA,
+            partition_columns=PARTITION_COLUMNS,
+        )
+        result = operator.execute(context=dict())
+
+        assert result == {
+            "bucket": "TEST-BUCKET-1",
+            "total_row_count": 3,
+            "total_files": 3,
+            "files": [
+                {
+                    "file_name": "test_results_0.csv",
+                    "file_mime_type": "application/octet-stream",
+                    "file_row_count": 1,
+                },
+                {
+                    "file_name": "test_results_1.csv",
+                    "file_mime_type": "application/octet-stream",
+                    "file_row_count": 1,
+                },
+                {
+                    "file_name": "test_results_2.csv",
+                    "file_mime_type": "application/octet-stream",
+                    "file_row_count": 1,
+                },
+            ],
+        }
+
+        mock_query.assert_called_once()
+        assert mock_flush.call_count == 3
+        assert mock_close.call_count == 3
+        mock_upload.assert_has_calls(
+            [
+                mock.call(
+                    BUCKET,
+                    f"column_b={row[1]}/column_c={row[2]}/test_results_{i}.csv",
+                    TMP_FILE_NAME,
+                    mime_type="application/octet-stream",
+                    gzip=False,
+                    metadata=None,
+                )
+                for i, row in enumerate(INPUT_DATA)
+            ]
+        )
+
+        mock_query.reset_mock()
+        mock_flush.reset_mock()
+        mock_upload.reset_mock()
+        mock_close.reset_mock()
+        cursor_mock.reset_mock()
+
+        cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
+
         # Test null marker
         cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
         mock_convert_type.return_value = None
@@ -423,3 +492,31 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
         file.flush()
         df = pd.read_json(file.name, orient="records", lines=True)
         assert df.equals(OUTPUT_DF_WITH_EXCLUDE_COLUMNS)
+
+    def test__write_local_data_files_parquet_with_partition_columns(self):
+        op = DummySQLToGCSOperator(
+            sql=SQL,
+            bucket=BUCKET,
+            filename=FILENAME,
+            task_id=TASK_ID,
+            schema_filename=SCHEMA_FILE,
+            export_format="parquet",
+            gzip=False,
+            schema=SCHEMA,
+            gcp_conn_id="google_cloud_default",
+            partition_columns=PARTITION_COLUMNS,
+        )
+        cursor = MagicMock()
+        cursor.__iter__.return_value = INPUT_DATA
+        cursor.description = CURSOR_DESCRIPTION
+
+        local_data_files = op._write_local_data_files(cursor)
+        concat_dfs = []
+        for local_data_file in local_data_files:
+            file = local_data_file["file_handle"]
+            file.flush()
+            df = pd.read_parquet(file.name)
+            concat_dfs.append(df)
+
+        concat_df = pd.concat(concat_dfs, ignore_index=True)
+        assert concat_df.equals(OUTPUT_DF)