You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2018/01/11 17:25:49 UTC

incubator-airflow git commit: [AIRFLOW-1688] Support load.time_partitioning in bigquery_hook

Repository: incubator-airflow
Updated Branches:
  refs/heads/master b75367bb5 -> 804710fda


[AIRFLOW-1688] Support load.time_partitioning in bigquery_hook

Closes #2820 from albertocalderari/master


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/804710fd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/804710fd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/804710fd

Branch: refs/heads/master
Commit: 804710fda54d0eb8dfef0385e518b04a35c8fed4
Parents: b75367b
Author: alberto.calderari <al...@just-eat.com>
Authored: Thu Jan 11 09:24:21 2018 -0800
Committer: Chris Riccomini <cr...@apache.org>
Committed: Thu Jan 11 09:24:31 2018 -0800

----------------------------------------------------------------------
 airflow/contrib/hooks/bigquery_hook.py    |  70 ++++++++++++-----
 airflow/contrib/operators/gcs_to_bq.py    |  10 ++-
 tests/contrib/hooks/test_bigquery_hook.py | 104 +++++++++++++++++++++++--
 3 files changed, 158 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/804710fd/airflow/contrib/hooks/bigquery_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
index d64c2a1..fe51d50 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -22,6 +22,7 @@ from builtins import range
 
 from past.builtins import basestring
 
+from airflow import AirflowException
 from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
 from airflow.hooks.dbapi_hook import DbApiHook
 from airflow.utils.log.logging_mixin import LoggingMixin
@@ -450,7 +451,8 @@ class BigQueryBaseCursor(LoggingMixin):
                  allow_quoted_newlines=False,
                  allow_jagged_rows=False,
                  schema_update_options=(),
-                 src_fmt_configs={}):
+                 src_fmt_configs={},
+                 time_partitioning={}):
         """
         Executes a BigQuery load command to load data from Google Cloud Storage
         to BigQuery. See here:
@@ -460,9 +462,11 @@ class BigQueryBaseCursor(LoggingMixin):
         For more details about these parameters.
 
         :param destination_project_dataset_table:
-            The dotted (<project>.|<project>:)<dataset>.<table> BigQuery table to load
-            data into. If <project> is not included, project will be the project defined
-            in the connection json.
+            The dotted (<project>.|<project>:)<dataset>.<table>($<partition>) BigQuery
+            table to load data into. If <project> is not included, project will be the
+            project defined in the connection json. If a partition is specified the
+            operator will automatically append the data, create a new partition or create
+            a new DAY partitioned table.
         :type destination_project_dataset_table: string
         :param schema_fields: The schema field list as defined here:
             https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load
@@ -484,20 +488,28 @@ class BigQueryBaseCursor(LoggingMixin):
         :param max_bad_records: The maximum number of bad records that BigQuery can
             ignore when running the job.
         :type max_bad_records: int
-        :param quote_character: The value that is used to quote data sections in a CSV file.
+        :param quote_character: The value that is used to quote data sections in a CSV
+            file.
         :type quote_character: string
-        :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not (false).
+        :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not
+            (false).
         :type allow_quoted_newlines: boolean
         :param allow_jagged_rows: Accept rows that are missing trailing optional columns.
-            The missing values are treated as nulls. If false, records with missing trailing columns
-            are treated as bad records, and if there are too many bad records, an invalid error is
-            returned in the job result. Only applicable when soure_format is CSV.
+            The missing values are treated as nulls. If false, records with missing
+            trailing columns are treated as bad records, and if there are too many bad
+            records, an invalid error is returned in the job result. Only applicable when
+            soure_format is CSV.
         :type allow_jagged_rows: bool
         :param schema_update_options: Allows the schema of the desitination
             table to be updated as a side effect of the load job.
         :type schema_update_options: tuple
         :param src_fmt_configs: configure optional fields specific to the source format
         :type src_fmt_configs: dict
+        :param time_partitioning: configure optional time partitioning fields i.e.
+            partition by field, type and
+            expiration as per API specifications. Note that 'field' is not available in
+            concurrency with dataset.table$partition.
+        :type time_partitioning: dict
         """
 
         # bigquery only allows certain source formats
@@ -518,7 +530,7 @@ class BigQueryBaseCursor(LoggingMixin):
         # bigquery also allows you to define how you want a table's schema to change
         # as a side effect of a load
         # for more details:
-        #   https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions
+        # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions
         allowed_schema_update_options = [
             'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"
         ]
