You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/09/03 20:30:24 UTC

[GitHub] kaxil closed pull request #3733: [AIRFLOW-491] Add cache parameter in BigQuery query method - with 'api_resource_configs'

kaxil closed pull request #3733: [AIRFLOW-491] Add cache parameter in BigQuery query method - with 'api_resource_configs'
URL: https://github.com/apache/incubator-airflow/pull/3733
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
index aee84e9eef..44ecd49e9e 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -24,6 +24,7 @@
 
 import time
 from builtins import range
+from copy import deepcopy
 
 from past.builtins import basestring
 
@@ -195,10 +196,19 @@ class BigQueryBaseCursor(LoggingMixin):
     PEP 249 cursor isn't needed.
     """
 
-    def __init__(self, service, project_id, use_legacy_sql=True):
+    def __init__(self,
+                 service,
+                 project_id,
+                 use_legacy_sql=True,
+                 api_resource_configs=None):
+
         self.service = service
         self.project_id = project_id
         self.use_legacy_sql = use_legacy_sql
+        if api_resource_configs:
+            _validate_value("api_resource_configs", api_resource_configs, dict)
+        self.api_resource_configs = api_resource_configs \
+            if api_resource_configs else {}
         self.running_job_id = None
 
     def create_empty_table(self,
@@ -238,8 +248,7 @@ def create_empty_table(self,
 
         :return:
         """
