You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/10/02 19:15:41 UTC

[beam] branch master updated: [BEAM-9506] Evaluate gcs_location at runtime, not at pipeline construction time

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 320a81c  [BEAM-9506] Evaluate gcs_location at runtime, not at pipeline construction time
     new 8ff9c89  Merge pull request #12939 from [BEAM-9506] Evaluate gcs_location at runtime, not at pipeline construction time
320a81c is described below

commit 320a81cac60b34911425b5093ee9acedfaef7d0c
Author: Kamil Wasilewski <ka...@polidea.com>
AuthorDate: Fri Sep 25 14:49:47 2020 +0200

    [BEAM-9506] Evaluate gcs_location at runtime, not at pipeline construction time
---
 sdks/python/apache_beam/io/gcp/bigquery.py      | 73 ++++++++++---------
 sdks/python/apache_beam/io/gcp/bigquery_test.py | 94 +++++++++++++++++--------
 2 files changed, 107 insertions(+), 60 deletions(-)

diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py
index 4c42ed3..03d594b 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -244,6 +244,7 @@ import time
 import uuid
 from builtins import object
 from builtins import zip
+from typing import Optional
 
 from future.utils import itervalues
 from past.builtins import unicode
@@ -683,7 +684,8 @@ class _CustomBigQuerySource(BoundedSource):
       bigquery_job_labels=None,
       use_json_exports=False,
       job_name=None,
-      step_name=None):
+      step_name=None,
+      unique_id=None):
     if table is not None and query is not None:
       raise ValueError(
           'Both a BigQuery table and a query were specified.'
@@ -716,7 +718,7 @@ class _CustomBigQuerySource(BoundedSource):
     self.use_json_exports = use_json_exports
     self._job_name = job_name or 'AUTOMATIC_JOB_NAME'
     self._step_name = step_name
-    self._source_uuid = str(uuid.uuid4())[0:10]
+    self._source_uuid = unique_id
 
   def _get_bq_metadata(self):
     if not self.bq_io_metadata:
@@ -869,8 +871,11 @@ class _CustomBigQuerySource(BoundedSource):
         self._source_uuid,
         bigquery_tools.BigQueryJobTypes.EXPORT,
         random.randint(0, 1000))
+    temp_location = self.options.view_as(GoogleCloudOptions).temp_location
+    gcs_location = ReadFromBigQuery.get_destination_uri(
+        self.gcs_location, temp_location, self._source_uuid)
     if self.use_json_exports:
