You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by fo...@apache.org on 2018/01/22 17:27:47 UTC

incubator-airflow git commit: [AIRFLOW-1267][AIRFLOW-1874] Add dialect parameter to BigQueryHook

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 24bb2b7b6 -> 1021f6803


[AIRFLOW-1267][AIRFLOW-1874] Add dialect parameter to BigQueryHook

Allows a default BigQuery dialect to be specified
at the hook level, which is threaded through to
the
underlying cursors.

This allows standard SQL dialect to be used,
while maintaining compatibility with the
`DbApiHook` interface.

Addresses AIRFLOW-1267 and AIRFLOW-1874

Closes #2964 from ji-han/master


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/1021f680
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/1021f680
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/1021f680

Branch: refs/heads/master
Commit: 1021f680317305e4ceb3a3b06889a8742cc3f6f3
Parents: 24bb2b7
Author: Winston Huang <wi...@quizlet.com>
Authored: Mon Jan 22 18:27:40 2018 +0100
Committer: Fokko Driesprong <fo...@godatadriven.com>
Committed: Mon Jan 22 18:27:40 2018 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/bigquery_hook.py         | 36 ++++++++++++++-------
 airflow/contrib/operators/bigquery_operator.py |  2 +-
 tests/contrib/hooks/test_bigquery_hook.py      | 21 +++++++++++-
 3 files changed, 45 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1021f680/airflow/contrib/hooks/bigquery_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
index fe51d50..67c6329 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -44,9 +44,13 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
     """
     conn_name_attr = 'bigquery_conn_id'
 
-    def __init__(self, bigquery_conn_id='bigquery_default', delegate_to=None):
+    def __init__(self,
+                 bigquery_conn_id='bigquery_default',
+                 delegate_to=None,
+                 use_legacy_sql=True):
         super(BigQueryHook, self).__init__(
             conn_id=bigquery_conn_id, delegate_to=delegate_to)
+        self.use_legacy_sql = use_legacy_sql
 
     def get_conn(self):
         """
@@ -54,7 +58,10 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
         """
         service = self.get_service()
         project = self._get_field('project')
-        return BigQueryConnection(service=service, project_id=project)
+        return BigQueryConnection(
+            service=service,
+            project_id=project,
+            use_legacy_sql=self.use_legacy_sql)
 
     def get_service(self):
         """
@@ -71,7 +78,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
         """
         raise NotImplementedError()
 
-    def get_pandas_df(self, bql, parameters=None, dialect='legacy'):
+    def get_pandas_df(self, bql, parameters=None, dialect=None):
         """
         Returns a Pandas DataFrame for the results produced by a BigQuery
         query. The DbApiHook method must be overridden because Pandas
@@ -86,10 +93,15 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook, LoggingMixin):
             used, leave to override superclass method)
         :type parameters: mapping or iterable
         :param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL
-        :type dialect: string in {'legacy', 'standard'}, default 'legacy'
+            defaults to use `self.use_legacy_sql` if not specified
+        :type dialect: string in {'legacy', 'standard'}
         """
         service = self.get_service()
         project = self._get_field('project')
+
+        if dialect is None:
+            dialect = 'legacy' if self.use_legacy_sql else 'standard'
+
         connector = BigQueryPandasConnector(project, service, dialect=dialect)
         schema, pages = connector.run_query(bql)
         dataframe_list = []
@@ -188,9 +200,10 @@ class BigQueryBaseCursor(LoggingMixin):
     PEP 249 cursor isn't needed.
     """
 
-    def __init__(self, service, project_id):
+    def __init__(self, service, project_id, use_legacy_sql=True):
         self.service = service
         self.project_id = project_id