@@ -547,6 +559,23 @@ class BigQueryBaseCursor(LoggingMixin):
                 'writeDisposition': write_disposition,
             }
         }
+
+        # if it is a partitioned table ($ is in the table name) add partition load option
+        if '$' in destination_project_dataset_table:
+            if time_partitioning.get('field'):
+                raise AirflowException(
+                    "Cannot specify field partition and partition name "
+                    "(dataset.table$partition) at the same time"
+                )
+            configuration['load']['timePartitioning'] = dict(type='DAY')
+
+        # can specify custom time partitioning options based on a field, or adding
+        # expiration
+        if time_partitioning:
+            if not configuration.get('load', {}).get('timePartitioning'):
+                configuration['load']['timePartitioning'] = {}
+            configuration['load']['timePartitioning'].update(time_partitioning)
+
         if schema_fields:
             configuration['load']['schema'] = {'fields': schema_fields}
 
@@ -777,7 +806,7 @@ class BigQueryBaseCursor(LoggingMixin):
                              default_project_id=self.project_id)
 
         try:
-            tables_resource = self.service.tables() \
+            self.service.tables() \
                 .delete(projectId=deletion_project,
                         datasetId=deletion_dataset,
                         tableId=deletion_table) \
@@ -1011,13 +1040,14 @@ class BigQueryCursor(BigQueryBaseCursor):
 
     def fetchmany(self, size=None):
         """
-        Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
-        list of tuples). An empty sequence is returned when no more rows are available.
-        The number of rows to fetch per call is specified by the parameter. If it is not given, the
-        cursor's arraysize determines the number of rows to be fetched. The method should try to
-        fetch as many rows as indicated by the size parameter. If this is not possible due to the
-        specified number of rows not being available, fewer rows may be returned.
-        An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
+        Fetch the next set of rows of a query result, returning a sequence of sequences
+        (e.g. a list of tuples). An empty sequence is returned when no more rows are
+        available. The number of rows to fetch per call is specified by the parameter.
+        If it is not given, the cursor's arraysize determines the number of rows to be
+        fetched. The method should try to fetch as many rows as indicated by the size
+        parameter. If this is not possible due to the specified number of rows not being
+        available, fewer rows may be returned. An :py:class:`~pyhive.exc.Error`
+        (or subclass) exception is raised if the previous call to
         :py:meth:`execute` did not produce any result set or no call was issued yet.
         """
         if size is None:
@@ -1033,8 +1063,8 @@ class BigQueryCursor(BigQueryBaseCursor):
 
     def fetchall(self):
         """
-        Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
-        (e.g. a list of tuples).
+        Fetch all (remaining) rows of a query result, returning them as a sequence of
+        sequences (e.g. a list of tuples).
         """
         result = []
         while True:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/804710fd/airflow/contrib/operators/gcs_to_bq.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py
index 730a3bc..75302b6 100644
--- a/airflow/contrib/operators/gcs_to_bq.py
+++ b/airflow/contrib/operators/gcs_to_bq.py
@@ -52,6 +52,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
         delegate_to=None,
         schema_update_options=(),
         src_fmt_configs={},
+        time_partitioning={},
         *args,
         **kwargs):
         """
@@ -119,6 +120,11 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
         :type schema_update_options: list
         :param src_fmt_configs: configure optional fields specific to the source format
         :type src_fmt_configs: dict
+        :param time_partitioning: configure optional time partitioning fields i.e.
+            partition by field, type and  expiration as per API specifications.
+            Note that 'field' is not available in concurrency with
+            dataset.table$partition.
+        :type time_partitioning: dict
         """
         super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs)
 
@@ -147,6 +153,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
 
         self.schema_update_options = schema_update_options
         self.src_fmt_configs = src_fmt_configs
+        self.time_partitioning = time_partitioning
 
     def execute(self, context):
         bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
@@ -181,7 +188,8 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
             allow_quoted_newlines=self.allow_quoted_newlines,
             allow_jagged_rows=self.allow_jagged_rows,
             schema_update_options=self.schema_update_options,