-      job_ref = bq.perform_extract_job([self.gcs_location],
+      job_ref = bq.perform_extract_job([gcs_location],
                                        export_job_name,
                                        self.table_reference,
                                        bigquery_tools.FileFormat.JSON,
@@ -878,7 +883,7 @@ class _CustomBigQuerySource(BoundedSource):
                                        job_labels=job_labels,
                                        include_header=False)
     else:
-      job_ref = bq.perform_extract_job([self.gcs_location],
+      job_ref = bq.perform_extract_job([gcs_location],
                                        export_job_name,
                                        self.table_reference,
                                        bigquery_tools.FileFormat.AVRO,
@@ -887,7 +892,7 @@ class _CustomBigQuerySource(BoundedSource):
                                        job_labels=job_labels,
                                        use_avro_logical_types=True)
     bq.wait_for_bq_job(job_ref)
-    metadata_list = FileSystems.match([self.gcs_location])[0].metadata_list
+    metadata_list = FileSystems.match([gcs_location])[0].metadata_list
 
     if isinstance(self.table_reference, vp.ValueProvider):
       table_ref = bigquery_tools.parse_table_reference(
@@ -1907,7 +1912,7 @@ class ReadFromBigQuery(PTransform):
    """
   COUNTER = 0
 
-  def __init__(self, gcs_location=None, validate=False, *args, **kwargs):
+  def __init__(self, gcs_location=None, *args, **kwargs):
     if gcs_location:
       if not isinstance(gcs_location, (str, unicode, ValueProvider)):
         raise TypeError(
@@ -1919,53 +1924,57 @@ class ReadFromBigQuery(PTransform):
         gcs_location = StaticValueProvider(str, gcs_location)
 
     self.gcs_location = gcs_location
-    self.validate = validate
 
     self._args = args
     self._kwargs = kwargs
 
-  def _get_destination_uri(self, temp_location):
+  @staticmethod
+  def get_destination_uri(
+      gcs_location_vp,  # type: Optional[ValueProvider]
+      temp_location,  # type: Optional[str]
+      unique_id,  # type: str
+  ):
     """Returns the fully qualified Google Cloud Storage URI where the
     extracted table should be written.
     """
     file_pattern = 'bigquery-table-dump-*.json'
 
-    if self.gcs_location is not None:
-      gcs_base = self.gcs_location.get()
+    gcs_location = None
+    if gcs_location_vp is not None:
+      gcs_location = gcs_location_vp.get()
+
+    if gcs_location is not None:
+      gcs_base = gcs_location
     elif temp_location is not None:
       gcs_base = temp_location
-      logging.debug("gcs_location is empty, using temp_location instead")
+      _LOGGER.debug("gcs_location is empty, using temp_location instead")
     else:
       raise ValueError(
-          '{} requires a GCS location to be provided. Neither gcs_location in'
-          ' the constructor nor the fallback option --temp_location is set.'.
-          format(self.__class__.__name__))
-    if self.validate:
-      self._validate_gcs_location(gcs_base)
+          'ReadFromBigQuery requires a GCS location to be provided. Neither '
+          'gcs_location in the constructor nor the fallback option '
+          '--temp_location is set.')
 
-    job_id = uuid.uuid4().hex
-    return FileSystems.join(gcs_base, job_id, file_pattern)
-
-  @staticmethod
-  def _validate_gcs_location(gcs_location):
-    if not gcs_location.startswith('gs://'):
-      raise ValueError('Invalid GCS location: {}'.format(gcs_location))
+    return FileSystems.join(gcs_base, unique_id, file_pattern)
 
   def expand(self, pcoll):
-    class RemoveJsonFiles(beam.DoFn):
-      def __init__(self, gcs_location):
-        self._gcs_location = gcs_location
+    class RemoveExportedFiles(beam.DoFn):
+      def __init__(self, gcs_location_vp):
+        self._gcs_location_vp = gcs_location_vp
+        self._temp_location = temp_location
+        self._unique_id = unique_id
 
       def process(self, unused_element, signal):
-        match_result = FileSystems.match([self._gcs_location])[0].metadata_list
-        logging.debug(
+        gcs_location = ReadFromBigQuery.get_destination_uri(
+            self._gcs_location_vp, self._temp_location, self._unique_id)
+        match_result = FileSystems.match([gcs_location])[0].metadata_list
+        _LOGGER.debug(
             "%s: matched %s files", self.__class__.__name__, len(match_result))
         paths = [x.path for x in match_result]
         FileSystems.delete(paths)
 
+    unique_id = str(uuid.uuid4())[0:10]
     temp_location = pcoll.pipeline.options.view_as(
         GoogleCloudOptions).temp_location
-    gcs_location = self._get_destination_uri(temp_location)
     job_name = pcoll.pipeline.options.view_as(GoogleCloudOptions).job_name
 
     try:
@@ -1977,11 +1986,11 @@ class ReadFromBigQuery(PTransform):
         pcoll
         | beam.io.Read(
             _CustomBigQuerySource(
-                gcs_location=gcs_location,
-                validate=self.validate,
+                gcs_location=self.gcs_location,
                 pipeline_options=pcoll.pipeline.options,
                 job_name=job_name,
                 step_name=step_name,
+                unique_id=unique_id,
                 *self._args,
                 **self._kwargs))
-        | _PassThroughThenCleanup(RemoveJsonFiles(gcs_location)))
+        | _PassThroughThenCleanup(RemoveExportedFiles(self.gcs_location)))
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py
index fa3e84b..b399d83 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -44,6 +44,7 @@ from apache_beam.internal import pickler
 from apache_beam.internal.gcp.json_value import to_json_value
 from apache_beam.io.filebasedsink_test import _TestCaseWithTempDirCleanUp
 from apache_beam.io.gcp import bigquery_tools
+from apache_beam.io.gcp.bigquery import ReadFromBigQuery
 from apache_beam.io.gcp.bigquery import TableRowJsonCoder
 from apache_beam.io.gcp.bigquery import WriteToBigQuery
 from apache_beam.io.gcp.bigquery import _JsonToDictCoder
@@ -58,8 +59,10 @@ from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
 from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher
 from apache_beam.io.gcp.tests.bigquery_matcher import BigQueryTableMatcher
 from apache_beam.options import value_provider
-from apache_beam.options.pipeline_options import GoogleCloudOptions
+from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.options.value_provider import RuntimeValueProvider
+from apache_beam.options.value_provider import StaticValueProvider
 from apache_beam.runners.dataflow.test_dataflow_runner import TestDataflowRunner
 from apache_beam.runners.runner import PipelineState
 from apache_beam.testing import test_utils
@@ -370,34 +373,69 @@ class TestJsonToDictCoder(unittest.TestCase):
 
 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
 class TestReadFromBigQuery(unittest.TestCase):
-  def test_exception_is_raised_when_gcs_location_cannot_be_specified(self):
-    with self.assertRaises(ValueError):
-      p = beam.Pipeline()
-      _ = p | beam.io.ReadFromBigQuery(
-          project='project', dataset='dataset', table='table')
-
-  @mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper')
-  def test_fallback_to_temp_location(self, BigQueryWrapper):
-    pipeline_options = beam.pipeline.PipelineOptions()
-    pipeline_options.view_as(GoogleCloudOptions).temp_location = 'gs://bucket'
-    try:
-      p = beam.Pipeline(options=pipeline_options)
-      _ = p | beam.io.ReadFromBigQuery(
-          project='project', dataset='dataset', table='table')
-    except ValueError:
-      self.fail('ValueError was raised unexpectedly')
-
-  def test_gcs_location_validation_works_properly(self):
-    with self.assertRaises(ValueError) as context:
-      p = beam.Pipeline()
-      _ = p | beam.io.ReadFromBigQuery(
-          project='project',
-          dataset='dataset',
-          table='table',
-          validate=True,
-          gcs_location='fs://bad_location')
+  @classmethod
+  def setUpClass(cls):
+    class UserDefinedOptions(PipelineOptions):
+      @classmethod
+      def _add_argparse_args(cls, parser):
+        parser.add_value_provider_argument('--gcs_location')
+
+    cls.UserDefinedOptions = UserDefinedOptions
+
+  def tearDown(self):
+    # Reset runtime options to avoid side-effects caused by other tests.
+    RuntimeValueProvider.set_runtime_options(None)
+
+  def test_get_destination_uri_empty_runtime_vp(self):
+    with self.assertRaisesRegex(ValueError,
+                                '^ReadFromBigQuery requires a GCS '
+                                'location to be provided'):
+      # Don't provide any runtime values.
+      RuntimeValueProvider.set_runtime_options({})
+      options = self.UserDefinedOptions()
+
+      ReadFromBigQuery.get_destination_uri(
+          options.gcs_location, None, uuid.uuid4().hex)
+
+  def test_get_destination_uri_none(self):
+    with self.assertRaisesRegex(ValueError,
+                                '^ReadFromBigQuery requires a GCS '
+                                'location to be provided'):
+      ReadFromBigQuery.get_destination_uri(None, None, uuid.uuid4().hex)
+
+  def test_get_destination_uri_runtime_vp(self):
+    # Provide values at job-execution time.
+    RuntimeValueProvider.set_runtime_options({'gcs_location': 'gs://bucket'})
+    options = self.UserDefinedOptions()
+    unique_id = uuid.uuid4().hex
+
+    uri = ReadFromBigQuery.get_destination_uri(
+        options.gcs_location, None, unique_id)
+    self.assertEqual(
+        uri, 'gs://bucket/' + unique_id + '/bigquery-table-dump-*.json')
+
+  def test_get_destination_uri_static_vp(self):
+    unique_id = uuid.uuid4().hex
+    uri = ReadFromBigQuery.get_destination_uri(
+        StaticValueProvider(str, 'gs://bucket'), None, unique_id)
     self.assertEqual(
-        'Invalid GCS location: fs://bad_location', str(context.exception))
+        uri, 'gs://bucket/' + unique_id + '/bigquery-table-dump-*.json')
+
+  def test_get_destination_uri_fallback_temp_location(self):
+    # Don't provide any runtime values.
+    RuntimeValueProvider.set_runtime_options({})
+    options = self.UserDefinedOptions()
+
+    with self.assertLogs('apache_beam.io.gcp.bigquery',
+                         level='DEBUG') as context:
+      ReadFromBigQuery.get_destination_uri(
+          options.gcs_location, 'gs://bucket', uuid.uuid4().hex)
+    self.assertEqual(
+        context.output,
+        [
+            'DEBUG:apache_beam.io.gcp.bigquery:gcs_location is empty, '
+            'using temp_location instead'
+        ])
 
 
 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')