+        self.use_legacy_sql = use_legacy_sql
         self.running_job_id = None
 
     def run_query(self,
@@ -199,7 +212,6 @@ class BigQueryBaseCursor(LoggingMixin):
                   write_disposition='WRITE_EMPTY',
                   allow_large_results=False,
                   udf_config=False,
-                  use_legacy_sql=True,
                   maximum_billing_tier=None,
                   create_disposition='CREATE_IF_NEEDED',
                   query_params=None,
@@ -224,8 +236,6 @@ class BigQueryBaseCursor(LoggingMixin):
         :param udf_config: The User Defined Function configuration for the query.
             See https://cloud.google.com/bigquery/user-defined-functions for details.
         :type udf_config: list
-        :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
-        :type use_legacy_sql: boolean
         :param maximum_billing_tier: Positive integer that serves as a
             multiplier of the basic price.
         :type maximum_billing_tier: integer
@@ -257,7 +267,7 @@ class BigQueryBaseCursor(LoggingMixin):
         configuration = {
             'query': {
                 'query': bql,
-                'useLegacySql': use_legacy_sql,
+                'useLegacySql': self.use_legacy_sql,
                 'maximumBillingTier': maximum_billing_tier
             }
         }
@@ -290,7 +300,7 @@ class BigQueryBaseCursor(LoggingMixin):
             })
 
         if query_params:
-            if use_legacy_sql:
+            if self.use_legacy_sql:
                 raise ValueError("Query paramaters are not allowed when using "
                                  "legacy SQL")
             else:
@@ -942,9 +952,11 @@ class BigQueryCursor(BigQueryBaseCursor):
     https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
     """
 
-    def __init__(self, service, project_id):
+    def __init__(self, service, project_id, use_legacy_sql=True):
         super(BigQueryCursor, self).__init__(
-            service=service, project_id=project_id)
+            service=service,
+            project_id=project_id,
+            use_legacy_sql=use_legacy_sql)
         self.buffersize = None
         self.page_token = None
         self.job_id = None

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1021f680/airflow/contrib/operators/bigquery_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py
index 4b3ce75..94fa9b7 100644
--- a/airflow/contrib/operators/bigquery_operator.py
+++ b/airflow/contrib/operators/bigquery_operator.py
@@ -98,6 +98,7 @@ class BigQueryOperator(BaseOperator):
             self.log.info('Executing: %s', self.bql)
             hook = BigQueryHook(
                 bigquery_conn_id=self.bigquery_conn_id,
+                use_legacy_sql=self.use_legacy_sql,
                 delegate_to=self.delegate_to)
             conn = hook.get_conn()
             self.bq_cursor = conn.cursor()
@@ -107,7 +108,6 @@ class BigQueryOperator(BaseOperator):
             write_disposition=self.write_disposition,
             allow_large_results=self.allow_large_results,
             udf_config=self.udf_config,
-            use_legacy_sql=self.use_legacy_sql,
             maximum_billing_tier=self.maximum_billing_tier,
             create_disposition=self.create_disposition,
             query_params=self.query_params,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1021f680/tests/contrib/hooks/test_bigquery_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py
index 86268c4..4c0eefa 100644
--- a/tests/contrib/hooks/test_bigquery_hook.py
+++ b/tests/contrib/hooks/test_bigquery_hook.py
@@ -308,6 +308,25 @@ class TestTimePartitioningInRunJob(unittest.TestCase):
             )
 
 
+class TestBigQueryHookLegacySql(unittest.TestCase):
+    """Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly."""
+
+    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+    def test_hook_uses_legacy_sql_by_default(self, run_with_config):
+        with mock.patch.object(hook.BigQueryHook, 'get_service'):
+            bq_hook = hook.BigQueryHook()
+            bq_hook.get_first('query')
+            args, kwargs = run_with_config.call_args
+            self.assertIs(args[0]['query']['useLegacySql'], True)
+
+    @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+    def test_legacy_sql_override_propagates_properly(self, run_with_config):
+        with mock.patch.object(hook.BigQueryHook, 'get_service'):
+            bq_hook = hook.BigQueryHook(use_legacy_sql=False)
+            bq_hook.get_first('query')
+            args, kwargs = run_with_config.call_args
+            self.assertIs(args[0]['query']['useLegacySql'], False)
+
+
 if __name__ == '__main__':
     unittest.main()
-