You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2021/08/02 07:49:51 UTC

[GitHub] [beam] chamikaramj commented on a change in pull request #15185: [BEAM-10917] Add support for BigQuery Read API in Python BEAM

chamikaramj commented on a change in pull request #15185:
URL: https://github.com/apache/beam/pull/15185#discussion_r680718528



##########
File path: sdks/python/apache_beam/io/gcp/bigquery.py
##########
@@ -883,6 +895,221 @@ def _export_files(self, bq):
     return table.schema, metadata_list
 
 
+class _CustomBigQueryStorageSourceBase(BoundedSource):
+  """A base class for BoundedSource implementations which read from BigQuery
+  using the BigQuery Storage API.
+
+  Args:
+    table (str, TableReference): The ID of the table. The ID must contain only
+      letters ``a-z``, ``A-Z``, numbers ``0-9``, or underscores ``_``  If
+      **dataset** argument is :data:`None` then the table argument must
+      contain the entire table reference specified as:
+      ``'PROJECT:DATASET.TABLE'`` or must specify a TableReference.
+    dataset (str): The ID of the dataset containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    project (str): The ID of the project containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    selected_fields (List[str]): Names of the fields in the table that should be
+      read. If empty, all fields will be read. If the specified field is a
+      nested field, all the sub-fields in the field will be selected. The output
+      field order is unrelated to the order of fields in selected_fields.
+    row_restriction (str): SQL text filtering statement, similar to a WHERE
+      clause in a query. Aggregates are not supported.Restricted to a maximum
+      length for 1 MB.
+  """
+
+  # The maximum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size.
+  MAX_SPLIT_COUNT = 10000
+  # The minimum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size. Note that the server may
+  # still choose to return fewer than ten streams based on the layout of the
+  # table.
+  MIN_SPLIT_COUNT = 10
+
+  def __init__(
+      self,
+      table: Union[str, TableReference],
+      dataset: str = None,
+      project: str = None,
+      selected_fields: List[str] = None,
+      row_restriction: str = None,
+      pipeline_options: GoogleCloudOptions = None):
+
+    self.table_reference = bigquery_tools.parse_table_reference(
+        table, dataset, project)
+    self.project = self.table_reference.projectId
+    self.dataset = self.table_reference.datasetId
+    self.table = self.table_reference.tableId
+    self.selected_fields = selected_fields
+    self.row_restriction = row_restriction
+    self.pipeline_options = pipeline_options
+    self.split_result = None
+
+  def _get_parent_project(self):
+    """Returns the project that will be billed."""
+    project = self.pipeline_options.view_as(GoogleCloudOptions).project
+    if isinstance(project, vp.ValueProvider):
+      project = project.get()
+    if not project:
+      project = self.project
+    return project
+
+  def _get_table_size(self, table, dataset, project):
+    if project is None:
+      project = self._get_parent_project()
+
+    bq = bigquery_tools.BigQueryWrapper()
+    table = bq.get_table(project, dataset, table)
+    return table.numBytes
+
+  def display_data(self):
+    return {
+        'project': str(self.project),
+        'dataset': str(self.dataset),
+        'table': str(self.table),
+        'selected_fields': str(self.selected_fields),
+        'row_restriction': str(self.row_restriction)
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    requested_session = bq_storage.types.ReadSession()
+    requested_session.table = 'projects/{}/datasets/{}/tables/{}'.format(
+        self.project, self.dataset, self.table)
+    requested_session.data_format = bq_storage.types.DataFormat.AVRO
+    if self.selected_fields is not None:
+      requested_session.read_options.selected_fields = self.selected_fields
+    if self.row_restriction is not None:
+      requested_session.read_options.row_restriction = self.row_restriction
+
+    storage_client = bq_storage.BigQueryReadClient()
+    stream_count = 0
+    if (desired_bundle_size > 0):
+      table_size = self._get_table_size(self.table, self.dataset, self.project)
+      stream_count = min(
+          int(table_size / desired_bundle_size),
+          _CustomBigQueryStorageSourceBase.MAX_SPLIT_COUNT)
+    stream_count = max(
+        stream_count, _CustomBigQueryStorageSourceBase.MIN_SPLIT_COUNT)
+
+    parent = 'projects/{}'.format(self.project)
+    read_session = storage_client.create_read_session(
+        parent=parent,
+        read_session=requested_session,
+        max_stream_count=stream_count)
+
+    self.split_result = [
+        _CustomBigQueryStorageStreamSource(stream.name)
+        for stream in read_session.streams
+    ]
+
+    for source in self.split_result:
+      yield SourceBundle(
+          weight=1.0, source=source, start_position=None, stop_position=None)
+
+  def get_range_tracker(self, start_position, stop_position):
+    class NonePositionRangeTracker(RangeTracker):
+      """A RangeTracker that always returns positions as None. Prevents the
+      BigQuery Storage source from being read() before being split()."""
+      def start_position(self):
+        return None
+
+      def stop_position(self):
+        return None
+
+    return NonePositionRangeTracker()
+
+  def read(self, range_tracker):
+    raise NotImplementedError(
+        'BigQuery storage source must be split before being read')
+
+
+class _CustomBigQueryStorageStreamSource(BoundedSource):
+  """A source representing a single stream in a read session."""
+  def __init__(self, read_stream_name: str):
+    self.read_stream_name = read_stream_name
+
+  def display_data(self):
+    return {
+        'read_stream': str(self.read_stream_name),
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid

Review comment:
       Are we hoping to implement this as some point ? Seems like there's a TODO for corresponding Java implementation: https://github.com/apache/beam/blob/dce846b36a4fb9140c4c5d14e10b72f835f03d98/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java#L124

##########
File path: sdks/python/apache_beam/io/gcp/bigquery.py
##########
@@ -883,6 +895,221 @@ def _export_files(self, bq):
     return table.schema, metadata_list
 
 
+class _CustomBigQueryStorageSourceBase(BoundedSource):
+  """A base class for BoundedSource implementations which read from BigQuery
+  using the BigQuery Storage API.
+
+  Args:
+    table (str, TableReference): The ID of the table. The ID must contain only
+      letters ``a-z``, ``A-Z``, numbers ``0-9``, or underscores ``_``  If
+      **dataset** argument is :data:`None` then the table argument must
+      contain the entire table reference specified as:
+      ``'PROJECT:DATASET.TABLE'`` or must specify a TableReference.
+    dataset (str): The ID of the dataset containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    project (str): The ID of the project containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    selected_fields (List[str]): Names of the fields in the table that should be
+      read. If empty, all fields will be read. If the specified field is a
+      nested field, all the sub-fields in the field will be selected. The output
+      field order is unrelated to the order of fields in selected_fields.
+    row_restriction (str): SQL text filtering statement, similar to a WHERE
+      clause in a query. Aggregates are not supported.Restricted to a maximum
+      length for 1 MB.
+  """
+
+  # The maximum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size.
+  MAX_SPLIT_COUNT = 10000
+  # The minimum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size. Note that the server may
+  # still choose to return fewer than ten streams based on the layout of the
+  # table.
+  MIN_SPLIT_COUNT = 10
+
+  def __init__(
+      self,
+      table: Union[str, TableReference],
+      dataset: str = None,
+      project: str = None,
+      selected_fields: List[str] = None,
+      row_restriction: str = None,
+      pipeline_options: GoogleCloudOptions = None):
+
+    self.table_reference = bigquery_tools.parse_table_reference(
+        table, dataset, project)
+    self.project = self.table_reference.projectId
+    self.dataset = self.table_reference.datasetId
+    self.table = self.table_reference.tableId
+    self.selected_fields = selected_fields
+    self.row_restriction = row_restriction
+    self.pipeline_options = pipeline_options
+    self.split_result = None
+
+  def _get_parent_project(self):
+    """Returns the project that will be billed."""
+    project = self.pipeline_options.view_as(GoogleCloudOptions).project
+    if isinstance(project, vp.ValueProvider):
+      project = project.get()
+    if not project:
+      project = self.project
+    return project
+
+  def _get_table_size(self, table, dataset, project):
+    if project is None:
+      project = self._get_parent_project()
+
+    bq = bigquery_tools.BigQueryWrapper()
+    table = bq.get_table(project, dataset, table)
+    return table.numBytes
+
+  def display_data(self):
+    return {
+        'project': str(self.project),
+        'dataset': str(self.dataset),
+        'table': str(self.table),
+        'selected_fields': str(self.selected_fields),
+        'row_restriction': str(self.row_restriction)
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    requested_session = bq_storage.types.ReadSession()
+    requested_session.table = 'projects/{}/datasets/{}/tables/{}'.format(
+        self.project, self.dataset, self.table)
+    requested_session.data_format = bq_storage.types.DataFormat.AVRO
+    if self.selected_fields is not None:
+      requested_session.read_options.selected_fields = self.selected_fields
+    if self.row_restriction is not None:
+      requested_session.read_options.row_restriction = self.row_restriction
+
+    storage_client = bq_storage.BigQueryReadClient()
+    stream_count = 0
+    if (desired_bundle_size > 0):
+      table_size = self._get_table_size(self.table, self.dataset, self.project)
+      stream_count = min(
+          int(table_size / desired_bundle_size),
+          _CustomBigQueryStorageSourceBase.MAX_SPLIT_COUNT)
+    stream_count = max(
+        stream_count, _CustomBigQueryStorageSourceBase.MIN_SPLIT_COUNT)
+
+    parent = 'projects/{}'.format(self.project)
+    read_session = storage_client.create_read_session(
+        parent=parent,
+        read_session=requested_session,
+        max_stream_count=stream_count)
+
+    self.split_result = [
+        _CustomBigQueryStorageStreamSource(stream.name)
+        for stream in read_session.streams
+    ]
+
+    for source in self.split_result:
+      yield SourceBundle(
+          weight=1.0, source=source, start_position=None, stop_position=None)
+
+  def get_range_tracker(self, start_position, stop_position):
+    class NonePositionRangeTracker(RangeTracker):
+      """A RangeTracker that always returns positions as None. Prevents the
+      BigQuery Storage source from being read() before being split()."""
+      def start_position(self):
+        return None
+
+      def stop_position(self):
+        return None
+
+    return NonePositionRangeTracker()
+
+  def read(self, range_tracker):
+    raise NotImplementedError(
+        'BigQuery storage source must be split before being read')
+
+
+class _CustomBigQueryStorageStreamSource(BoundedSource):
+  """A source representing a single stream in a read session."""
+  def __init__(self, read_stream_name: str):
+    self.read_stream_name = read_stream_name
+
+  def display_data(self):
+    return {
+        'read_stream': str(self.read_stream_name),
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    # A stream source can't be split without reading from it due to
+    # server-side liquid sharding.
+    raise NotImplementedError('BigQuery storage stream source cannot be split.')

Review comment:
       I don't think you can prevent a runner from trying to split the source (and reading the splits). May be just return a single split that is same as the current source ?

##########
File path: sdks/python/apache_beam/io/gcp/bigquery.py
##########
@@ -883,6 +895,221 @@ def _export_files(self, bq):
     return table.schema, metadata_list
 
 
+class _CustomBigQueryStorageSourceBase(BoundedSource):
+  """A base class for BoundedSource implementations which read from BigQuery
+  using the BigQuery Storage API.
+
+  Args:
+    table (str, TableReference): The ID of the table. The ID must contain only
+      letters ``a-z``, ``A-Z``, numbers ``0-9``, or underscores ``_``  If
+      **dataset** argument is :data:`None` then the table argument must
+      contain the entire table reference specified as:
+      ``'PROJECT:DATASET.TABLE'`` or must specify a TableReference.
+    dataset (str): The ID of the dataset containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    project (str): The ID of the project containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    selected_fields (List[str]): Names of the fields in the table that should be
+      read. If empty, all fields will be read. If the specified field is a
+      nested field, all the sub-fields in the field will be selected. The output
+      field order is unrelated to the order of fields in selected_fields.
+    row_restriction (str): SQL text filtering statement, similar to a WHERE
+      clause in a query. Aggregates are not supported.Restricted to a maximum
+      length for 1 MB.
+  """
+
+  # The maximum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size.
+  MAX_SPLIT_COUNT = 10000
+  # The minimum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size. Note that the server may
+  # still choose to return fewer than ten streams based on the layout of the
+  # table.
+  MIN_SPLIT_COUNT = 10
+
+  def __init__(
+      self,
+      table: Union[str, TableReference],
+      dataset: str = None,
+      project: str = None,
+      selected_fields: List[str] = None,
+      row_restriction: str = None,
+      pipeline_options: GoogleCloudOptions = None):
+
+    self.table_reference = bigquery_tools.parse_table_reference(
+        table, dataset, project)
+    self.project = self.table_reference.projectId
+    self.dataset = self.table_reference.datasetId
+    self.table = self.table_reference.tableId
+    self.selected_fields = selected_fields
+    self.row_restriction = row_restriction
+    self.pipeline_options = pipeline_options
+    self.split_result = None
+
+  def _get_parent_project(self):
+    """Returns the project that will be billed."""
+    project = self.pipeline_options.view_as(GoogleCloudOptions).project
+    if isinstance(project, vp.ValueProvider):
+      project = project.get()
+    if not project:
+      project = self.project
+    return project
+
+  def _get_table_size(self, table, dataset, project):
+    if project is None:
+      project = self._get_parent_project()
+
+    bq = bigquery_tools.BigQueryWrapper()
+    table = bq.get_table(project, dataset, table)
+    return table.numBytes
+
+  def display_data(self):
+    return {
+        'project': str(self.project),
+        'dataset': str(self.dataset),
+        'table': str(self.table),
+        'selected_fields': str(self.selected_fields),
+        'row_restriction': str(self.row_restriction)
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    requested_session = bq_storage.types.ReadSession()
+    requested_session.table = 'projects/{}/datasets/{}/tables/{}'.format(
+        self.project, self.dataset, self.table)
+    requested_session.data_format = bq_storage.types.DataFormat.AVRO
+    if self.selected_fields is not None:
+      requested_session.read_options.selected_fields = self.selected_fields
+    if self.row_restriction is not None:
+      requested_session.read_options.row_restriction = self.row_restriction
+
+    storage_client = bq_storage.BigQueryReadClient()
+    stream_count = 0
+    if (desired_bundle_size > 0):
+      table_size = self._get_table_size(self.table, self.dataset, self.project)
+      stream_count = min(
+          int(table_size / desired_bundle_size),
+          _CustomBigQueryStorageSourceBase.MAX_SPLIT_COUNT)
+    stream_count = max(
+        stream_count, _CustomBigQueryStorageSourceBase.MIN_SPLIT_COUNT)
+
+    parent = 'projects/{}'.format(self.project)
+    read_session = storage_client.create_read_session(
+        parent=parent,
+        read_session=requested_session,
+        max_stream_count=stream_count)
+
+    self.split_result = [
+        _CustomBigQueryStorageStreamSource(stream.name)
+        for stream in read_session.streams
+    ]
+
+    for source in self.split_result:
+      yield SourceBundle(
+          weight=1.0, source=source, start_position=None, stop_position=None)
+
+  def get_range_tracker(self, start_position, stop_position):
+    class NonePositionRangeTracker(RangeTracker):
+      """A RangeTracker that always returns positions as None. Prevents the
+      BigQuery Storage source from being read() before being split()."""
+      def start_position(self):
+        return None
+
+      def stop_position(self):
+        return None
+
+    return NonePositionRangeTracker()
+
+  def read(self, range_tracker):
+    raise NotImplementedError(
+        'BigQuery storage source must be split before being read')
+
+
+class _CustomBigQueryStorageStreamSource(BoundedSource):
+  """A source representing a single stream in a read session."""
+  def __init__(self, read_stream_name: str):
+    self.read_stream_name = read_stream_name
+
+  def display_data(self):
+    return {
+        'read_stream': str(self.read_stream_name),
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    # A stream source can't be split without reading from it due to
+    # server-side liquid sharding.
+    raise NotImplementedError('BigQuery storage stream source cannot be split.')
+
+  def get_range_tracker(self, start_position, stop_position):
+    if start_position is None:

Review comment:
       Seems like this will always be None, right ? May be cleaner to assert and always set start_position to zero.

##########
File path: sdks/python/apache_beam/io/gcp/bigquery.py
##########
@@ -883,6 +895,221 @@ def _export_files(self, bq):
     return table.schema, metadata_list
 
 
+class _CustomBigQueryStorageSourceBase(BoundedSource):
+  """A base class for BoundedSource implementations which read from BigQuery
+  using the BigQuery Storage API.
+
+  Args:
+    table (str, TableReference): The ID of the table. The ID must contain only
+      letters ``a-z``, ``A-Z``, numbers ``0-9``, or underscores ``_``  If
+      **dataset** argument is :data:`None` then the table argument must
+      contain the entire table reference specified as:
+      ``'PROJECT:DATASET.TABLE'`` or must specify a TableReference.
+    dataset (str): The ID of the dataset containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    project (str): The ID of the project containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    selected_fields (List[str]): Names of the fields in the table that should be
+      read. If empty, all fields will be read. If the specified field is a
+      nested field, all the sub-fields in the field will be selected. The output
+      field order is unrelated to the order of fields in selected_fields.
+    row_restriction (str): SQL text filtering statement, similar to a WHERE
+      clause in a query. Aggregates are not supported.Restricted to a maximum
+      length for 1 MB.
+  """
+
+  # The maximum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size.
+  MAX_SPLIT_COUNT = 10000
+  # The minimum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size. Note that the server may
+  # still choose to return fewer than ten streams based on the layout of the
+  # table.
+  MIN_SPLIT_COUNT = 10
+
+  def __init__(
+      self,
+      table: Union[str, TableReference],
+      dataset: str = None,
+      project: str = None,
+      selected_fields: List[str] = None,
+      row_restriction: str = None,
+      pipeline_options: GoogleCloudOptions = None):
+
+    self.table_reference = bigquery_tools.parse_table_reference(
+        table, dataset, project)
+    self.project = self.table_reference.projectId
+    self.dataset = self.table_reference.datasetId
+    self.table = self.table_reference.tableId
+    self.selected_fields = selected_fields
+    self.row_restriction = row_restriction
+    self.pipeline_options = pipeline_options
+    self.split_result = None
+
+  def _get_parent_project(self):
+    """Returns the project that will be billed."""
+    project = self.pipeline_options.view_as(GoogleCloudOptions).project
+    if isinstance(project, vp.ValueProvider):
+      project = project.get()
+    if not project:
+      project = self.project
+    return project
+
+  def _get_table_size(self, table, dataset, project):
+    if project is None:
+      project = self._get_parent_project()
+
+    bq = bigquery_tools.BigQueryWrapper()
+    table = bq.get_table(project, dataset, table)
+    return table.numBytes
+
+  def display_data(self):
+    return {
+        'project': str(self.project),
+        'dataset': str(self.dataset),
+        'table': str(self.table),
+        'selected_fields': str(self.selected_fields),
+        'row_restriction': str(self.row_restriction)
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    requested_session = bq_storage.types.ReadSession()
+    requested_session.table = 'projects/{}/datasets/{}/tables/{}'.format(
+        self.project, self.dataset, self.table)
+    requested_session.data_format = bq_storage.types.DataFormat.AVRO
+    if self.selected_fields is not None:
+      requested_session.read_options.selected_fields = self.selected_fields
+    if self.row_restriction is not None:
+      requested_session.read_options.row_restriction = self.row_restriction
+
+    storage_client = bq_storage.BigQueryReadClient()
+    stream_count = 0
+    if (desired_bundle_size > 0):
+      table_size = self._get_table_size(self.table, self.dataset, self.project)
+      stream_count = min(
+          int(table_size / desired_bundle_size),
+          _CustomBigQueryStorageSourceBase.MAX_SPLIT_COUNT)
+    stream_count = max(
+        stream_count, _CustomBigQueryStorageSourceBase.MIN_SPLIT_COUNT)
+
+    parent = 'projects/{}'.format(self.project)
+    read_session = storage_client.create_read_session(
+        parent=parent,
+        read_session=requested_session,
+        max_stream_count=stream_count)
+
+    self.split_result = [
+        _CustomBigQueryStorageStreamSource(stream.name)
+        for stream in read_session.streams
+    ]
+
+    for source in self.split_result:
+      yield SourceBundle(
+          weight=1.0, source=source, start_position=None, stop_position=None)
+
+  def get_range_tracker(self, start_position, stop_position):
+    class NonePositionRangeTracker(RangeTracker):
+      """A RangeTracker that always returns positions as None. Prevents the
+      BigQuery Storage source from being read() before being split()."""
+      def start_position(self):
+        return None
+
+      def stop_position(self):
+        return None
+
+    return NonePositionRangeTracker()
+
+  def read(self, range_tracker):
+    raise NotImplementedError(
+        'BigQuery storage source must be split before being read')
+
+
+class _CustomBigQueryStorageStreamSource(BoundedSource):
+  """A source representing a single stream in a read session."""
+  def __init__(self, read_stream_name: str):
+    self.read_stream_name = read_stream_name
+
+  def display_data(self):
+    return {
+        'read_stream': str(self.read_stream_name),
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    # A stream source can't be split without reading from it due to
+    # server-side liquid sharding.
+    raise NotImplementedError('BigQuery storage stream source cannot be split.')
+
+  def get_range_tracker(self, start_position, stop_position):
+    if start_position is None:
+      # Defaulting to the start of the stream.
+      start_position = 0
+    # Since the streams are unsplittable we choose OFFSET_INFINITY as the
+    # default end offset so that all data of the source gets read.
+    stop_position = range_trackers.OffsetRangeTracker.OFFSET_INFINITY
+    range_tracker = range_trackers.OffsetRangeTracker(
+        start_position, stop_position)
+    # Ensuring that all try_split() calls will be ignored by the Rangetracker.
+    range_tracker = range_trackers.UnsplittableRangeTracker(range_tracker)

Review comment:
       This disables dynamic work rebalancing. But Java storage API source does support dynamic work rebalancing: https://github.com/apache/beam/blob/dce846b36a4fb9140c4c5d14e10b72f835f03d98/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java#L293
   
   Can we support dynamic work rebalancing for Python as well ?

##########
File path: sdks/python/apache_beam/io/gcp/bigquery.py
##########
@@ -883,6 +895,221 @@ def _export_files(self, bq):
     return table.schema, metadata_list
 
 
+class _CustomBigQueryStorageSourceBase(BoundedSource):
+  """A base class for BoundedSource implementations which read from BigQuery
+  using the BigQuery Storage API.
+
+  Args:
+    table (str, TableReference): The ID of the table. The ID must contain only
+      letters ``a-z``, ``A-Z``, numbers ``0-9``, or underscores ``_``  If
+      **dataset** argument is :data:`None` then the table argument must
+      contain the entire table reference specified as:
+      ``'PROJECT:DATASET.TABLE'`` or must specify a TableReference.
+    dataset (str): The ID of the dataset containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    project (str): The ID of the project containing this table or
+      :data:`None` if the table argument specifies a TableReference.
+    selected_fields (List[str]): Names of the fields in the table that should be
+      read. If empty, all fields will be read. If the specified field is a
+      nested field, all the sub-fields in the field will be selected. The output
+      field order is unrelated to the order of fields in selected_fields.
+    row_restriction (str): SQL text filtering statement, similar to a WHERE
+      clause in a query. Aggregates are not supported.Restricted to a maximum
+      length for 1 MB.
+  """
+
+  # The maximum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size.
+  MAX_SPLIT_COUNT = 10000
+  # The minimum number of streams which will be requested when creating a read
+  # session, regardless of the desired bundle size. Note that the server may
+  # still choose to return fewer than ten streams based on the layout of the
+  # table.
+  MIN_SPLIT_COUNT = 10
+
+  def __init__(
+      self,
+      table: Union[str, TableReference],
+      dataset: str = None,
+      project: str = None,
+      selected_fields: List[str] = None,
+      row_restriction: str = None,
+      pipeline_options: GoogleCloudOptions = None):
+
+    self.table_reference = bigquery_tools.parse_table_reference(
+        table, dataset, project)
+    self.project = self.table_reference.projectId
+    self.dataset = self.table_reference.datasetId
+    self.table = self.table_reference.tableId
+    self.selected_fields = selected_fields
+    self.row_restriction = row_restriction
+    self.pipeline_options = pipeline_options
+    self.split_result = None
+
+  def _get_parent_project(self):
+    """Returns the project that will be billed."""
+    project = self.pipeline_options.view_as(GoogleCloudOptions).project
+    if isinstance(project, vp.ValueProvider):
+      project = project.get()
+    if not project:
+      project = self.project
+    return project
+
+  def _get_table_size(self, table, dataset, project):
+    if project is None:
+      project = self._get_parent_project()
+
+    bq = bigquery_tools.BigQueryWrapper()
+    table = bq.get_table(project, dataset, table)
+    return table.numBytes
+
+  def display_data(self):
+    return {
+        'project': str(self.project),
+        'dataset': str(self.dataset),
+        'table': str(self.table),
+        'selected_fields': str(self.selected_fields),
+        'row_restriction': str(self.row_restriction)
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    requested_session = bq_storage.types.ReadSession()
+    requested_session.table = 'projects/{}/datasets/{}/tables/{}'.format(
+        self.project, self.dataset, self.table)
+    requested_session.data_format = bq_storage.types.DataFormat.AVRO
+    if self.selected_fields is not None:
+      requested_session.read_options.selected_fields = self.selected_fields
+    if self.row_restriction is not None:
+      requested_session.read_options.row_restriction = self.row_restriction
+
+    storage_client = bq_storage.BigQueryReadClient()
+    stream_count = 0
+    if (desired_bundle_size > 0):
+      table_size = self._get_table_size(self.table, self.dataset, self.project)
+      stream_count = min(
+          int(table_size / desired_bundle_size),
+          _CustomBigQueryStorageSourceBase.MAX_SPLIT_COUNT)
+    stream_count = max(
+        stream_count, _CustomBigQueryStorageSourceBase.MIN_SPLIT_COUNT)
+
+    parent = 'projects/{}'.format(self.project)
+    read_session = storage_client.create_read_session(
+        parent=parent,
+        read_session=requested_session,
+        max_stream_count=stream_count)
+
+    self.split_result = [
+        _CustomBigQueryStorageStreamSource(stream.name)
+        for stream in read_session.streams
+    ]
+
+    for source in self.split_result:
+      yield SourceBundle(
+          weight=1.0, source=source, start_position=None, stop_position=None)
+
+  def get_range_tracker(self, start_position, stop_position):
+    class NonePositionRangeTracker(RangeTracker):
+      """A RangeTracker that always returns positions as None. Prevents the
+      BigQuery Storage source from being read() before being split()."""
+      def start_position(self):
+        return None
+
+      def stop_position(self):
+        return None
+
+    return NonePositionRangeTracker()
+
+  def read(self, range_tracker):
+    raise NotImplementedError(
+        'BigQuery storage source must be split before being read')
+
+
+class _CustomBigQueryStorageStreamSource(BoundedSource):
+  """A source representing a single stream in a read session."""
+  def __init__(self, read_stream_name: str):
+    self.read_stream_name = read_stream_name
+
+  def display_data(self):
+    return {
+        'read_stream': str(self.read_stream_name),
+    }
+
+  def estimate_size(self):
+    # The size of stream source cannot be estimate due to server-side liquid
+    # sharding
+    return None
+
+  def split(self, desired_bundle_size, start_position=None, stop_position=None):
+    # A stream source can't be split without reading from it due to
+    # server-side liquid sharding.
+    raise NotImplementedError('BigQuery storage stream source cannot be split.')
+
+  def get_range_tracker(self, start_position, stop_position):
+    if start_position is None:
+      # Defaulting to the start of the stream.
+      start_position = 0
+    # Since the streams are unsplittable we choose OFFSET_INFINITY as the
+    # default end offset so that all data of the source gets read.
+    stop_position = range_trackers.OffsetRangeTracker.OFFSET_INFINITY
+    range_tracker = range_trackers.OffsetRangeTracker(
+        start_position, stop_position)
+    # Ensuring that all try_split() calls will be ignored by the Rangetracker.
+    range_tracker = range_trackers.UnsplittableRangeTracker(range_tracker)
+
+    return range_tracker
+
+  def read(self, range_tracker):
+    storage_client = bq_storage.BigQueryReadClient()
+    read_rows_iterator = iter(storage_client.read_rows(self.read_stream_name))
+    # Handling the case where the user might provide very selective filters
+    # which can result in read_rows_response being empty.
+    first_read_rows_response = next(read_rows_iterator, None)
+    if first_read_rows_response is None:
+      return iter([])
+    row_reader = _ReadRowsResponseReader(

Review comment:
       Might be simpler to implement this as a generator instead of introducing a new class.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org