You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/01/10 05:55:23 UTC
[airflow] branch main updated: Support partition_columns in BaseSQLToGCSOperator (#28677)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 35a8ffc55a Support partition_columns in BaseSQLToGCSOperator (#28677)
35a8ffc55a is described below
commit 35a8ffc55af220b16ea345d770f80f698dcae3fb
Author: Victor Chiapaikeo <vc...@gmail.com>
AuthorDate: Tue Jan 10 00:55:15 2023 -0500
Support partition_columns in BaseSQLToGCSOperator (#28677)
* Support partition_columns in BaseSQLToGCSOperator
Co-authored-by: eladkal <45...@users.noreply.github.com>
---
.../providers/google/cloud/transfers/sql_to_gcs.py | 116 ++++++++++++++++-----
.../google/cloud/transfers/test_sql_to_gcs.py | 97 +++++++++++++++++
2 files changed, 188 insertions(+), 25 deletions(-)
diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
index e4a3f3e942..12043b05d6 100644
--- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import abc
import json
+import os
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
@@ -77,6 +78,10 @@ class BaseSQLToGCSOperator(BaseOperator):
account from the list granting this role to the originating account (templated).
:param upload_metadata: whether to upload the row count metadata as blob metadata
:param exclude_columns: set of columns to exclude from transmission
+ :param partition_columns: list of columns to use for file partitioning. In order to use
+ this parameter, you must sort your dataset by partition_columns. Do this by
+ passing an ORDER BY clause to the sql query. Files are uploaded to GCS as objects
+ with a hive style partitioning directory structure (templated).
"""
template_fields: Sequence[str] = (
@@ -87,6 +92,7 @@ class BaseSQLToGCSOperator(BaseOperator):
"schema",
"parameters",
"impersonation_chain",
+ "partition_columns",
)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "sql"}
@@ -111,7 +117,8 @@ class BaseSQLToGCSOperator(BaseOperator):
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
upload_metadata: bool = False,
- exclude_columns=None,
+ exclude_columns: set | None = None,
+ partition_columns: list | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -135,8 +142,16 @@ class BaseSQLToGCSOperator(BaseOperator):
self.impersonation_chain = impersonation_chain
self.upload_metadata = upload_metadata
self.exclude_columns = exclude_columns
+ self.partition_columns = partition_columns
def execute(self, context: Context):
+ if self.partition_columns:
+ self.log.info(
+ f"Found partition columns: {','.join(self.partition_columns)}. "
+ "Assuming the SQL statement is properly sorted by these columns in "
+ "ascending or descending order."
+ )
+
self.log.info("Executing query")
cursor = self.query()
@@ -158,6 +173,7 @@ class BaseSQLToGCSOperator(BaseOperator):
total_files = 0
self.log.info("Writing local data files")
for file_to_upload in self._write_local_data_files(cursor):
+
# Flush file before uploading
file_to_upload["file_handle"].flush()
@@ -204,27 +220,13 @@ class BaseSQLToGCSOperator(BaseOperator):
names in GCS, and values are file handles to local files that
contain the data for the GCS objects.
"""
- import os
-
org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
schema = [column for column in org_schema if column not in self.exclude_columns]
col_type_dict = self._get_col_type_dict()
file_no = 0
-
- tmp_file_handle = NamedTemporaryFile(delete=True)
- if self.export_format == "csv":
- file_mime_type = "text/csv"
- elif self.export_format == "parquet":
- file_mime_type = "application/octet-stream"
- else:
- file_mime_type = "application/json"
- file_to_upload = {
- "file_name": self.filename.format(file_no),
- "file_handle": tmp_file_handle,
- "file_mime_type": file_mime_type,
- "file_row_count": 0,
- }
+ file_mime_type = self._get_file_mime_type()
+ file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
if self.export_format == "csv":
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
@@ -232,8 +234,42 @@ class BaseSQLToGCSOperator(BaseOperator):
parquet_schema = self._convert_parquet_schema(cursor)
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
+ prev_partition_values = None
+ curr_partition_values = None
for row in cursor:
+ if self.partition_columns:
+ row_dict = dict(zip(schema, row))
+ curr_partition_values = tuple(
+ [row_dict.get(partition_column, "") for partition_column in self.partition_columns]
+ )
+
+ if prev_partition_values is None:
+ # We haven't set prev_partition_values before. Set to current
+ prev_partition_values = curr_partition_values
+
+ elif prev_partition_values != curr_partition_values:
+ # If the partition values differ, write the current local file out
+ # Yield first before we write the current record
+ file_no += 1
+
+ if self.export_format == "parquet":
+ parquet_writer.close()
+
+ file_to_upload["partition_values"] = prev_partition_values
+ yield file_to_upload
+ file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
+ if self.export_format == "csv":
+ csv_writer = self._configure_csv_file(tmp_file_handle, schema)
+ if self.export_format == "parquet":
+ parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
+
+ # Reset previous to current after writing out the file
+ prev_partition_values = curr_partition_values
+
+ # Incrementing file_row_count after partition yield ensures all rows are written
file_to_upload["file_row_count"] += 1
+
+ # Proceed to write the row to the localfile
if self.export_format == "csv":
row = self.convert_types(schema, col_type_dict, row)
if self.null_marker is not None:
@@ -268,24 +304,44 @@ class BaseSQLToGCSOperator(BaseOperator):
if self.export_format == "parquet":
parquet_writer.close()
+
+ file_to_upload["partition_values"] = curr_partition_values
yield file_to_upload
- tmp_file_handle = NamedTemporaryFile(delete=True)
- file_to_upload = {
- "file_name": self.filename.format(file_no),
- "file_handle": tmp_file_handle,
- "file_mime_type": file_mime_type,
- "file_row_count": 0,
- }
+ file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
if self.export_format == "csv":
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == "parquet":
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
+
if self.export_format == "parquet":
parquet_writer.close()
# Last file may have 0 rows, don't yield if empty
if file_to_upload["file_row_count"] > 0:
+ file_to_upload["partition_values"] = curr_partition_values
yield file_to_upload
+ def _get_file_to_upload(self, file_mime_type, file_no):
+ """Returns a dictionary that represents the file to upload"""
+ tmp_file_handle = NamedTemporaryFile(delete=True)
+ return (
+ {
+ "file_name": self.filename.format(file_no),
+ "file_handle": tmp_file_handle,
+ "file_mime_type": file_mime_type,
+ "file_row_count": 0,
+ },
+ tmp_file_handle,
+ )
+
+ def _get_file_mime_type(self):
+ if self.export_format == "csv":
+ file_mime_type = "text/csv"
+ elif self.export_format == "parquet":
+ file_mime_type = "application/octet-stream"
+ else:
+ file_mime_type = "application/json"
+ return file_mime_type
+
def _configure_csv_file(self, file_handle, schema):
"""Configure a csv writer with the file_handle and write schema
as headers for the new file.
@@ -400,9 +456,19 @@ class BaseSQLToGCSOperator(BaseOperator):
if is_data_file and self.upload_metadata:
metadata = {"row_count": file_to_upload["file_row_count"]}
+ object_name = file_to_upload.get("file_name")
+ if is_data_file and self.partition_columns:
+ # Add partition column values to object_name
+ partition_values = file_to_upload.get("partition_values")
+ head_path, tail_path = os.path.split(object_name)
+ partition_subprefix = [
+ f"{col}={val}" for col, val in zip(self.partition_columns, partition_values)
+ ]
+ object_name = os.path.join(head_path, *partition_subprefix, tail_path)
+
hook.upload(
self.bucket,
- file_to_upload.get("file_name"),
+ object_name,
file_to_upload.get("file_handle").name,
mime_type=file_to_upload.get("file_mime_type"),
gzip=self.gzip if is_data_file else False,
diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
index 100ad1976c..bcb2a39100 100644
--- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py
@@ -62,6 +62,7 @@ APP_JSON = "application/json"
OUTPUT_DF = pd.DataFrame([["convert_type_return_value"] * 3] * 3, columns=COLUMNS)
EXCLUDE_COLUMNS = set("column_c")
+PARTITION_COLUMNS = ["column_b", "column_c"]
NEW_COLUMNS = [c for c in COLUMNS if c not in EXCLUDE_COLUMNS]
OUTPUT_DF_WITH_EXCLUDE_COLUMNS = pd.DataFrame(
[["convert_type_return_value"] * len(NEW_COLUMNS)] * 3, columns=NEW_COLUMNS
@@ -305,6 +306,74 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
)
mock_close.assert_called_once()
+ mock_query.reset_mock()
+ mock_flush.reset_mock()
+ mock_upload.reset_mock()
+ mock_close.reset_mock()
+ cursor_mock.reset_mock()
+
+ cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
+
+ # Test partition columns
+ operator = DummySQLToGCSOperator(
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ task_id=TASK_ID,
+ export_format="parquet",
+ schema=SCHEMA,
+ partition_columns=PARTITION_COLUMNS,
+ )
+ result = operator.execute(context=dict())
+
+ assert result == {
+ "bucket": "TEST-BUCKET-1",
+ "total_row_count": 3,
+ "total_files": 3,
+ "files": [
+ {
+ "file_name": "test_results_0.csv",
+ "file_mime_type": "application/octet-stream",
+ "file_row_count": 1,
+ },
+ {
+ "file_name": "test_results_1.csv",
+ "file_mime_type": "application/octet-stream",
+ "file_row_count": 1,
+ },
+ {
+ "file_name": "test_results_2.csv",
+ "file_mime_type": "application/octet-stream",
+ "file_row_count": 1,
+ },
+ ],
+ }
+
+ mock_query.assert_called_once()
+ assert mock_flush.call_count == 3
+ assert mock_close.call_count == 3
+ mock_upload.assert_has_calls(
+ [
+ mock.call(
+ BUCKET,
+ f"column_b={row[1]}/column_c={row[2]}/test_results_{i}.csv",
+ TMP_FILE_NAME,
+ mime_type="application/octet-stream",
+ gzip=False,
+ metadata=None,
+ )
+ for i, row in enumerate(INPUT_DATA)
+ ]
+ )
+
+ mock_query.reset_mock()
+ mock_flush.reset_mock()
+ mock_upload.reset_mock()
+ mock_close.reset_mock()
+ cursor_mock.reset_mock()
+
+ cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
+
# Test null marker
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
mock_convert_type.return_value = None
@@ -423,3 +492,31 @@ class TestBaseSQLToGCSOperator(unittest.TestCase):
file.flush()
df = pd.read_json(file.name, orient="records", lines=True)
assert df.equals(OUTPUT_DF_WITH_EXCLUDE_COLUMNS)
+
+ def test__write_local_data_files_parquet_with_partition_columns(self):
+ op = DummySQLToGCSOperator(
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ task_id=TASK_ID,
+ schema_filename=SCHEMA_FILE,
+ export_format="parquet",
+ gzip=False,
+ schema=SCHEMA,
+ gcp_conn_id="google_cloud_default",
+ partition_columns=PARTITION_COLUMNS,
+ )
+ cursor = MagicMock()
+ cursor.__iter__.return_value = INPUT_DATA
+ cursor.description = CURSOR_DESCRIPTION
+
+ local_data_files = op._write_local_data_files(cursor)
+ concat_dfs = []
+ for local_data_file in local_data_files:
+ file = local_data_file["file_handle"]
+ file.flush()
+ df = pd.read_parquet(file.name)
+ concat_dfs.append(df)
+
+ concat_df = pd.concat(concat_dfs, ignore_index=True)
+ assert concat_df.equals(OUTPUT_DF)