-        if time_partitioning is None:
-            time_partitioning = dict()
+
         project_id = project_id if project_id is not None else self.project_id
 
         table_resource = {
@@ -473,11 +482,11 @@ def create_external_table(self,
     def run_query(self,
                   bql=None,
                   sql=None,
-                  destination_dataset_table=False,
+                  destination_dataset_table=None,
                   write_disposition='WRITE_EMPTY',
                   allow_large_results=False,
                   flatten_results=None,
-                  udf_config=False,
+                  udf_config=None,
                   use_legacy_sql=None,
                   maximum_billing_tier=None,
                   maximum_bytes_billed=None,
@@ -486,7 +495,8 @@ def run_query(self,
                   labels=None,
                   schema_update_options=(),
                   priority='INTERACTIVE',
-                  time_partitioning=None):
+                  time_partitioning=None,
+                  api_resource_configs=None):
         """
         Executes a BigQuery SQL query. Optionally persists results in a BigQuery
         table. See here:
@@ -518,6 +528,13 @@ def run_query(self,
         :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
             If `None`, defaults to `self.use_legacy_sql`.
         :type use_legacy_sql: boolean
+        :param api_resource_configs: a dictionary that contain params
+            'configuration' applied for Google BigQuery Jobs API:
+            https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs
+            for example, {'query': {'useQueryCache': False}}. You could use it
+            if you need to provide some params that are not supported by the
+            BigQueryHook like args.
+        :type api_resource_configs: dict
         :type udf_config: list
         :param maximum_billing_tier: Positive integer that serves as a
             multiplier of the basic price.
@@ -550,12 +567,22 @@ def run_query(self,
         :type time_partitioning: dict
 
         """
+        if not api_resource_configs:
+            api_resource_configs = self.api_resource_configs
+        else:
+            _validate_value('api_resource_configs',
+                            api_resource_configs, dict)
+        configuration = deepcopy(api_resource_configs)
+        if 'query' not in configuration:
+            configuration['query'] = {}
+
+        else:
+            _validate_value("api_resource_configs['query']",
+                            configuration['query'], dict)
 
-        # TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
-        if time_partitioning is None:
-            time_partitioning = {}
         sql = bql if sql is None else sql
 
+        # TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
         if bql:
             import warnings
             warnings.warn('Deprecated parameter `bql` used in '
@@ -566,95 +593,109 @@ def run_query(self,
                           'Airflow.',
                           category=DeprecationWarning)
 
-        if sql is None:
-            raise TypeError('`BigQueryBaseCursor.run_query` missing 1 required '
-                            'positional argument: `sql`')
+        if sql is None and not configuration['query'].get('query', None):
+            raise TypeError('`BigQueryBaseCursor.run_query` '
+                            'missing 1 required positional argument: `sql`')
 
         # BigQuery also allows you to define how you want a table's schema to change
         # as a side effect of a query job
         # for more details:
         #   https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions
+
         allowed_schema_update_options = [
             'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"
         ]
-        if not set(allowed_schema_update_options).issuperset(
-                set(schema_update_options)):
-            raise ValueError(
-                "{0} contains invalid schema update options. "
-                "Please only use one or more of the following options: {1}"
-                .format(schema_update_options, allowed_schema_update_options))
 
-        if use_legacy_sql is None:
-            use_legacy_sql = self.use_legacy_sql
+        if not set(allowed_schema_update_options
+                   ).issuperset(set(schema_update_options)):
+            raise ValueError("{0} contains invalid schema update options. "
+                             "Please only use one or more of the following "
+                             "options: {1}"
+                             .format(schema_update_options,
+                                     allowed_schema_update_options))
 
-        configuration = {
-            'query': {
-                'query': sql,
-                'useLegacySql': use_legacy_sql,
-                'maximumBillingTier': maximum_billing_tier,
-                'maximumBytesBilled': maximum_bytes_billed,
-                'priority': priority
-            }
-        }
+        if schema_update_options:
+            if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
+                raise ValueError("schema_update_options is only "
+                                 "allowed if write_disposition is "
+                                 "'WRITE_APPEND' or 'WRITE_TRUNCATE'.")
 
         if destination_dataset_table:
-            if '.' not in destination_dataset_table:
-                raise ValueError(
-                    'Expected destination_dataset_table name in the format of '
-                    '<dataset>.<table>. Got: {}'.format(
-                        destination_dataset_table))
             destination_project, destination_dataset, destination_table = \
                 _split_tablename(table_input=destination_dataset_table,
                                  default_project_id=self.project_id)
-            configuration['query'].update({
-                'allowLargeResults': allow_large_results,
-                'flattenResults': flatten_results,
-                'writeDisposition': write_disposition,
-                'createDisposition': create_disposition,
-                'destinationTable': {
-                    'projectId': destination_project,
-                    'datasetId': destination_dataset,
-                    'tableId': destination_table,
-                }
-            })
-        if udf_config:
-            if not isinstance(udf_config, list):
-                raise TypeError("udf_config argument must have a type 'list'"
-                                " not {}".format(type(udf_config)))
-            configuration['query'].update({
-                'userDefinedFunctionResources': udf_config
-            })
 
-        if query_params:
-            if self.use_legacy_sql:
-                raise ValueError("Query parameters are not allowed when using "
-                                 "legacy SQL")
-            else:
-                configuration['query']['queryParameters'] = query_params
+            destination_dataset_table = {
+                'projectId': destination_project,
+                'datasetId': destination_dataset,
+                'tableId': destination_table,
+            }
 
-        if labels:
-            configuration['labels'] = labels
+        query_param_list = [
+            (sql, 'query', None, str),
+            (priority, 'priority', 'INTERACTIVE', str),
+            (use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool),
+            (query_params, 'queryParameters', None, dict),
+            (udf_config, 'userDefinedFunctionResources', None, list),
+            (maximum_billing_tier, 'maximumBillingTier', None, int),
+            (maximum_bytes_billed, 'maximumBytesBilled', None, float),
+            (time_partitioning, 'timePartitioning', {}, dict),
+            (schema_update_options, 'schemaUpdateOptions', None, tuple),
+            (destination_dataset_table, 'destinationTable', None, dict)
+        ]
 
-        time_partitioning = _cleanse_time_partitioning(
-            destination_dataset_table,
-            time_partitioning
-        )
-        if time_partitioning:
-            configuration['query'].update({
-                'timePartitioning': time_partitioning
-            })
+        for param_tuple in query_param_list:
 
-        if schema_update_options:
-            if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
-                raise ValueError("schema_update_options is only "
-                                 "allowed if write_disposition is "
-                                 "'WRITE_APPEND' or 'WRITE_TRUNCATE'.")
-            else:
-                self.log.info(
-                    "Adding experimental "
-                    "'schemaUpdateOptions': {0}".format(schema_update_options))
-                configuration['query'][
-                    'schemaUpdateOptions'] = schema_update_options
+            param, param_name, param_default, param_type = param_tuple
+
+            if param_name not in configuration['query'] and param in [None, {}, ()]:
+                if param_name == 'timePartitioning':
+                    param_default = _cleanse_time_partitioning(
+                        destination_dataset_table, time_partitioning)
+                param = param_default
+
+            if param not in [None, {}, ()]:
+                _api_resource_configs_duplication_check(
+                    param_name, param, configuration['query'])
+
+                configuration['query'][param_name] = param
+
+                # check valid type of provided param,
+                # it last step because we can get param from 2 sources,
+                # and first of all need to find it
+
+                _validate_value(param_name, configuration['query'][param_name],
+                                param_type)
+
+                if param_name == 'schemaUpdateOptions' and param:
+                    self.log.info("Adding experimental 'schemaUpdateOptions': "
+                                  "{0}".format(schema_update_options))
+
+                if param_name == 'destinationTable':
+                    for key in ['projectId', 'datasetId', 'tableId']:
+                        if key not in configuration['query']['destinationTable']:
+                            raise ValueError(
+                                "Not correct 'destinationTable' in "
+                                "api_resource_configs. 'destinationTable' "
+                                "must be a dict with {'projectId':'', "
+                                "'datasetId':'', 'tableId':''}")
+
+                    configuration['query'].update({
+                        'allowLargeResults': allow_large_results,
+                        'flattenResults': flatten_results,
+                        'writeDisposition': write_disposition,
+                        'createDisposition': create_disposition,
+                    })
+
+        if 'useLegacySql' in configuration['query'] and \
+                'queryParameters' in configuration['query']:
+            raise ValueError("Query parameters are not allowed "
+                             "when using legacy SQL")
+
+        if labels:
+            _api_resource_configs_duplication_check(
+                'labels', labels, configuration)
+            configuration['labels'] = labels
 
         return self.run_with_configuration(configuration)
 
@@ -888,8 +929,7 @@ def run_load(self,
         #   https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat
         if src_fmt_configs is None:
             src_fmt_configs = {}
-        if time_partitioning is None:
-            time_partitioning = {}
+
         source_format = source_format.upper()
         allowed_formats = [
             "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS",
@@ -1167,10 +1207,6 @@ def run_table_delete(self, deletion_dataset_table,
         :type ignore_if_missing: boolean
         :return:
         """
-        if '.' not in deletion_dataset_table:
-            raise ValueError(
-                'Expected deletion_dataset_table name in the format of '
-                '<dataset>.<table>. Got: {}'.format(deletion_dataset_table))
         deletion_project, deletion_dataset, deletion_table = \
             _split_tablename(table_input=deletion_dataset_table,
                              default_project_id=self.project_id)
@@ -1536,6 +1572,12 @@ def _bq_cast(string_field, bq_type):
 
 
 def _split_tablename(table_input, default_project_id, var_name=None):
+
+    if '.' not in table_input:
+        raise ValueError(
+            'Expected deletion_dataset_table name in the format of '
+            '<dataset>.<table>. Got: {}'.format(table_input))
+
     if not default_project_id:
         raise ValueError("INTERNAL: No default project is specified")
 
@@ -1597,6 +1639,10 @@ def var_print(var_name):
 
 def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
     # if it is a partitioned table ($ is in the table name) add partition load option
+
+    if time_partitioning_in is None:
+        time_partitioning_in = {}
+
     time_partitioning_out = {}
     if destination_dataset_table and '$' in destination_dataset_table:
         if time_partitioning_in.get('field'):
@@ -1607,3 +1653,20 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
 
     time_partitioning_out.update(time_partitioning_in)
     return time_partitioning_out
+
+
+def _validate_value(key, value, expected_type):
+    """ function to check expected type and raise
+    error if type is not correct """
+    if not isinstance(value, expected_type):
+        raise TypeError("{} argument must have a type {} not {}".format(
+            key, expected_type, type(value)))
+
+
+def _api_resource_configs_duplication_check(key, value, config_dict):
+    if key in config_dict and value != config_dict[key]:
+        raise ValueError("Values of {param_name} param are duplicated. "
+                         "`api_resource_configs` contained {param_name} param "
+                         "in `query` config and {param_name} was also provided "
+                         "with arg to run_query() method. Please remove duplicates."
+                         .format(param_name=key))
diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py
index a8fdc66ec9..b0c0ce2d6e 100644
--- a/airflow/contrib/operators/bigquery_operator.py
+++ b/airflow/contrib/operators/bigquery_operator.py
@@ -75,6 +75,13 @@ class BigQueryOperator(BaseOperator):
         (without incurring a charge). If unspecified, this will be
         set to your project default.
     :type maximum_bytes_billed: float
+    :param api_resource_configs: a dictionary that contain params
+        'configuration' applied for Google BigQuery Jobs API:
+        https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs
+        for example, {'query': {'useQueryCache': False}}. You could use it
+        if you need to provide some params that are not supported by BigQueryOperator
+        like args.
+    :type api_resource_configs: dict
     :param schema_update_options: Allows the schema of the destination
         table to be updated as a side effect of the load job.
     :type schema_update_options: tuple
@@ -118,7 +125,8 @@ def __init__(self,
                  query_params=None,
                  labels=None,
                  priority='INTERACTIVE',
-                 time_partitioning={},
+                 time_partitioning=None,
+                 api_resource_configs=None,
                  *args,
                  **kwargs):
         super(BigQueryOperator, self).__init__(*args, **kwargs)
@@ -140,7 +148,10 @@ def __init__(self,
         self.labels = labels
         self.bq_cursor = None
         self.priority = priority
-        self.time_partitioning = time_partitioning
+        if time_partitioning is None:
+            self.time_partitioning = {}
+        if api_resource_configs is None:
+            self.api_resource_configs = {}
 
         # TODO remove `bql` in Airflow 2.0
         if self.bql:
@@ -179,7 +190,8 @@ def execute(self, context):
             labels=self.labels,
             schema_update_options=self.schema_update_options,
             priority=self.priority,
-            time_partitioning=self.time_partitioning
+            time_partitioning=self.time_partitioning,
+            api_resource_configs=self.api_resource_configs,
         )
 
     def on_kill(self):
diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py
index d7e9491b8b..69a103bf06 100644
--- a/tests/contrib/hooks/test_bigquery_hook.py
+++ b/tests/contrib/hooks/test_bigquery_hook.py
@@ -25,7 +25,8 @@
 import mock
 
 from airflow.contrib.hooks import bigquery_hook as hook
-from airflow.contrib.hooks.bigquery_hook import _cleanse_time_partitioning
+from airflow.contrib.hooks.bigquery_hook import _cleanse_time_partitioning, \
+    _validate_value, _api_resource_configs_duplication_check
 
 bq_available = True
 
@@ -206,6 +207,16 @@ def mock_job_cancel(projectId, jobId):
 
 
 class TestBigQueryBaseCursor(unittest.TestCase):
+    def test_bql_deprecation_warning(self):
+        with warnings.catch_warnings(record=True) as w:
+            hook.BigQueryBaseCursor("test", "test").run_query(
+                bql='select * from test_table'
+            )
+            yield
+        self.assertIn(
+            'Deprecated parameter `bql`',
+            w[0].message.args[0])
+
     def test_invalid_schema_update_options(self):
         with self.assertRaises(Exception) as context:
             hook.BigQueryBaseCursor("test", "test").run_load(
@@ -216,16 +227,6 @@ def test_invalid_schema_update_options(self):
             )
         self.assertIn("THIS IS NOT VALID", str(context.exception))
 
-    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
-    def test_bql_deprecation_warning(self, mock_rwc):
-        with warnings.catch_warnings(record=True) as w:
-            hook.BigQueryBaseCursor("test", "test").run_query(
-                bql='select * from test_table'
-            )
-        self.assertIn(
-            'Deprecated parameter `bql`',
-            w[0].message.args[0])
-
     def test_nobql_nosql_param_error(self):
         with self.assertRaises(TypeError) as context:
             hook.BigQueryBaseCursor("test", "test").run_query(
@@ -281,6 +282,39 @@ def test_run_query_sql_dialect_override(self, run_with_config):
             args, kwargs = run_with_config.call_args
             self.assertIs(args[0]['query']['useLegacySql'], bool_val)
 
+    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+    def test_api_resource_configs(self, run_with_config):
+        for bool_val in [True, False]:
+            cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id")
+            cursor.run_query('query',
+                             api_resource_configs={
+                                 'query': {'useQueryCache': bool_val}})
+            args, kwargs = run_with_config.call_args
+            self.assertIs(args[0]['query']['useQueryCache'], bool_val)
+            self.assertIs(args[0]['query']['useLegacySql'], True)
+
+    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+    def test_api_resource_configs_duplication_warning(self, run_with_config):
+        with self.assertRaises(ValueError):
+            cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id")
+            cursor.run_query('query',
+                             use_legacy_sql=True,
+                             api_resource_configs={
+                                 'query': {'useLegacySql': False}})
+
+    def test_validate_value(self):
+        with self.assertRaises(TypeError):
+            _validate_value("case_1", "a", dict)
+        self.assertIsNone(_validate_value("case_2", 0, int))
+
+    def test_duplication_check(self):
+        with self.assertRaises(ValueError):
+            key_one = True
+            _api_resource_configs_duplication_check(
+                "key_one", key_one, {"key_one": False})
+        self.assertIsNone(_api_resource_configs_duplication_check(
+            "key_one", key_one, {"key_one": True}))
+
 
 class TestLabelsInRunJob(unittest.TestCase):
     @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services