You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ch...@apache.org on 2022/03/21 17:23:06 UTC

[beam] 01/01: Revert "[BEAM-14112] Avoid storing a generator in _CustomBigQuerySource (#17100)"

This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch revert-17100-cyang/rfbq-interactive-fix
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 3ccd12e69bda2816929a54cf3639738c82733782
Author: Chamikara Jayalath <ch...@google.com>
AuthorDate: Mon Mar 21 10:20:13 2022 -0700

    Revert "[BEAM-14112] Avoid storing a generator in _CustomBigQuerySource (#17100)"
    
    This reverts commit 62a661071b7db15e71d236abe68e15582e8997c9.
---
 sdks/python/apache_beam/io/gcp/bigquery.py         | 38 +++++++++++++---------
 .../apache_beam/io/gcp/bigquery_read_it_test.py    | 12 -------
 2 files changed, 22 insertions(+), 28 deletions(-)

diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py
index fc0f4b7..0503de8 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -705,7 +705,7 @@ class _CustomBigQuerySource(BoundedSource):
     self.flatten_results = flatten_results
     self.coder = coder or _JsonToDictCoder
     self.kms_key = kms_key
-    self.export_result = None
+    self.split_result = None
     self.options = pipeline_options
     self.bq_io_metadata = None  # Populate in setup, as it may make an RPC
     self.bigquery_job_labels = bigquery_job_labels or {}
@@ -789,26 +789,19 @@ class _CustomBigQuerySource(BoundedSource):
       project = self.project
     return project
 
-  def _create_source(self, path, bq):
+  def _create_source(self, path, schema):
     if not self.use_json_exports:
       return create_avro_source(path)
     else:
-      if isinstance(self.table_reference, vp.ValueProvider):
-        table_ref = bigquery_tools.parse_table_reference(
-            self.table_reference.get(), project=self.project)
-      else:
-        table_ref = self.table_reference
-      table = bq.get_table(
-          table_ref.projectId, table_ref.datasetId, table_ref.tableId)
       return TextSource(
           path,
           min_bundle_size=0,
           compression_type=CompressionTypes.UNCOMPRESSED,
           strip_trailing_newlines=True,
-          coder=self.coder(table.schema))
+          coder=self.coder(schema))
 
   def split(self, desired_bundle_size, start_position=None, stop_position=None):
-    if self.export_result is None:
+    if self.split_result is None:
       bq = bigquery_tools.BigQueryWrapper(
           temp_dataset_id=(
               self.temp_dataset.datasetId if self.temp_dataset else None))
@@ -820,13 +813,16 @@ class _CustomBigQuerySource(BoundedSource):
       if not self.table_reference.projectId:
         self.table_reference.projectId = self._get_project()
 
-      self.export_result = self._export_files(bq)
+      schema, metadata_list = self._export_files(bq)
+      # Sources to be created lazily within a generator as they're output.
+      self.split_result = (
+          self._create_source(metadata.path, schema)
+          for metadata in metadata_list)
 
       if self.query is not None:
         bq.clean_up_temporary_dataset(self._get_project())
 
-    for metadata in self.export_result:
-      source = self._create_source(metadata.path, bq)
+    for source in self.split_result:
       yield SourceBundle(
           weight=1.0, source=source, start_position=None, stop_position=None)
 
@@ -878,7 +874,7 @@ class _CustomBigQuerySource(BoundedSource):
     """Runs a BigQuery export job.
 
     Returns:
-      a list of FileMetadata instances
+      bigquery.TableSchema instance, a list of FileMetadata instances
     """
     job_labels = self._get_bq_metadata().add_additional_bq_job_labels(
         self.bigquery_job_labels)
@@ -908,7 +904,17 @@ class _CustomBigQuerySource(BoundedSource):
                                        job_labels=job_labels,
                                        use_avro_logical_types=True)
     bq.wait_for_bq_job(job_ref)
-    return FileSystems.match([gcs_location])[0].metadata_list
+    metadata_list = FileSystems.match([gcs_location])[0].metadata_list
+
+    if isinstance(self.table_reference, vp.ValueProvider):
+      table_ref = bigquery_tools.parse_table_reference(
+          self.table_reference.get(), project=self.project)
+    else:
+      table_ref = self.table_reference
+    table = bq.get_table(
+        table_ref.projectId, table_ref.datasetId, table_ref.tableId)
+
+    return table.schema, metadata_list
 
 
 class _CustomBigQueryStorageSource(BoundedSource):
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
index e47754d..9101039 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py
@@ -37,8 +37,6 @@ from apache_beam.io.gcp import bigquery_tools
 from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
 from apache_beam.io.gcp.internal.clients import bigquery
 from apache_beam.options.value_provider import StaticValueProvider
-from apache_beam.runners.interactive import interactive_beam
-from apache_beam.runners.interactive.interactive_runner import InteractiveRunner
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -675,16 +673,6 @@ class ReadAllBQTests(BigQueryReadIntegrationTests):
           equal_to(self.TABLE_DATA_1 + self.TABLE_DATA_2 + self.TABLE_DATA_3))
 
 
-class ReadInteractiveRunnerTests(BigQueryReadIntegrationTests):
-  @skip(['PortableRunner', 'FlinkRunner'])
-  @pytest.mark.it_postcommit
-  def test_read_in_interactive_runner(self):
-    p = beam.Pipeline(InteractiveRunner(), argv=self.args)
-    pcoll = p | beam.io.ReadFromBigQuery(query="SELECT 1")
-    result = interactive_beam.collect(pcoll)
-    assert result.iloc[0, 0] == 1
-
-
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()