You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/10/02 19:15:41 UTC
[beam] branch master updated: [BEAM-9506] Evaluate gcs_location at
runtime, not at pipeline construction time
This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 320a81c [BEAM-9506] Evaluate gcs_location at runtime, not at pipeline construction time
new 8ff9c89 Merge pull request #12939 from [BEAM-9506] Evaluate gcs_location at runtime, not at pipeline construction time
320a81c is described below
commit 320a81cac60b34911425b5093ee9acedfaef7d0c
Author: Kamil Wasilewski <ka...@polidea.com>
AuthorDate: Fri Sep 25 14:49:47 2020 +0200
[BEAM-9506] Evaluate gcs_location at runtime, not at pipeline construction time
---
sdks/python/apache_beam/io/gcp/bigquery.py | 73 ++++++++++---------
sdks/python/apache_beam/io/gcp/bigquery_test.py | 94 +++++++++++++++++--------
2 files changed, 107 insertions(+), 60 deletions(-)
diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py
index 4c42ed3..03d594b 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -244,6 +244,7 @@ import time
import uuid
from builtins import object
from builtins import zip
+from typing import Optional
from future.utils import itervalues
from past.builtins import unicode
@@ -683,7 +684,8 @@ class _CustomBigQuerySource(BoundedSource):
bigquery_job_labels=None,
use_json_exports=False,
job_name=None,
- step_name=None):
+ step_name=None,
+ unique_id=None):
if table is not None and query is not None:
raise ValueError(
'Both a BigQuery table and a query were specified.'
@@ -716,7 +718,7 @@ class _CustomBigQuerySource(BoundedSource):
self.use_json_exports = use_json_exports
self._job_name = job_name or 'AUTOMATIC_JOB_NAME'
self._step_name = step_name
- self._source_uuid = str(uuid.uuid4())[0:10]
+ self._source_uuid = unique_id
def _get_bq_metadata(self):
if not self.bq_io_metadata:
@@ -869,8 +871,11 @@ class _CustomBigQuerySource(BoundedSource):
self._source_uuid,
bigquery_tools.BigQueryJobTypes.EXPORT,
random.randint(0, 1000))
+ temp_location = self.options.view_as(GoogleCloudOptions).temp_location
+ gcs_location = ReadFromBigQuery.get_destination_uri(
+ self.gcs_location, temp_location, self._source_uuid)
if self.use_json_exports:
- job_ref = bq.perform_extract_job([self.gcs_location],
+ job_ref = bq.perform_extract_job([gcs_location],
export_job_name,
self.table_reference,
bigquery_tools.FileFormat.JSON,
@@ -878,7 +883,7 @@ class _CustomBigQuerySource(BoundedSource):
job_labels=job_labels,
include_header=False)
else:
- job_ref = bq.perform_extract_job([self.gcs_location],
+ job_ref = bq.perform_extract_job([gcs_location],
export_job_name,
self.table_reference,
bigquery_tools.FileFormat.AVRO,
@@ -887,7 +892,7 @@ class _CustomBigQuerySource(BoundedSource):
job_labels=job_labels,
use_avro_logical_types=True)
bq.wait_for_bq_job(job_ref)
- metadata_list = FileSystems.match([self.gcs_location])[0].metadata_list
+ metadata_list = FileSystems.match([gcs_location])[0].metadata_list
if isinstance(self.table_reference, vp.ValueProvider):
table_ref = bigquery_tools.parse_table_reference(
@@ -1907,7 +1912,7 @@ class ReadFromBigQuery(PTransform):
"""
COUNTER = 0
- def __init__(self, gcs_location=None, validate=False, *args, **kwargs):
+ def __init__(self, gcs_location=None, *args, **kwargs):
if gcs_location:
if not isinstance(gcs_location, (str, unicode, ValueProvider)):
raise TypeError(
@@ -1919,53 +1924,57 @@ class ReadFromBigQuery(PTransform):
gcs_location = StaticValueProvider(str, gcs_location)
self.gcs_location = gcs_location
- self.validate = validate
self._args = args
self._kwargs = kwargs
- def _get_destination_uri(self, temp_location):
+ @staticmethod
+ def get_destination_uri(
+ gcs_location_vp, # type: Optional[ValueProvider]
+ temp_location, # type: Optional[str]
+ unique_id, # type: str
+ ):
"""Returns the fully qualified Google Cloud Storage URI where the
extracted table should be written.
"""
file_pattern = 'bigquery-table-dump-*.json'
- if self.gcs_location is not None:
- gcs_base = self.gcs_location.get()
+ gcs_location = None
+ if gcs_location_vp is not None:
+ gcs_location = gcs_location_vp.get()
+
+ if gcs_location is not None:
+ gcs_base = gcs_location
elif temp_location is not None:
gcs_base = temp_location
- logging.debug("gcs_location is empty, using temp_location instead")
+ _LOGGER.debug("gcs_location is empty, using temp_location instead")
else:
raise ValueError(
- '{} requires a GCS location to be provided. Neither gcs_location in'
- ' the constructor nor the fallback option --temp_location is set.'.
- format(self.__class__.__name__))
- if self.validate:
- self._validate_gcs_location(gcs_base)
+ 'ReadFromBigQuery requires a GCS location to be provided. Neither '
+ 'gcs_location in the constructor nor the fallback option '
+ '--temp_location is set.')
- job_id = uuid.uuid4().hex
- return FileSystems.join(gcs_base, job_id, file_pattern)
-
- @staticmethod
- def _validate_gcs_location(gcs_location):
- if not gcs_location.startswith('gs://'):
- raise ValueError('Invalid GCS location: {}'.format(gcs_location))
+ return FileSystems.join(gcs_base, unique_id, file_pattern)
def expand(self, pcoll):
- class RemoveJsonFiles(beam.DoFn):
- def __init__(self, gcs_location):
- self._gcs_location = gcs_location
+ class RemoveExportedFiles(beam.DoFn):
+ def __init__(self, gcs_location_vp):
+ self._gcs_location_vp = gcs_location_vp
+ self._temp_location = temp_location
+ self._unique_id = unique_id
def process(self, unused_element, signal):
- match_result = FileSystems.match([self._gcs_location])[0].metadata_list
- logging.debug(
+ gcs_location = ReadFromBigQuery.get_destination_uri(
+ self._gcs_location_vp, self._temp_location, self._unique_id)
+ match_result = FileSystems.match([gcs_location])[0].metadata_list
+ _LOGGER.debug(
"%s: matched %s files", self.__class__.__name__, len(match_result))
paths = [x.path for x in match_result]
FileSystems.delete(paths)
+ unique_id = str(uuid.uuid4())[0:10]
temp_location = pcoll.pipeline.options.view_as(
GoogleCloudOptions).temp_location
- gcs_location = self._get_destination_uri(temp_location)
job_name = pcoll.pipeline.options.view_as(GoogleCloudOptions).job_name
try:
@@ -1977,11 +1986,11 @@ class ReadFromBigQuery(PTransform):
pcoll
| beam.io.Read(
_CustomBigQuerySource(
- gcs_location=gcs_location,
- validate=self.validate,
+ gcs_location=self.gcs_location,
pipeline_options=pcoll.pipeline.options,
job_name=job_name,
step_name=step_name,
+ unique_id=unique_id,
*self._args,
**self._kwargs))
- | _PassThroughThenCleanup(RemoveJsonFiles(gcs_location)))
+ | _PassThroughThenCleanup(RemoveExportedFiles(self.gcs_location)))
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py
index fa3e84b..b399d83 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -44,6 +44,7 @@ from apache_beam.internal import pickler
from apache_beam.internal.gcp.json_value import to_json_value
from apache_beam.io.filebasedsink_test import _TestCaseWithTempDirCleanUp
from apache_beam.io.gcp import bigquery_tools
+from apache_beam.io.gcp.bigquery import ReadFromBigQuery
from apache_beam.io.gcp.bigquery import TableRowJsonCoder
from apache_beam.io.gcp.bigquery import WriteToBigQuery
from apache_beam.io.gcp.bigquery import _JsonToDictCoder
@@ -58,8 +59,10 @@ from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher
from apache_beam.io.gcp.tests.bigquery_matcher import BigQueryTableMatcher
from apache_beam.options import value_provider
-from apache_beam.options.pipeline_options import GoogleCloudOptions
+from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.options.value_provider import RuntimeValueProvider
+from apache_beam.options.value_provider import StaticValueProvider
from apache_beam.runners.dataflow.test_dataflow_runner import TestDataflowRunner
from apache_beam.runners.runner import PipelineState
from apache_beam.testing import test_utils
@@ -370,34 +373,69 @@ class TestJsonToDictCoder(unittest.TestCase):
@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestReadFromBigQuery(unittest.TestCase):
- def test_exception_is_raised_when_gcs_location_cannot_be_specified(self):
- with self.assertRaises(ValueError):
- p = beam.Pipeline()
- _ = p | beam.io.ReadFromBigQuery(
- project='project', dataset='dataset', table='table')
-
- @mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper')
- def test_fallback_to_temp_location(self, BigQueryWrapper):
- pipeline_options = beam.pipeline.PipelineOptions()
- pipeline_options.view_as(GoogleCloudOptions).temp_location = 'gs://bucket'
- try:
- p = beam.Pipeline(options=pipeline_options)
- _ = p | beam.io.ReadFromBigQuery(
- project='project', dataset='dataset', table='table')
- except ValueError:
- self.fail('ValueError was raised unexpectedly')
-
- def test_gcs_location_validation_works_properly(self):
- with self.assertRaises(ValueError) as context:
- p = beam.Pipeline()
- _ = p | beam.io.ReadFromBigQuery(
- project='project',
- dataset='dataset',
- table='table',
- validate=True,
- gcs_location='fs://bad_location')
+ @classmethod
+ def setUpClass(cls):
+ class UserDefinedOptions(PipelineOptions):
+ @classmethod
+ def _add_argparse_args(cls, parser):
+ parser.add_value_provider_argument('--gcs_location')
+
+ cls.UserDefinedOptions = UserDefinedOptions
+
+ def tearDown(self):
+ # Reset runtime options to avoid side-effects caused by other tests.
+ RuntimeValueProvider.set_runtime_options(None)
+
+ def test_get_destination_uri_empty_runtime_vp(self):
+ with self.assertRaisesRegex(ValueError,
+ '^ReadFromBigQuery requires a GCS '
+ 'location to be provided'):
+ # Don't provide any runtime values.
+ RuntimeValueProvider.set_runtime_options({})
+ options = self.UserDefinedOptions()
+
+ ReadFromBigQuery.get_destination_uri(
+ options.gcs_location, None, uuid.uuid4().hex)
+
+ def test_get_destination_uri_none(self):
+ with self.assertRaisesRegex(ValueError,
+ '^ReadFromBigQuery requires a GCS '
+ 'location to be provided'):
+ ReadFromBigQuery.get_destination_uri(None, None, uuid.uuid4().hex)
+
+ def test_get_destination_uri_runtime_vp(self):
+ # Provide values at job-execution time.
+ RuntimeValueProvider.set_runtime_options({'gcs_location': 'gs://bucket'})
+ options = self.UserDefinedOptions()
+ unique_id = uuid.uuid4().hex
+
+ uri = ReadFromBigQuery.get_destination_uri(
+ options.gcs_location, None, unique_id)
+ self.assertEqual(
+ uri, 'gs://bucket/' + unique_id + '/bigquery-table-dump-*.json')
+
+ def test_get_destination_uri_static_vp(self):
+ unique_id = uuid.uuid4().hex
+ uri = ReadFromBigQuery.get_destination_uri(
+ StaticValueProvider(str, 'gs://bucket'), None, unique_id)
self.assertEqual(
- 'Invalid GCS location: fs://bad_location', str(context.exception))
+ uri, 'gs://bucket/' + unique_id + '/bigquery-table-dump-*.json')
+
+ def test_get_destination_uri_fallback_temp_location(self):
+ # Don't provide any runtime values.
+ RuntimeValueProvider.set_runtime_options({})
+ options = self.UserDefinedOptions()
+
+ with self.assertLogs('apache_beam.io.gcp.bigquery',
+ level='DEBUG') as context:
+ ReadFromBigQuery.get_destination_uri(
+ options.gcs_location, 'gs://bucket', uuid.uuid4().hex)
+ self.assertEqual(
+ context.output,
+ [
+ 'DEBUG:apache_beam.io.gcp.bigquery:gcs_location is empty, '
+ 'using temp_location instead'
+ ])
@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')