-            src_fmt_configs=self.src_fmt_configs)
+            src_fmt_configs=self.src_fmt_configs,
+            time_partitioning=self.time_partitioning)
 
         if self.max_id_key:
             cursor.execute('SELECT MAX({}) FROM {}'.format(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/804710fd/tests/contrib/hooks/test_bigquery_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py
index 0365bba..86268c4 100644
--- a/tests/contrib/hooks/test_bigquery_hook.py
+++ b/tests/contrib/hooks/test_bigquery_hook.py
@@ -16,6 +16,7 @@
 import unittest
 import mock
 
+from airflow import AirflowException
 from airflow.contrib.hooks import bigquery_hook as hook
 from oauth2client.contrib.gce import HttpAccessTokenRefreshError
 
@@ -173,6 +174,7 @@ def mock_job_cancel(projectId, jobId):
     mock_canceled_jobs.append(jobId)
     return mock.Mock()
 
+
 class TestBigQueryBaseCursor(unittest.TestCase):
     def test_invalid_schema_update_options(self):
         with self.assertRaises(Exception) as context:
@@ -194,26 +196,118 @@ class TestBigQueryBaseCursor(unittest.TestCase):
                 write_disposition='WRITE_EMPTY'
             )
         self.assertIn("schema_update_options is only", str(context.exception))
-    
+
     @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
     @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
     def test_cancel_queries(self, mocked_logging, mocked_time):
         project_id = 12345
         running_job_id = 3
-        
+
         mock_jobs = mock.Mock()
         mock_jobs.cancel = mock.Mock(side_effect=mock_job_cancel)
         mock_service = mock.Mock()
         mock_service.jobs = mock.Mock(return_value=mock_jobs)
-        
+
         bq_hook = hook.BigQueryBaseCursor(mock_service, project_id)
         bq_hook.running_job_id = running_job_id
         bq_hook.poll_job_complete = mock.Mock(side_effect=mock_poll_job_complete)
-        
+
         bq_hook.cancel_query()
-        
+
         mock_jobs.cancel.assert_called_with(projectId=project_id, jobId=running_job_id)
 
+
+class TestTimePartitioningInRunJob(unittest.TestCase):
+
+    class BigQueryBaseCursorTest(hook.BigQueryBaseCursor):
+        """Use this class to verify the load configuration"""
+        def run_with_configuration(self, configuration):
+            return configuration
+
+    class Serv(object):
+        """mocks the behaviour of a succezsfull Job"""
+
+        class Job(object):
+            """mocks the behaviour of a succezsfull Job"""
+            def __getitem__(self, item=None):
+                return self
+
+            def get(self, projectId, jobId=None):
+                return self.__getitem__(projectId)
+
+            def insert(self, projectId, body=None):
+                return self.get(projectId, body)
+
+            def execute(self):
+                return {
+                    'status': {'state': 'DONE'},
+                    'jobReference': {'jobId': 0}
+                }
+
+        def __int__(self, job='mock_load'):
+            self.job = job
+
+        def jobs(self):
+            return self.Job()
+
+    def test_the_job_execution_wont_break(self):
+        s = self.Serv()
+        bqc = hook.BigQueryBaseCursor(s, 'str')
+        job = bqc.run_load(
+            destination_project_dataset_table='test.teast',
+            schema_fields=[],
+            source_uris=[],
+            time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}
+        )
+
+        self.assertEquals(job, 0)
+
+    def test_dollar_makes_partition(self):
+        s = self.Serv()
+        bqc = self.BigQueryBaseCursorTest(s, 'str')
+        cnfg = bqc.run_load(
+            destination_project_dataset_table='test.teast$20170101',
+            schema_fields=[],
+            source_uris=[],
+            src_fmt_configs={}
+        )
+        expect = {
+            'type': 'DAY'
+        }
+        self.assertEqual(cnfg['load'].get('timePartitioning'), expect)
+
+    def test_extra_time_partitioning_options(self):
+        s = self.Serv()
+        bqc = self.BigQueryBaseCursorTest(s, 'str')
+        cnfg = bqc.run_load(
+            destination_project_dataset_table='test.teast',
+            schema_fields=[],
+            source_uris=[],
+            time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}
+        )
+
+        expect = {
+            'type': 'DAY',
+            'field': 'test_field',
+            'expirationMs': 1000
+        }
+
+        self.assertEqual(cnfg['load'].get('timePartitioning'), expect)
+
+    def test_cant_add_dollar_and_field_name(self):
+        s = self.Serv()
+        bqc = self.BigQueryBaseCursorTest(s, 'str')
+
+        with self.assertRaises(AirflowException):
+            tp_dict = {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}
+            bqc.run_load(
+                destination_project_dataset_table='test.teast$20170101',
+                schema_fields=[],
+                source_uris=[],
+                time_partitioning=tp_dict
+            )
+
+
 if __name__ == '__main__':
     unittest.main()