You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/09/07 16:41:05 UTC

[GitHub] Fokko closed pull request #3838: [AIRFLOW-2997] Support for Bigquery clustered tables

Fokko closed pull request #3838: [AIRFLOW-2997] Support for Bigquery clustered tables
URL: https://github.com/apache/incubator-airflow/pull/3838
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
index 44ecd49e9e..245e3a8233 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -496,7 +496,8 @@ def run_query(self,
                   schema_update_options=(),
                   priority='INTERACTIVE',
                   time_partitioning=None,
-                  api_resource_configs=None):
+                  api_resource_configs=None,
+                  cluster_fields=None):
         """
         Executes a BigQuery SQL query. Optionally persists results in a BigQuery
         table. See here:
@@ -565,8 +566,12 @@ def run_query(self,
             expiration as per API specifications. Note that 'field' is not available in
             conjunction with dataset.table$partition.
         :type time_partitioning: dict
-
+        :param cluster_fields: Request that the result of this query be stored sorted
+            by one or more columns. This is only available in combination with
+            time_partitioning. The order of columns given determines the sort order.
+        :type cluster_fields: list of str
         """
+
         if not api_resource_configs:
             api_resource_configs = self.api_resource_configs
         else:
@@ -631,6 +636,9 @@ def run_query(self,
                 'tableId': destination_table,
             }
 
+        if cluster_fields:
+            cluster_fields = {'fields': cluster_fields}
+
         query_param_list = [
             (sql, 'query', None, str),
             (priority, 'priority', 'INTERACTIVE', str),
@@ -641,7 +649,8 @@ def run_query(self,
             (maximum_bytes_billed, 'maximumBytesBilled', None, float),
             (time_partitioning, 'timePartitioning', {}, dict),
             (schema_update_options, 'schemaUpdateOptions', None, tuple),
-            (destination_dataset_table, 'destinationTable', None, dict)
+            (destination_dataset_table, 'destinationTable', None, dict),
+            (cluster_fields, 'clustering', None, dict),
         ]
 
         for param_tuple in query_param_list:
@@ -856,7 +865,8 @@ def run_load(self,
                  allow_jagged_rows=False,
                  schema_update_options=(),
                  src_fmt_configs=None,
-                 time_partitioning=None):
+                 time_partitioning=None,
+                 cluster_fields=None):
         """
         Executes a BigQuery load command to load data from Google Cloud Storage
         to BigQuery. See here:
@@ -920,6 +930,10 @@ def run_load(self,
             expiration as per API specifications. Note that 'field' is not available in
             conjunction with dataset.table$partition.
         :type time_partitioning: dict
+        :param cluster_fields: Request that the result of this load be stored sorted
+            by one or more columns. This is only available in combination with
+            time_partitioning. The order of columns given determines the sort order.
+        :type cluster_fields: list of str
         """
 
         # bigquery only allows certain source formats
@@ -983,6 +997,9 @@ def run_load(self,
                 'timePartitioning': time_partitioning
             })
 
+        if cluster_fields:
+            configuration['load'].update({'clustering': {'fields': cluster_fields}})
+
         if schema_fields:
             configuration['load']['schema'] = {'fields': schema_fields}
 
diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py
index b0c0ce2d6e..025af034ad 100644
--- a/airflow/contrib/operators/bigquery_operator.py
+++ b/airflow/contrib/operators/bigquery_operator.py
@@ -100,6 +100,10 @@ class BigQueryOperator(BaseOperator):
         expiration as per API specifications. Note that 'field' is not available in
         conjunction with dataset.table$partition.
     :type time_partitioning: dict
+    :param cluster_fields: Request that the result of this query be stored sorted
+        by one or more columns. This is only available in conjunction with
+        time_partitioning. The order of columns given determines the sort order.
+    :type cluster_fields: list of str
     """
 
     template_fields = ('bql', 'sql', 'destination_dataset_table', 'labels')
@@ -127,6 +131,7 @@ def __init__(self,
                  priority='INTERACTIVE',
                  time_partitioning=None,
                  api_resource_configs=None,
+                 cluster_fields=None,
                  *args,
                  **kwargs):
         super(BigQueryOperator, self).__init__(*args, **kwargs)
@@ -152,6 +157,7 @@ def __init__(self,
             self.time_partitioning = {}
         if api_resource_configs is None:
             self.api_resource_configs = {}
+        self.cluster_fields = cluster_fields
 
         # TODO remove `bql` in Airflow 2.0
         if self.bql:
@@ -192,6 +198,7 @@ def execute(self, context):
             priority=self.priority,
             time_partitioning=self.time_partitioning,
             api_resource_configs=self.api_resource_configs,
+            cluster_fields=self.cluster_fields,
         )
 
     def on_kill(self):
diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py
index 69acb61659..003e828a8a 100644
--- a/airflow/contrib/operators/gcs_to_bq.py
+++ b/airflow/contrib/operators/gcs_to_bq.py
@@ -114,6 +114,11 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
         Note that 'field' is not available in concurrency with
         dataset.table$partition.
     :type time_partitioning: dict
+    :param cluster_fields: Request that the result of this load be stored sorted
+        by one or more columns. This is only available in conjunction with
+        time_partitioning. The order of columns given determines the sort order.
+        Not applicable for external tables.
+    :type cluster_fields: list of str
     """
     template_fields = ('bucket', 'source_objects',
                        'schema_object', 'destination_project_dataset_table')
@@ -146,6 +151,7 @@ def __init__(self,
                  src_fmt_configs=None,
                  external_table=False,
                  time_partitioning=None,
+                 cluster_fields=None,
                  *args, **kwargs):
 
         super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs)
@@ -183,6 +189,7 @@ def __init__(self,
         self.schema_update_options = schema_update_options
         self.src_fmt_configs = src_fmt_configs
         self.time_partitioning = time_partitioning
+        self.cluster_fields = cluster_fields
 
     def execute(self, context):
         bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
@@ -238,7 +245,8 @@ def execute(self, context):
                 allow_jagged_rows=self.allow_jagged_rows,
                 schema_update_options=self.schema_update_options,
                 src_fmt_configs=self.src_fmt_configs,
-                time_partitioning=self.time_partitioning)
+                time_partitioning=self.time_partitioning,
+                cluster_fields=self.cluster_fields)
 
         if self.max_id_key:
             cursor.execute('SELECT MAX({}) FROM {}'.format(
diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py
index 69a103bf06..e1379dde79 100644
--- a/tests/contrib/hooks/test_bigquery_hook.py
+++ b/tests/contrib/hooks/test_bigquery_hook.py
@@ -455,6 +455,94 @@ def test_cant_add_dollar_and_field_name(self):
             )
 
 
+class TestClusteringInRunJob(unittest.TestCase):
+
+    @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
+    @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
+    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+    def test_run_load_default(self, mocked_rwc, mocked_time, mocked_logging):
+        project_id = 12345
+
+        def run_with_config(config):
+            self.assertIsNone(config['load'].get('clustering'))
+        mocked_rwc.side_effect = run_with_config
+
+        bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
+        bq_hook.run_load(
+            destination_project_dataset_table='my_dataset.my_table',
+            schema_fields=[],
+            source_uris=[],
+        )
+
+        mocked_rwc.assert_called_once()
+
+    @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
+    @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
+    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+    def test_run_load_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
+        project_id = 12345
+
+        def run_with_config(config):
+            self.assertEqual(
+                config['load']['clustering'],
+                {
+                    'fields': ['field1', 'field2']
+                }
+            )
+        mocked_rwc.side_effect = run_with_config
+
+        bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
+        bq_hook.run_load(
+            destination_project_dataset_table='my_dataset.my_table',
+            schema_fields=[],
+            source_uris=[],
+            cluster_fields=['field1', 'field2'],
+            time_partitioning={'type': 'DAY'}
+        )
+
+        mocked_rwc.assert_called_once()
+
+    @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
+    @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
+    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+    def test_run_query_default(self, mocked_rwc, mocked_time, mocked_logging):
+        project_id = 12345
+
+        def run_with_config(config):
+            self.assertIsNone(config['query'].get('clustering'))
+        mocked_rwc.side_effect = run_with_config
+
+        bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
+        bq_hook.run_query(sql='select 1')
+
+        mocked_rwc.assert_called_once()
+
+    @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
+    @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
+    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+    def test_run_query_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
+        project_id = 12345
+
+        def run_with_config(config):
+            self.assertEqual(
+                config['query']['clustering'],
+                {
+                    'fields': ['field1', 'field2']
+                }
+            )
+        mocked_rwc.side_effect = run_with_config
+
+        bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
+        bq_hook.run_query(
+            sql='select 1',
+            destination_dataset_table='my_dataset.my_table',
+            cluster_fields=['field1', 'field2'],
+            time_partitioning={'type': 'DAY'}
+        )
+
+        mocked_rwc.assert_called_once()
+
+
 class TestBigQueryHookLegacySql(unittest.TestCase):
     """Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly."""
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services