You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jb...@apache.org on 2017/07/20 17:09:31 UTC
[04/28] beam git commit: Revert "[BEAM-2610] This closes #3553"
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/gcp/pubsub.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py
index 32d388a..1ba8ac0 100644
--- a/sdks/python/apache_beam/io/gcp/pubsub.py
+++ b/sdks/python/apache_beam/io/gcp/pubsub.py
@@ -24,35 +24,29 @@ This API is currently under development and is subject to change.
from __future__ import absolute_import
-import re
-
from apache_beam import coders
from apache_beam.io.iobase import Read
from apache_beam.io.iobase import Write
from apache_beam.runners.dataflow.native_io import iobase as dataflow_io
-from apache_beam.transforms import core
from apache_beam.transforms import PTransform
-from apache_beam.transforms import Map
-from apache_beam.transforms import window
+from apache_beam.transforms import ParDo
from apache_beam.transforms.display import DisplayDataItem
-__all__ = ['ReadStringsFromPubSub', 'WriteStringsToPubSub']
+__all__ = ['ReadStringsFromPubSub', 'WriteStringsToPubSub',
+ 'PubSubSource', 'PubSubSink']
class ReadStringsFromPubSub(PTransform):
"""A ``PTransform`` for reading utf-8 string payloads from Cloud Pub/Sub."""
- def __init__(self, topic=None, subscription=None, id_label=None):
+ def __init__(self, topic, subscription=None, id_label=None):
"""Initializes ``ReadStringsFromPubSub``.
Attributes:
- topic: Cloud Pub/Sub topic in the form "projects/<project>/topics/
- <topic>". If provided, subscription must be None.
- subscription: Existing Cloud Pub/Sub subscription to use in the
- form "projects/<project>/subscriptions/<subscription>". If not
- specified, a temporary subscription will be created from the specified
- topic. If provided, topic must be None.
+ topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>".
+ subscription: Optional existing Cloud Pub/Sub subscription to use in the
+ form "projects/<project>/subscriptions/<subscription>".
id_label: The attribute on incoming Pub/Sub messages to use as a unique
record identifier. When specified, the value of this attribute (which
can be any string that uniquely identifies the record) will be used for
@@ -66,13 +60,10 @@ class ReadStringsFromPubSub(PTransform):
subscription=subscription,
id_label=id_label)
- def get_windowing(self, unused_inputs):
- return core.Windowing(window.GlobalWindows())
-
def expand(self, pvalue):
pcoll = pvalue.pipeline | Read(self._source)
pcoll.element_type = bytes
- pcoll = pcoll | 'DecodeString' >> Map(lambda b: b.decode('utf-8'))
+ pcoll = pcoll | 'decode string' >> ParDo(_decodeUtf8String)
pcoll.element_type = unicode
return pcoll
@@ -90,50 +81,18 @@ class WriteStringsToPubSub(PTransform):
self._sink = _PubSubPayloadSink(topic)
def expand(self, pcoll):
- pcoll = pcoll | 'EncodeString' >> Map(lambda s: s.encode('utf-8'))
+ pcoll = pcoll | 'encode string' >> ParDo(_encodeUtf8String)
pcoll.element_type = bytes
return pcoll | Write(self._sink)
-PROJECT_ID_REGEXP = '[a-z][-a-z0-9:.]{4,61}[a-z0-9]'
-SUBSCRIPTION_REGEXP = 'projects/([^/]+)/subscriptions/(.+)'
-TOPIC_REGEXP = 'projects/([^/]+)/topics/(.+)'
-
-
-def parse_topic(full_topic):
- match = re.match(TOPIC_REGEXP, full_topic)
- if not match:
- raise ValueError(
- 'PubSub topic must be in the form "projects/<project>/topics'
- '/<topic>" (got %r).' % full_topic)
- project, topic_name = match.group(1), match.group(2)
- if not re.match(PROJECT_ID_REGEXP, project):
- raise ValueError('Invalid PubSub project name: %r.' % project)
- return project, topic_name
-
-
-def parse_subscription(full_subscription):
- match = re.match(SUBSCRIPTION_REGEXP, full_subscription)
- if not match:
- raise ValueError(
- 'PubSub subscription must be in the form "projects/<project>'
- '/subscriptions/<subscription>" (got %r).' % full_subscription)
- project, subscription_name = match.group(1), match.group(2)
- if not re.match(PROJECT_ID_REGEXP, project):
- raise ValueError('Invalid PubSub project name: %r.' % project)
- return project, subscription_name
-
-
class _PubSubPayloadSource(dataflow_io.NativeSource):
"""Source for the payload of a message as bytes from a Cloud Pub/Sub topic.
Attributes:
- topic: Cloud Pub/Sub topic in the form "projects/<project>/topics/<topic>".
- If provided, subscription must be None.
- subscription: Existing Cloud Pub/Sub subscription to use in the
- form "projects/<project>/subscriptions/<subscription>". If not specified,
- a temporary subscription will be created from the specified topic. If
- provided, topic must be None.
+ topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>".
+ subscription: Optional existing Cloud Pub/Sub subscription to use in the
+ form "projects/<project>/subscriptions/<subscription>".
id_label: The attribute on incoming Pub/Sub messages to use as a unique
record identifier. When specified, the value of this attribute (which can
be any string that uniquely identifies the record) will be used for
@@ -142,27 +101,11 @@ class _PubSubPayloadSource(dataflow_io.NativeSource):
case, deduplication of the stream will be strictly best effort.
"""
- def __init__(self, topic=None, subscription=None, id_label=None):
- # We are using this coder explicitly for portability reasons of PubsubIO
- # across implementations in languages.
- self.coder = coders.BytesCoder()
- self.full_topic = topic
- self.full_subscription = subscription
- self.topic_name = None
- self.subscription_name = None
+ def __init__(self, topic, subscription=None, id_label=None):
+ self.topic = topic
+ self.subscription = subscription
self.id_label = id_label
- # Perform some validation on the topic and subscription.
- if not (topic or subscription):
- raise ValueError('Either a topic or subscription must be provided.')
- if topic and subscription:
- raise ValueError('Only one of topic or subscription should be provided.')
-
- if topic:
- self.project, self.topic_name = parse_topic(topic)
- if subscription:
- self.project, self.subscription_name = parse_subscription(subscription)
-
@property
def format(self):
"""Source format name required for remote execution."""
@@ -173,10 +116,10 @@ class _PubSubPayloadSource(dataflow_io.NativeSource):
DisplayDataItem(self.id_label,
label='ID Label Attribute').drop_if_none(),
'topic':
- DisplayDataItem(self.full_topic,
- label='Pubsub Topic').drop_if_none(),
+ DisplayDataItem(self.topic,
+ label='Pubsub Topic'),
'subscription':
- DisplayDataItem(self.full_subscription,
+ DisplayDataItem(self.subscription,
label='Pubsub Subscription').drop_if_none()}
def reader(self):
@@ -188,12 +131,7 @@ class _PubSubPayloadSink(dataflow_io.NativeSink):
"""Sink for the payload of a message as bytes to a Cloud Pub/Sub topic."""
def __init__(self, topic):
- # we are using this coder explicitly for portability reasons of PubsubIO
- # across implementations in languages.
- self.coder = coders.BytesCoder()
- self.full_topic = topic
-
- self.project, self.topic_name = parse_topic(topic)
+ self.topic = topic
@property
def format(self):
@@ -201,8 +139,86 @@ class _PubSubPayloadSink(dataflow_io.NativeSink):
return 'pubsub'
def display_data(self):
- return {'topic': DisplayDataItem(self.full_topic, label='Pubsub Topic')}
+ return {'topic': DisplayDataItem(self.topic, label='Pubsub Topic')}
def writer(self):
raise NotImplementedError(
'PubSubPayloadSink is not supported in local execution.')
+
+
+def _decodeUtf8String(encoded_value):
+ """Decodes a string in utf-8 format from bytes"""
+ return encoded_value.decode('utf-8')
+
+
+def _encodeUtf8String(value):
+ """Encodes a string in utf-8 format to bytes"""
+ return value.encode('utf-8')
+
+
+class PubSubSource(dataflow_io.NativeSource):
+ """Deprecated: do not use.
+
+ Source for reading from a given Cloud Pub/Sub topic.
+
+ Attributes:
+ topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>".
+ subscription: Optional existing Cloud Pub/Sub subscription to use in the
+ form "projects/<project>/subscriptions/<subscription>".
+ id_label: The attribute on incoming Pub/Sub messages to use as a unique
+ record identifier. When specified, the value of this attribute (which can
+ be any string that uniquely identifies the record) will be used for
+ deduplication of messages. If not provided, Dataflow cannot guarantee
+ that no duplicate data will be delivered on the Pub/Sub stream. In this
+ case, deduplication of the stream will be strictly best effort.
+ coder: The Coder to use for decoding incoming Pub/Sub messages.
+ """
+
+ def __init__(self, topic, subscription=None, id_label=None,
+ coder=coders.StrUtf8Coder()):
+ self.topic = topic
+ self.subscription = subscription
+ self.id_label = id_label
+ self.coder = coder
+
+ @property
+ def format(self):
+ """Source format name required for remote execution."""
+ return 'pubsub'
+
+ def display_data(self):
+ return {'id_label':
+ DisplayDataItem(self.id_label,
+ label='ID Label Attribute').drop_if_none(),
+ 'topic':
+ DisplayDataItem(self.topic,
+ label='Pubsub Topic'),
+ 'subscription':
+ DisplayDataItem(self.subscription,
+ label='Pubsub Subscription').drop_if_none()}
+
+ def reader(self):
+ raise NotImplementedError(
+ 'PubSubSource is not supported in local execution.')
+
+
+class PubSubSink(dataflow_io.NativeSink):
+ """Deprecated: do not use.
+
+ Sink for writing to a given Cloud Pub/Sub topic."""
+
+ def __init__(self, topic, coder=coders.StrUtf8Coder()):
+ self.topic = topic
+ self.coder = coder
+
+ @property
+ def format(self):
+ """Sink format name required for remote execution."""
+ return 'pubsub'
+
+ def display_data(self):
+ return {'topic': DisplayDataItem(self.topic, label='Pubsub Topic')}
+
+ def writer(self):
+ raise NotImplementedError(
+ 'PubSubSink is not supported in local execution.')
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/gcp/pubsub_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py
index 0dcc3c3..322d08a 100644
--- a/sdks/python/apache_beam/io/gcp/pubsub_test.py
+++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py
@@ -22,6 +22,8 @@ import unittest
import hamcrest as hc
+from apache_beam.io.gcp.pubsub import _decodeUtf8String
+from apache_beam.io.gcp.pubsub import _encodeUtf8String
from apache_beam.io.gcp.pubsub import _PubSubPayloadSink
from apache_beam.io.gcp.pubsub import _PubSubPayloadSource
from apache_beam.io.gcp.pubsub import ReadStringsFromPubSub
@@ -31,112 +33,77 @@ from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher
-# Protect against environments where the PubSub library is not available.
-# pylint: disable=wrong-import-order, wrong-import-position
-try:
- from google.cloud import pubsub
-except ImportError:
- pubsub = None
-# pylint: enable=wrong-import-order, wrong-import-position
-
-
-@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
class TestReadStringsFromPubSub(unittest.TestCase):
- def test_expand_with_topic(self):
+ def test_expand(self):
p = TestPipeline()
- pcoll = p | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic',
- None, 'a_label')
+ pcoll = p | ReadStringsFromPubSub('a_topic', 'a_subscription', 'a_label')
# Ensure that the output type is str
self.assertEqual(unicode, pcoll.element_type)
- # Ensure that the properties passed through correctly
- source = pcoll.producer.transform._source
- self.assertEqual('a_topic', source.topic_name)
- self.assertEqual('a_label', source.id_label)
-
- def test_expand_with_subscription(self):
- p = TestPipeline()
- pcoll = p | ReadStringsFromPubSub(
- None, 'projects/fakeprj/subscriptions/a_subscription', 'a_label')
- # Ensure that the output type is str
- self.assertEqual(unicode, pcoll.element_type)
+ # Ensure that the type on the intermediate read output PCollection is bytes
+ read_pcoll = pcoll.producer.inputs[0]
+ self.assertEqual(bytes, read_pcoll.element_type)
# Ensure that the properties passed through correctly
- source = pcoll.producer.transform._source
- self.assertEqual('a_subscription', source.subscription_name)
+ source = read_pcoll.producer.transform.source
+ self.assertEqual('a_topic', source.topic)
+ self.assertEqual('a_subscription', source.subscription)
self.assertEqual('a_label', source.id_label)
- def test_expand_with_no_topic_or_subscription(self):
- with self.assertRaisesRegexp(
- ValueError, "Either a topic or subscription must be provided."):
- ReadStringsFromPubSub(None, None, 'a_label')
-
- def test_expand_with_both_topic_and_subscription(self):
- with self.assertRaisesRegexp(
- ValueError, "Only one of topic or subscription should be provided."):
- ReadStringsFromPubSub('a_topic', 'a_subscription', 'a_label')
-
-@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
class TestWriteStringsToPubSub(unittest.TestCase):
def test_expand(self):
p = TestPipeline()
- pdone = (p
- | ReadStringsFromPubSub('projects/fakeprj/topics/baz')
- | WriteStringsToPubSub('projects/fakeprj/topics/a_topic'))
+ pdone = p | ReadStringsFromPubSub('baz') | WriteStringsToPubSub('a_topic')
# Ensure that the properties passed through correctly
- self.assertEqual('a_topic', pdone.producer.transform.dofn.topic_name)
+ sink = pdone.producer.transform.sink
+ self.assertEqual('a_topic', sink.topic)
+ # Ensure that the type on the intermediate payload transformer output
+ # PCollection is bytes
+ write_pcoll = pdone.producer.inputs[0]
+ self.assertEqual(bytes, write_pcoll.element_type)
-@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
-class TestPubSubSource(unittest.TestCase):
- def test_display_data_topic(self):
- source = _PubSubPayloadSource(
- 'projects/fakeprj/topics/a_topic',
- None,
- 'a_label')
- dd = DisplayData.create_from(source)
- expected_items = [
- DisplayDataItemMatcher(
- 'topic', 'projects/fakeprj/topics/a_topic'),
- DisplayDataItemMatcher('id_label', 'a_label')]
-
- hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
- def test_display_data_subscription(self):
- source = _PubSubPayloadSource(
- None,
- 'projects/fakeprj/subscriptions/a_subscription',
- 'a_label')
+class TestPubSubSource(unittest.TestCase):
+ def test_display_data(self):
+ source = _PubSubPayloadSource('a_topic', 'a_subscription', 'a_label')
dd = DisplayData.create_from(source)
expected_items = [
- DisplayDataItemMatcher(
- 'subscription', 'projects/fakeprj/subscriptions/a_subscription'),
+ DisplayDataItemMatcher('topic', 'a_topic'),
+ DisplayDataItemMatcher('subscription', 'a_subscription'),
DisplayDataItemMatcher('id_label', 'a_label')]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_display_data_no_subscription(self):
- source = _PubSubPayloadSource('projects/fakeprj/topics/a_topic')
+ source = _PubSubPayloadSource('a_topic')
dd = DisplayData.create_from(source)
expected_items = [
- DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic')]
+ DisplayDataItemMatcher('topic', 'a_topic')]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
-@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
class TestPubSubSink(unittest.TestCase):
def test_display_data(self):
- sink = _PubSubPayloadSink('projects/fakeprj/topics/a_topic')
+ sink = _PubSubPayloadSink('a_topic')
dd = DisplayData.create_from(sink)
expected_items = [
- DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic')]
+ DisplayDataItemMatcher('topic', 'a_topic')]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+class TestEncodeDecodeUtf8String(unittest.TestCase):
+ def test_encode(self):
+ self.assertEqual(b'test_data', _encodeUtf8String('test_data'))
+
+ def test_decode(self):
+ self.assertEqual('test_data', _decodeUtf8String(b'test_data'))
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py
index d6f0e97..844cbc5 100644
--- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py
+++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py
@@ -92,9 +92,9 @@ class BigqueryMatcher(BaseMatcher):
page_token = None
results = []
while True:
- for row in query.fetch_data(page_token=page_token):
- results.append(row)
- if results:
+ rows, _, page_token = query.fetch_data(page_token=page_token)
+ results.extend(rows)
+ if not page_token:
break
return results
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py
index 5b72285..f12293e 100644
--- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py
+++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py
@@ -53,7 +53,7 @@ class BigqueryMatcherTest(unittest.TestCase):
matcher = bq_verifier.BigqueryMatcher(
'mock_project',
'mock_query',
- '59f9d6bdee30d67ea73b8aded121c3a0280f9cd8')
+ 'da39a3ee5e6b4b0d3255bfef95601890afd80709')
hc_assert_that(self._mock_result, matcher)
@patch.object(bigquery, 'Client')
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/range_trackers.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py
index bef77d4..9cb36e7 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -193,6 +193,136 @@ class OffsetRangeTracker(iobase.RangeTracker):
self._split_points_unclaimed_callback = callback
+class GroupedShuffleRangeTracker(iobase.RangeTracker):
+ """For internal use only; no backwards-compatibility guarantees.
+
+ A 'RangeTracker' for positions used by'GroupedShuffleReader'.
+
+ These positions roughly correspond to hashes of keys. In case of hash
+ collisions, multiple groups can have the same position. In that case, the
+ first group at a particular position is considered a split point (because
+ it is the first to be returned when reading a position range starting at this
+ position), others are not.
+ """
+
+ def __init__(self, decoded_start_pos, decoded_stop_pos):
+ super(GroupedShuffleRangeTracker, self).__init__()
+ self._decoded_start_pos = decoded_start_pos
+ self._decoded_stop_pos = decoded_stop_pos
+ self._decoded_last_group_start = None
+ self._last_group_was_at_a_split_point = False
+ self._split_points_seen = 0
+ self._lock = threading.Lock()
+
+ def start_position(self):
+ return self._decoded_start_pos
+
+ def stop_position(self):
+ return self._decoded_stop_pos
+
+ def last_group_start(self):
+ return self._decoded_last_group_start
+
+ def _validate_decoded_group_start(self, decoded_group_start, split_point):
+ if self.start_position() and decoded_group_start < self.start_position():
+ raise ValueError('Trying to return record at %r which is before the'
+ ' starting position at %r' %
+ (decoded_group_start, self.start_position()))
+
+ if (self.last_group_start() and
+ decoded_group_start < self.last_group_start()):
+ raise ValueError('Trying to return group at %r which is before the'
+ ' last-returned group at %r' %
+ (decoded_group_start, self.last_group_start()))
+ if (split_point and self.last_group_start() and
+ self.last_group_start() == decoded_group_start):
+ raise ValueError('Trying to return a group at a split point with '
+ 'same position as the previous group: both at %r, '
+ 'last group was %sat a split point.' %
+ (decoded_group_start,
+ ('' if self._last_group_was_at_a_split_point
+ else 'not ')))
+ if not split_point:
+ if self.last_group_start() is None:
+ raise ValueError('The first group [at %r] must be at a split point' %
+ decoded_group_start)
+ if self.last_group_start() != decoded_group_start:
+ # This case is not a violation of general RangeTracker semantics, but it
+ # is contrary to how GroupingShuffleReader in particular works. Hitting
+ # it would mean it's behaving unexpectedly.
+ raise ValueError('Trying to return a group not at a split point, but '
+ 'with a different position than the previous group: '
+ 'last group was %r at %r, current at a %s split'
+ ' point.' %
+ (self.last_group_start()
+ , decoded_group_start
+ , ('' if self._last_group_was_at_a_split_point
+ else 'non-')))
+
+ def try_claim(self, decoded_group_start):
+ with self._lock:
+ self._validate_decoded_group_start(decoded_group_start, True)
+ if (self.stop_position()
+ and decoded_group_start >= self.stop_position()):
+ return False
+
+ self._decoded_last_group_start = decoded_group_start
+ self._last_group_was_at_a_split_point = True
+ self._split_points_seen += 1
+ return True
+
+ def set_current_position(self, decoded_group_start):
+ with self._lock:
+ self._validate_decoded_group_start(decoded_group_start, False)
+ self._decoded_last_group_start = decoded_group_start
+ self._last_group_was_at_a_split_point = False
+
+ def try_split(self, decoded_split_position):
+ with self._lock:
+ if self.last_group_start() is None:
+ logging.info('Refusing to split %r at %r: unstarted'
+ , self, decoded_split_position)
+ return
+
+ if decoded_split_position <= self.last_group_start():
+ logging.info('Refusing to split %r at %r: already past proposed split '
+ 'position'
+ , self, decoded_split_position)
+ return
+
+ if ((self.stop_position()
+ and decoded_split_position >= self.stop_position())
+ or (self.start_position()
+ and decoded_split_position <= self.start_position())):
+ logging.error('Refusing to split %r at %r: proposed split position out '
+ 'of range', self, decoded_split_position)
+ return
+
+ logging.debug('Agreeing to split %r at %r'
+ , self, decoded_split_position)
+ self._decoded_stop_pos = decoded_split_position
+
+ # Since GroupedShuffleRangeTracker cannot determine relative sizes of the
+ # two splits, returning 0.5 as the fraction below so that the framework
+ # assumes the splits to be of the same size.
+ return self._decoded_stop_pos, 0.5
+
+ def fraction_consumed(self):
+ # GroupingShuffle sources have special support on the service and the
+ # service will estimate progress from positions for us.
+ raise RuntimeError('GroupedShuffleRangeTracker does not measure fraction'
+ ' consumed due to positions being opaque strings'
+ ' that are interpreted by the service')
+
+ def split_points(self):
+ with self._lock:
+ splits_points_consumed = (
+ 0 if self._split_points_seen <= 1 else (self._split_points_seen - 1))
+
+ return (splits_points_consumed,
+ iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
+
+
class OrderedPositionRangeTracker(iobase.RangeTracker):
"""
An abstract base class for range trackers whose positions are comparable.
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/range_trackers_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py
index 3e92663..edb6386 100644
--- a/sdks/python/apache_beam/io/range_trackers_test.py
+++ b/sdks/python/apache_beam/io/range_trackers_test.py
@@ -17,11 +17,14 @@
"""Unit tests for the range_trackers module."""
+import array
import copy
import logging
import math
import unittest
+
+from apache_beam.io import iobase
from apache_beam.io import range_trackers
@@ -186,6 +189,189 @@ class OffsetRangeTrackerTest(unittest.TestCase):
(3, 41))
+class GroupedShuffleRangeTrackerTest(unittest.TestCase):
+
+ def bytes_to_position(self, bytes_array):
+ return array.array('B', bytes_array).tostring()
+
+ def test_try_return_record_in_infinite_range(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker('', '')
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([1, 2, 3])))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([1, 2, 5])))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([3, 6, 8, 10])))
+
+ def test_try_return_record_finite_range(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([1, 0, 0]), self.bytes_to_position([5, 0, 0]))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([1, 2, 3])))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([1, 2, 5])))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([3, 6, 8, 10])))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([4, 255, 255, 255])))
+ # Should fail for positions that are lexicographically equal to or larger
+ # than the defined stop position.
+ self.assertFalse(copy.copy(tracker).try_claim(
+ self.bytes_to_position([5, 0, 0])))
+ self.assertFalse(copy.copy(tracker).try_claim(
+ self.bytes_to_position([5, 0, 1])))
+ self.assertFalse(copy.copy(tracker).try_claim(
+ self.bytes_to_position([6, 0, 0])))
+
+ def test_try_return_record_with_non_split_point(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([1, 0, 0]), self.bytes_to_position([5, 0, 0]))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([1, 2, 3])))
+ tracker.set_current_position(self.bytes_to_position([1, 2, 3]))
+ tracker.set_current_position(self.bytes_to_position([1, 2, 3]))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([1, 2, 5])))
+ tracker.set_current_position(self.bytes_to_position([1, 2, 5]))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([3, 6, 8, 10])))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([4, 255, 255, 255])))
+
+ def test_first_record_non_split_point(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0]))
+ with self.assertRaises(ValueError):
+ tracker.set_current_position(self.bytes_to_position([3, 4, 5]))
+
+ def test_non_split_point_record_with_different_position(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0]))
+ self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 5])))
+ with self.assertRaises(ValueError):
+ tracker.set_current_position(self.bytes_to_position([3, 4, 6]))
+
+ def test_try_return_record_before_start(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0]))
+ with self.assertRaises(ValueError):
+ tracker.try_claim(self.bytes_to_position([1, 2, 3]))
+
+ def test_try_return_non_monotonic(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0]))
+ self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 5])))
+ self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 6])))
+ with self.assertRaises(ValueError):
+ tracker.try_claim(self.bytes_to_position([3, 2, 1]))
+
+ def test_try_return_identical_positions(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0]))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([3, 4, 5])))
+ with self.assertRaises(ValueError):
+ tracker.try_claim(self.bytes_to_position([3, 4, 5]))
+
+ def test_try_split_at_position_infinite_range(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker('', '')
+ # Should fail before first record is returned.
+ self.assertFalse(tracker.try_split(
+ self.bytes_to_position([3, 4, 5, 6])))
+
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([1, 2, 3])))
+
+ # Should now succeed.
+ self.assertIsNotNone(tracker.try_split(
+ self.bytes_to_position([3, 4, 5, 6])))
+ # Should not split at same or larger position.
+ self.assertIsNone(tracker.try_split(
+ self.bytes_to_position([3, 4, 5, 6])))
+ self.assertIsNone(tracker.try_split(
+ self.bytes_to_position([3, 4, 5, 6, 7])))
+ self.assertIsNone(tracker.try_split(
+ self.bytes_to_position([4, 5, 6, 7])))
+
+ # Should split at smaller position.
+ self.assertIsNotNone(tracker.try_split(
+ self.bytes_to_position([3, 2, 1])))
+
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([2, 3, 4])))
+
+ # Should not split at a position we're already past.
+ self.assertIsNone(tracker.try_split(
+ self.bytes_to_position([2, 3, 4])))
+ self.assertIsNone(tracker.try_split(
+ self.bytes_to_position([2, 3, 3])))
+
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([3, 2, 0])))
+ self.assertFalse(tracker.try_claim(
+ self.bytes_to_position([3, 2, 1])))
+
+ def test_try_test_split_at_position_finite_range(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([0, 0, 0]),
+ self.bytes_to_position([10, 20, 30]))
+ # Should fail before first record is returned.
+ self.assertFalse(tracker.try_split(
+ self.bytes_to_position([0, 0, 0])))
+ self.assertFalse(tracker.try_split(
+ self.bytes_to_position([3, 4, 5, 6])))
+
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([1, 2, 3])))
+
+ # Should now succeed.
+ self.assertTrue(tracker.try_split(
+ self.bytes_to_position([3, 4, 5, 6])))
+ # Should not split at same or larger position.
+ self.assertFalse(tracker.try_split(
+ self.bytes_to_position([3, 4, 5, 6])))
+ self.assertFalse(tracker.try_split(
+ self.bytes_to_position([3, 4, 5, 6, 7])))
+ self.assertFalse(tracker.try_split(
+ self.bytes_to_position([4, 5, 6, 7])))
+
+ # Should split at smaller position.
+ self.assertTrue(tracker.try_split(
+ self.bytes_to_position([3, 2, 1])))
+ # But not at a position at or before last returned record.
+ self.assertFalse(tracker.try_split(
+ self.bytes_to_position([1, 2, 3])))
+
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([2, 3, 4])))
+ self.assertTrue(tracker.try_claim(
+ self.bytes_to_position([3, 2, 0])))
+ self.assertFalse(tracker.try_claim(
+ self.bytes_to_position([3, 2, 1])))
+
+ def test_split_points(self):
+ tracker = range_trackers.GroupedShuffleRangeTracker(
+ self.bytes_to_position([1, 0, 0]),
+ self.bytes_to_position([5, 0, 0]))
+ self.assertEqual(tracker.split_points(),
+ (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+ self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 3])))
+ self.assertEqual(tracker.split_points(),
+ (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+ self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 5])))
+ self.assertEqual(tracker.split_points(),
+ (1, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+ self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 6, 8])))
+ self.assertEqual(tracker.split_points(),
+ (2, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+ self.assertTrue(tracker.try_claim(self.bytes_to_position([4, 255, 255])))
+ self.assertEqual(tracker.split_points(),
+ (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+ self.assertFalse(tracker.try_claim(self.bytes_to_position([5, 1, 0])))
+ self.assertEqual(tracker.split_points(),
+ (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+
+
class OrderedPositionRangeTrackerTest(unittest.TestCase):
class DoubleRangeTracker(range_trackers.OrderedPositionRangeTracker):
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/options/pipeline_options.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py
index ea996a3..daef3a7 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -18,6 +18,7 @@
"""Pipeline options obtained from command line parsing."""
import argparse
+import warnings
from apache_beam.transforms.display import HasDisplayData
from apache_beam.options.value_provider import StaticValueProvider
@@ -278,6 +279,14 @@ class StandardOptions(PipelineOptions):
action='store_true',
help='Whether to enable streaming mode.')
+ # TODO(BEAM-1265): Remove this warning, once at least one runner supports
+ # streaming pipelines.
+ def validate(self, validator):
+ errors = []
+ if self.view_as(StandardOptions).streaming:
+ warnings.warn('Streaming pipelines are not supported.')
+ return errors
+
class TypeOptions(PipelineOptions):
@@ -465,14 +474,7 @@ class WorkerOptions(PipelineOptions):
parser.add_argument(
'--use_public_ips',
default=None,
- action='store_true',
- help='Whether to assign public IP addresses to the worker VMs.')
- parser.add_argument(
- '--no_use_public_ips',
- dest='use_public_ips',
- default=None,
- action='store_false',
- help='Whether to assign only private IP addresses to the worker VMs.')
+ help='Whether to assign public IP addresses to the worker machines.')
def validate(self, validator):
errors = []
@@ -552,18 +554,6 @@ class SetupOptions(PipelineOptions):
'worker will install the resulting package before running any custom '
'code.'))
parser.add_argument(
- '--beam_plugin', '--beam_plugin',
- dest='beam_plugins',
- action='append',
- default=None,
- help=
- ('Bootstrap the python process before executing any code by importing '
- 'all the plugins used in the pipeline. Please pass a comma separated'
- 'list of import paths to be included. This is currently an '
- 'experimental flag and provides no stability. Multiple '
- '--beam_plugin options can be specified if more than one plugin '
- 'is needed.'))
- parser.add_argument(
'--save_main_session',
default=False,
action='store_true',
@@ -609,11 +599,6 @@ class TestOptions(PipelineOptions):
help=('Verify state/output of e2e test pipeline. This is pickled '
'version of the matcher which should extends '
'hamcrest.core.base_matcher.BaseMatcher.'))
- parser.add_argument(
- '--dry_run',
- default=False,
- help=('Used in unit testing runners without submitting the '
- 'actual job.'))
def validate(self, validator):
errors = []
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/options/pipeline_options_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py
index f4dd4d9..1a644b4 100644
--- a/sdks/python/apache_beam/options/pipeline_options_test.py
+++ b/sdks/python/apache_beam/options/pipeline_options_test.py
@@ -192,52 +192,47 @@ class PipelineOptionsTest(unittest.TestCase):
options = PipelineOptions(['--redefined_flag'])
self.assertTrue(options.get_all_options()['redefined_flag'])
- # TODO(BEAM-1319): Require unique names only within a test.
- # For now, <file name acronym>_vp_arg<number> will be the convention
- # to name value-provider arguments in tests, as opposed to
- # <file name acronym>_non_vp_arg<number> for non-value-provider arguments.
- # The number will grow per file as tests are added.
def test_value_provider_options(self):
class UserOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
- '--pot_vp_arg1',
+ '--vp_arg',
help='This flag is a value provider')
parser.add_value_provider_argument(
- '--pot_vp_arg2',
+ '--vp_arg2',
default=1,
type=int)
parser.add_argument(
- '--pot_non_vp_arg1',
+ '--non_vp_arg',
default=1,
type=int
)
# Provide values: if not provided, the option becomes of the type runtime vp
- options = UserOptions(['--pot_vp_arg1', 'hello'])
- self.assertIsInstance(options.pot_vp_arg1, StaticValueProvider)
- self.assertIsInstance(options.pot_vp_arg2, RuntimeValueProvider)
- self.assertIsInstance(options.pot_non_vp_arg1, int)
+ options = UserOptions(['--vp_arg', 'hello'])
+ self.assertIsInstance(options.vp_arg, StaticValueProvider)
+ self.assertIsInstance(options.vp_arg2, RuntimeValueProvider)
+ self.assertIsInstance(options.non_vp_arg, int)
# Values can be overwritten
- options = UserOptions(pot_vp_arg1=5,
- pot_vp_arg2=StaticValueProvider(value_type=str,
- value='bye'),
- pot_non_vp_arg1=RuntimeValueProvider(
+ options = UserOptions(vp_arg=5,
+ vp_arg2=StaticValueProvider(value_type=str,
+ value='bye'),
+ non_vp_arg=RuntimeValueProvider(
option_name='foo',
value_type=int,
default_value=10))
- self.assertEqual(options.pot_vp_arg1, 5)
- self.assertTrue(options.pot_vp_arg2.is_accessible(),
- '%s is not accessible' % options.pot_vp_arg2)
- self.assertEqual(options.pot_vp_arg2.get(), 'bye')
- self.assertFalse(options.pot_non_vp_arg1.is_accessible())
+ self.assertEqual(options.vp_arg, 5)
+ self.assertTrue(options.vp_arg2.is_accessible(),
+ '%s is not accessible' % options.vp_arg2)
+ self.assertEqual(options.vp_arg2.get(), 'bye')
+ self.assertFalse(options.non_vp_arg.is_accessible())
with self.assertRaises(RuntimeError):
- options.pot_non_vp_arg1.get()
+ options.non_vp_arg.get()
if __name__ == '__main__':
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/options/value_provider_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/options/value_provider_test.py b/sdks/python/apache_beam/options/value_provider_test.py
index 17e9590..3a45e8b 100644
--- a/sdks/python/apache_beam/options/value_provider_test.py
+++ b/sdks/python/apache_beam/options/value_provider_test.py
@@ -24,77 +24,72 @@ from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.options.value_provider import StaticValueProvider
-# TODO(BEAM-1319): Require unique names only within a test.
-# For now, <file name acronym>_vp_arg<number> will be the convention
-# to name value-provider arguments in tests, as opposed to
-# <file name acronym>_non_vp_arg<number> for non-value-provider arguments.
-# The number will grow per file as tests are added.
class ValueProviderTests(unittest.TestCase):
def test_static_value_provider_keyword_argument(self):
class UserDefinedOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
- '--vpt_vp_arg1',
+ '--vp_arg',
help='This keyword argument is a value provider',
default='some value')
- options = UserDefinedOptions(['--vpt_vp_arg1', 'abc'])
- self.assertTrue(isinstance(options.vpt_vp_arg1, StaticValueProvider))
- self.assertTrue(options.vpt_vp_arg1.is_accessible())
- self.assertEqual(options.vpt_vp_arg1.get(), 'abc')
+ options = UserDefinedOptions(['--vp_arg', 'abc'])
+ self.assertTrue(isinstance(options.vp_arg, StaticValueProvider))
+ self.assertTrue(options.vp_arg.is_accessible())
+ self.assertEqual(options.vp_arg.get(), 'abc')
def test_runtime_value_provider_keyword_argument(self):
class UserDefinedOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
- '--vpt_vp_arg2',
+ '--vp_arg',
help='This keyword argument is a value provider')
options = UserDefinedOptions()
- self.assertTrue(isinstance(options.vpt_vp_arg2, RuntimeValueProvider))
- self.assertFalse(options.vpt_vp_arg2.is_accessible())
+ self.assertTrue(isinstance(options.vp_arg, RuntimeValueProvider))
+ self.assertFalse(options.vp_arg.is_accessible())
with self.assertRaises(RuntimeError):
- options.vpt_vp_arg2.get()
+ options.vp_arg.get()
def test_static_value_provider_positional_argument(self):
class UserDefinedOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
- 'vpt_vp_arg3',
+ 'vp_pos_arg',
help='This positional argument is a value provider',
default='some value')
options = UserDefinedOptions(['abc'])
- self.assertTrue(isinstance(options.vpt_vp_arg3, StaticValueProvider))
- self.assertTrue(options.vpt_vp_arg3.is_accessible())
- self.assertEqual(options.vpt_vp_arg3.get(), 'abc')
+ self.assertTrue(isinstance(options.vp_pos_arg, StaticValueProvider))
+ self.assertTrue(options.vp_pos_arg.is_accessible())
+ self.assertEqual(options.vp_pos_arg.get(), 'abc')
def test_runtime_value_provider_positional_argument(self):
class UserDefinedOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
- 'vpt_vp_arg4',
+ 'vp_pos_arg',
help='This positional argument is a value provider')
options = UserDefinedOptions([])
- self.assertTrue(isinstance(options.vpt_vp_arg4, RuntimeValueProvider))
- self.assertFalse(options.vpt_vp_arg4.is_accessible())
+ self.assertTrue(isinstance(options.vp_pos_arg, RuntimeValueProvider))
+ self.assertFalse(options.vp_pos_arg.is_accessible())
with self.assertRaises(RuntimeError):
- options.vpt_vp_arg4.get()
+ options.vp_pos_arg.get()
def test_static_value_provider_type_cast(self):
class UserDefinedOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
- '--vpt_vp_arg5',
+ '--vp_arg',
type=int,
help='This flag is a value provider')
- options = UserDefinedOptions(['--vpt_vp_arg5', '123'])
- self.assertTrue(isinstance(options.vpt_vp_arg5, StaticValueProvider))
- self.assertTrue(options.vpt_vp_arg5.is_accessible())
- self.assertEqual(options.vpt_vp_arg5.get(), 123)
+ options = UserDefinedOptions(['--vp_arg', '123'])
+ self.assertTrue(isinstance(options.vp_arg, StaticValueProvider))
+ self.assertTrue(options.vp_arg.is_accessible())
+ self.assertEqual(options.vp_arg.get(), 123)
def test_set_runtime_option(self):
# define ValueProvider ptions, with and without default values
@@ -102,25 +97,25 @@ class ValueProviderTests(unittest.TestCase):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
- '--vpt_vp_arg6',
+ '--vp_arg',
help='This keyword argument is a value provider') # set at runtime
parser.add_value_provider_argument( # not set, had default int
- '-v', '--vpt_vp_arg7', # with short form
+ '-v', '--vp_arg2', # with short form
default=123,
type=int)
parser.add_value_provider_argument( # not set, had default str
- '--vpt_vp-arg8', # with dash in name
+ '--vp-arg3', # with dash in name
default='123',
type=str)
parser.add_value_provider_argument( # not set and no default
- '--vpt_vp_arg9',
+ '--vp_arg4',
type=float)
parser.add_value_provider_argument( # positional argument set
- 'vpt_vp_arg10', # default & runtime ignored
+ 'vp_pos_arg', # default & runtime ignored
help='This positional argument is a value provider',
type=float,
default=5.4)
@@ -128,23 +123,23 @@ class ValueProviderTests(unittest.TestCase):
# provide values at graph-construction time
# (options not provided here become of the type RuntimeValueProvider)
options = UserDefinedOptions1(['1.2'])
- self.assertFalse(options.vpt_vp_arg6.is_accessible())
- self.assertFalse(options.vpt_vp_arg7.is_accessible())
- self.assertFalse(options.vpt_vp_arg8.is_accessible())
- self.assertFalse(options.vpt_vp_arg9.is_accessible())
- self.assertTrue(options.vpt_vp_arg10.is_accessible())
+ self.assertFalse(options.vp_arg.is_accessible())
+ self.assertFalse(options.vp_arg2.is_accessible())
+ self.assertFalse(options.vp_arg3.is_accessible())
+ self.assertFalse(options.vp_arg4.is_accessible())
+ self.assertTrue(options.vp_pos_arg.is_accessible())
# provide values at job-execution time
# (options not provided here will use their default, if they have one)
- RuntimeValueProvider.set_runtime_options({'vpt_vp_arg6': 'abc',
- 'vpt_vp_arg10':'3.2'})
- self.assertTrue(options.vpt_vp_arg6.is_accessible())
- self.assertEqual(options.vpt_vp_arg6.get(), 'abc')
- self.assertTrue(options.vpt_vp_arg7.is_accessible())
- self.assertEqual(options.vpt_vp_arg7.get(), 123)
- self.assertTrue(options.vpt_vp_arg8.is_accessible())
- self.assertEqual(options.vpt_vp_arg8.get(), '123')
- self.assertTrue(options.vpt_vp_arg9.is_accessible())
- self.assertIsNone(options.vpt_vp_arg9.get())
- self.assertTrue(options.vpt_vp_arg10.is_accessible())
- self.assertEqual(options.vpt_vp_arg10.get(), 1.2)
+ RuntimeValueProvider.set_runtime_options({'vp_arg': 'abc',
+ 'vp_pos_arg':'3.2'})
+ self.assertTrue(options.vp_arg.is_accessible())
+ self.assertEqual(options.vp_arg.get(), 'abc')
+ self.assertTrue(options.vp_arg2.is_accessible())
+ self.assertEqual(options.vp_arg2.get(), 123)
+ self.assertTrue(options.vp_arg3.is_accessible())
+ self.assertEqual(options.vp_arg3.get(), '123')
+ self.assertTrue(options.vp_arg4.is_accessible())
+ self.assertIsNone(options.vp_arg4.get())
+ self.assertTrue(options.vp_pos_arg.is_accessible())
+ self.assertEqual(options.vp_pos_arg.get(), 1.2)
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/pipeline.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index fe36d85..9093abf 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -45,7 +45,6 @@ Typical usage:
from __future__ import absolute_import
-import abc
import collections
import logging
import os
@@ -54,7 +53,6 @@ import tempfile
from apache_beam import pvalue
from apache_beam.internal import pickler
-from apache_beam.pvalue import PCollection
from apache_beam.runners import create_runner
from apache_beam.runners import PipelineRunner
from apache_beam.transforms import ptransform
@@ -159,157 +157,6 @@ class Pipeline(object):
"""Returns the root transform of the transform stack."""
return self.transforms_stack[0]
- def _remove_labels_recursively(self, applied_transform):
- for part in applied_transform.parts:
- if part.full_label in self.applied_labels:
- self.applied_labels.remove(part.full_label)
- if part.parts:
- for part2 in part.parts:
- self._remove_labels_recursively(part2)
-
- def _replace(self, override):
-
- assert isinstance(override, PTransformOverride)
- matcher = override.get_matcher()
-
- output_map = {}
- output_replacements = {}
- input_replacements = {}
-
- class TransformUpdater(PipelineVisitor): # pylint: disable=used-before-assignment
- """"A visitor that replaces the matching PTransforms."""
-
- def __init__(self, pipeline):
- self.pipeline = pipeline
-
- def _replace_if_needed(self, transform_node):
- if matcher(transform_node):
- replacement_transform = override.get_replacement_transform(
- transform_node.transform)
- inputs = transform_node.inputs
- # TODO: Support replacing PTransforms with multiple inputs.
- if len(inputs) > 1:
- raise NotImplementedError(
- 'PTransform overriding is only supported for PTransforms that '
- 'have a single input. Tried to replace input of '
- 'AppliedPTransform %r that has %d inputs',
- transform_node, len(inputs))
- transform_node.transform = replacement_transform
- self.pipeline.transforms_stack.append(transform_node)
-
- # Keeping the same label for the replaced node but recursively
- # removing labels of child transforms since they will be replaced
- # during the expand below.
- self.pipeline._remove_labels_recursively(transform_node)
-
- new_output = replacement_transform.expand(inputs[0])
- if new_output.producer is None:
- # When current transform is a primitive, we set the producer here.
- new_output.producer = transform_node
-
- # We only support replacing transforms with a single output with
- # another transform that produces a single output.
- # TODO: Support replacing PTransforms with multiple outputs.
- if (len(transform_node.outputs) > 1 or
- not isinstance(transform_node.outputs[None], PCollection) or
- not isinstance(new_output, PCollection)):
- raise NotImplementedError(
- 'PTransform overriding is only supported for PTransforms that '
- 'have a single output. Tried to replace output of '
- 'AppliedPTransform %r with %r.'
- , transform_node, new_output)
-
- # Recording updated outputs. This cannot be done in the same visitor
- # since if we dynamically update output type here, we'll run into
- # errors when visiting child nodes.
- output_map[transform_node.outputs[None]] = new_output
-
- self.pipeline.transforms_stack.pop()
-
- def enter_composite_transform(self, transform_node):
- self._replace_if_needed(transform_node)
-
- def visit_transform(self, transform_node):
- self._replace_if_needed(transform_node)
-
- self.visit(TransformUpdater(self))
-
- # Adjusting inputs and outputs
- class InputOutputUpdater(PipelineVisitor): # pylint: disable=used-before-assignment
- """"A visitor that records input and output values to be replaced.
-
- Input and output values that should be updated are recorded in maps
- input_replacements and output_replacements respectively.
-
- We cannot update input and output values while visiting since that results
- in validation errors.
- """
-
- def __init__(self, pipeline):
- self.pipeline = pipeline
-
- def enter_composite_transform(self, transform_node):
- self.visit_transform(transform_node)
-
- def visit_transform(self, transform_node):
- if (None in transform_node.outputs and
- transform_node.outputs[None] in output_map):
- output_replacements[transform_node] = (
- output_map[transform_node.outputs[None]])
-
- replace_input = False
- for input in transform_node.inputs:
- if input in output_map:
- replace_input = True
- break
-
- if replace_input:
- new_input = [
- input if not input in output_map else output_map[input]
- for input in transform_node.inputs]
- input_replacements[transform_node] = new_input
-
- self.visit(InputOutputUpdater(self))
-
- for transform in output_replacements:
- transform.replace_output(output_replacements[transform])
-
- for transform in input_replacements:
- transform.inputs = input_replacements[transform]
-
- def _check_replacement(self, override):
- matcher = override.get_matcher()
-
- class ReplacementValidator(PipelineVisitor):
- def visit_transform(self, transform_node):
- if matcher(transform_node):
- raise RuntimeError('Transform node %r was not replaced as expected.',
- transform_node)
-
- self.visit(ReplacementValidator())
-
- def replace_all(self, replacements):
- """ Dynamically replaces PTransforms in the currently populated hierarchy.
-
- Currently this only works for replacements where input and output types
- are exactly the same.
- TODO: Update this to also work for transform overrides where input and
- output types are different.
-
- Args:
- replacements a list of PTransformOverride objects.
- """
- for override in replacements:
- assert isinstance(override, PTransformOverride)
- self._replace(override)
-
- # Checking if the PTransforms have been successfully replaced. This will
- # result in a failure if a PTransform that was replaced in a given override
- # gets re-added in a subsequent override. This is not allowed and ordering
- # of PTransformOverride objects in 'replacements' is important.
- for override in replacements:
- self._check_replacement(override)
-
def run(self, test_runner_api=True):
"""Runs the pipeline. Returns whatever our runner returns after running."""
@@ -466,20 +313,10 @@ class Pipeline(object):
self.transforms_stack.pop()
return pvalueish_result
- def __reduce__(self):
- # Some transforms contain a reference to their enclosing pipeline,
- # which in turn reference all other transforms (resulting in quadratic
- # time/space to pickle each transform individually). As we don't
- # require pickled pipelines to be executable, break the chain here.
- return str, ('Pickled pipeline stub.',)
-
def _verify_runner_api_compatible(self):
class Visitor(PipelineVisitor): # pylint: disable=used-before-assignment
ok = True # Really a nonlocal.
- def enter_composite_transform(self, transform_node):
- self.visit_transform(transform_node)
-
def visit_transform(self, transform_node):
if transform_node.side_inputs:
# No side inputs (yet).
@@ -502,7 +339,7 @@ class Pipeline(object):
def to_runner_api(self):
"""For internal use only; no backwards-compatibility guarantees."""
from apache_beam.runners import pipeline_context
- from apache_beam.portability.api import beam_runner_api_pb2
+ from apache_beam.runners.api import beam_runner_api_pb2
context = pipeline_context.PipelineContext()
# Mutates context; placing inline would force dependence on
# argument evaluation order.
@@ -525,18 +362,7 @@ class Pipeline(object):
p.applied_labels = set([
t.unique_name for t in proto.components.transforms.values()])
for id in proto.components.pcollections:
- pcollection = context.pcollections.get_by_id(id)
- pcollection.pipeline = p
-
- # Inject PBegin input where necessary.
- from apache_beam.io.iobase import Read
- from apache_beam.transforms.core import Create
- has_pbegin = [Read, Create]
- for id in proto.components.transforms:
- transform = context.transforms.get_by_id(id)
- if not transform.inputs and transform.transform.__class__ in has_pbegin:
- transform.inputs = (pvalue.PBegin(p),)
-
+ context.pcollections.get_by_id(id).pipeline = p
return p
@@ -558,7 +384,7 @@ class PipelineVisitor(object):
pass
def visit_transform(self, transform_node):
- """Callback for visiting a transform leaf node in the pipeline DAG."""
+ """Callback for visiting a transform node in the pipeline DAG."""
pass
def enter_composite_transform(self, transform_node):
@@ -615,20 +441,6 @@ class AppliedPTransform(object):
for side_input in self.side_inputs:
real_producer(side_input.pvalue).refcounts[side_input.pvalue.tag] += 1
- def replace_output(self, output, tag=None):
- """Replaces the output defined by the given tag with the given output.
-
- Args:
- output: replacement output
- tag: tag of the output to be replaced.
- """
- if isinstance(output, pvalue.DoOutputsTuple):
- self.replace_output(output[output._main_tag])
- elif isinstance(output, pvalue.PValue):
- self.outputs[tag] = output
- else:
- raise TypeError("Unexpected output type: %s" % output)
-
def add_output(self, output, tag=None):
if isinstance(output, pvalue.DoOutputsTuple):
self.add_output(output[output._main_tag])
@@ -713,7 +525,7 @@ class AppliedPTransform(object):
if isinstance(output, pvalue.PCollection)}
def to_runner_api(self, context):
- from apache_beam.portability.api import beam_runner_api_pb2
+ from apache_beam.runners.api import beam_runner_api_pb2
def transform_to_runner_api(transform, context):
if transform is None:
@@ -752,37 +564,3 @@ class AppliedPTransform(object):
pc.tag = tag
result.update_input_refcounts()
return result
-
-
-class PTransformOverride(object):
- """For internal use only; no backwards-compatibility guarantees.
-
- Gives a matcher and replacements for matching PTransforms.
-
- TODO: Update this to support cases where input and/our output types are
- different.
- """
- __metaclass__ = abc.ABCMeta
-
- @abc.abstractmethod
- def get_matcher(self):
- """Gives a matcher that will be used to to perform this override.
-
- Returns:
- a callable that takes an AppliedPTransform as a parameter and returns a
- boolean as a result.
- """
- raise NotImplementedError
-
- @abc.abstractmethod
- def get_replacement_transform(self, ptransform):
- """Provides a runner specific override for a given PTransform.
-
- Args:
- ptransform: PTransform to be replaced.
- Returns:
- A PTransform that will be the replacement for the PTransform given as an
- argument.
- """
- # Returns a PTransformReplacement
- raise NotImplementedError
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/pipeline_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index aad0143..e0775d1 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -28,11 +28,9 @@ import apache_beam as beam
from apache_beam.io import Read
from apache_beam.metrics import Metrics
from apache_beam.pipeline import Pipeline
-from apache_beam.pipeline import PTransformOverride
from apache_beam.pipeline import PipelineOptions
from apache_beam.pipeline import PipelineVisitor
from apache_beam.pvalue import AsSingleton
-from apache_beam.runners import DirectRunner
from apache_beam.runners.dataflow.native_io.iobase import NativeSource
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
@@ -77,18 +75,6 @@ class FakeSource(NativeSource):
return FakeSource._Reader(self._vals)
-class DoubleParDo(beam.PTransform):
- def expand(self, input):
- return input | 'Inner' >> beam.Map(lambda a: a * 2)
-
-
-class TripleParDo(beam.PTransform):
- def expand(self, input):
- # Keeping labels the same intentionally to make sure that there is no label
- # conflict due to replacement.
- return input | 'Inner' >> beam.Map(lambda a: a * 3)
-
-
class PipelineTest(unittest.TestCase):
@staticmethod
@@ -299,27 +285,6 @@ class PipelineTest(unittest.TestCase):
# p = Pipeline('EagerRunner')
# self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x))
- def test_ptransform_overrides(self):
-
- def my_par_do_matcher(applied_ptransform):
- return isinstance(applied_ptransform.transform, DoubleParDo)
-
- class MyParDoOverride(PTransformOverride):
-
- def get_matcher(self):
- return my_par_do_matcher
-
- def get_replacement_transform(self, ptransform):
- if isinstance(ptransform, DoubleParDo):
- return TripleParDo()
- raise ValueError('Unsupported type of transform: %r', ptransform)
-
- # Using following private variable for testing.
- DirectRunner._PTRANSFORM_OVERRIDES.append(MyParDoOverride())
- with Pipeline() as p:
- pcoll = p | beam.Create([1, 2, 3]) | 'Multiply' >> DoubleParDo()
- assert_that(pcoll, equal_to([3, 6, 9]))
-
class DoFnTest(unittest.TestCase):
@@ -480,24 +445,6 @@ class RunnerApiTest(unittest.TestCase):
p2 = Pipeline.from_runner_api(proto, p.runner, p._options)
p2.run()
- def test_pickling(self):
- class MyPTransform(beam.PTransform):
- pickle_count = [0]
-
- def expand(self, p):
- self.p = p
- return p | beam.Create([None])
-
- def __reduce__(self):
- self.pickle_count[0] += 1
- return str, ()
-
- p = beam.Pipeline()
- for k in range(20):
- p | 'Iter%s' % k >> MyPTransform() # pylint: disable=expression-not-assigned
- p.to_runner_api()
- self.assertEqual(MyPTransform.pickle_count[0], 20)
-
if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/portability/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/portability/__init__.py b/sdks/python/apache_beam/portability/__init__.py
deleted file mode 100644
index 0bce5d6..0000000
--- a/sdks/python/apache_beam/portability/__init__.py
+++ /dev/null
@@ -1,18 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""For internal use only; no backwards-compatibility guarantees."""
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/portability/api/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/portability/api/__init__.py b/sdks/python/apache_beam/portability/api/__init__.py
deleted file mode 100644
index 2750859..0000000
--- a/sdks/python/apache_beam/portability/api/__init__.py
+++ /dev/null
@@ -1,21 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""For internal use only; no backwards-compatibility guarantees.
-
-Automatically generated when running setup.py sdist or build[_py].
-"""
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/pvalue.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py
index 34a483e..7385e82 100644
--- a/sdks/python/apache_beam/pvalue.py
+++ b/sdks/python/apache_beam/pvalue.py
@@ -128,7 +128,7 @@ class PCollection(PValue):
return _InvalidUnpickledPCollection, ()
def to_runner_api(self, context):
- from apache_beam.portability.api import beam_runner_api_pb2
+ from apache_beam.runners.api import beam_runner_api_pb2
from apache_beam.internal import pickler
return beam_runner_api_pb2.PCollection(
unique_name='%d%s.%s' % (
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/api/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/api/__init__.py b/sdks/python/apache_beam/runners/api/__init__.py
new file mode 100644
index 0000000..2750859
--- /dev/null
+++ b/sdks/python/apache_beam/runners/api/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""For internal use only; no backwards-compatibility guarantees.
+
+Automatically generated when running setup.py sdist or build[_py].
+"""
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index 059e139..3fc8983 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -39,17 +39,13 @@ from apache_beam.runners.dataflow.internal import names
from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api
from apache_beam.runners.dataflow.internal.names import PropertyNames
from apache_beam.runners.dataflow.internal.names import TransformNames
-from apache_beam.runners.dataflow.ptransform_overrides import CreatePTransformOverride
from apache_beam.runners.runner import PValueCache
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineRunner
from apache_beam.runners.runner import PipelineState
from apache_beam.transforms.display import DisplayData
from apache_beam.typehints import typehints
-from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import StandardOptions
-from apache_beam.options.pipeline_options import TestOptions
-from apache_beam.utils.plugin import BeamPlugin
__all__ = ['DataflowRunner']
@@ -65,15 +61,11 @@ class DataflowRunner(PipelineRunner):
if blocking is set to False.
"""
- # A list of PTransformOverride objects to be applied before running a pipeline
- # using DataflowRunner.
- # Currently this only works for overrides where the input and output types do
- # not change.
- # For internal SDK use only. This should not be updated by Beam pipeline
- # authors.
- _PTRANSFORM_OVERRIDES = [
- CreatePTransformOverride(),
- ]
+ # Environment version information. It is passed to the service during a
+ # a job submission and is used by the service to establish what features
+ # are expected by the workers.
+ BATCH_ENVIRONMENT_MAJOR_VERSION = '6'
+ STREAMING_ENVIRONMENT_MAJOR_VERSION = '1'
def __init__(self, cache=None):
# Cache of CloudWorkflowStep protos generated while the runner
@@ -223,6 +215,7 @@ class DataflowRunner(PipelineRunner):
return FlattenInputVisitor()
+ # TODO(mariagh): Make this method take pipepline_options
def run(self, pipeline):
"""Remotely executes entire pipeline or parts reachable from node."""
# Import here to avoid adding the dependency for local running scenarios.
@@ -233,17 +226,6 @@ class DataflowRunner(PipelineRunner):
raise ImportError(
'Google Cloud Dataflow runner not available, '
'please install apache_beam[gcp]')
-
- # Performing configured PTransform overrides.
- pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)
-
- # Add setup_options for all the BeamPlugin imports
- setup_options = pipeline._options.view_as(SetupOptions)
- plugins = BeamPlugin.get_all_plugin_paths()
- if setup_options.beam_plugins is not None:
- plugins = list(set(plugins + setup_options.beam_plugins))
- setup_options.beam_plugins = plugins
-
self.job = apiclient.Job(pipeline._options)
# Dataflow runner requires a KV type for GBK inputs, hence we enforce that
@@ -257,14 +239,15 @@ class DataflowRunner(PipelineRunner):
# The superclass's run will trigger a traversal of all reachable nodes.
super(DataflowRunner, self).run(pipeline)
- test_options = pipeline._options.view_as(TestOptions)
- # If it is a dry run, return without submitting the job.
- if test_options.dry_run:
- return None
+ standard_options = pipeline._options.view_as(StandardOptions)
+ if standard_options.streaming:
+ job_version = DataflowRunner.STREAMING_ENVIRONMENT_MAJOR_VERSION
+ else:
+ job_version = DataflowRunner.BATCH_ENVIRONMENT_MAJOR_VERSION
# Get a Dataflow API client and set its options
self.dataflow_client = apiclient.DataflowApplicationClient(
- pipeline._options)
+ pipeline._options, job_version)
# Create the job
result = DataflowPipelineResult(
@@ -377,26 +360,6 @@ class DataflowRunner(PipelineRunner):
PropertyNames.OUTPUT_NAME: PropertyNames.OUT}])
return step
- def run_Impulse(self, transform_node):
- standard_options = (
- transform_node.outputs[None].pipeline._options.view_as(StandardOptions))
- if standard_options.streaming:
- step = self._add_step(
- TransformNames.READ, transform_node.full_label, transform_node)
- step.add_property(PropertyNames.FORMAT, 'pubsub')
- step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, '_starting_signal/')
-
- step.encoding = self._get_encoded_output_coder(transform_node)
- step.add_property(
- PropertyNames.OUTPUT_INFO,
- [{PropertyNames.USER_NAME: (
- '%s.%s' % (
- transform_node.full_label, PropertyNames.OUT)),
- PropertyNames.ENCODING: step.encoding,
- PropertyNames.OUTPUT_NAME: PropertyNames.OUT}])
- else:
- ValueError('Impulse source for batch pipelines has not been defined.')
-
def run_Flatten(self, transform_node):
step = self._add_step(TransformNames.FLATTEN,
transform_node.full_label, transform_node)
@@ -655,13 +618,10 @@ class DataflowRunner(PipelineRunner):
if not standard_options.streaming:
raise ValueError('PubSubPayloadSource is currently available for use '
'only in streaming pipelines.')
- # Only one of topic or subscription should be set.
- if transform.source.full_subscription:
+ step.add_property(PropertyNames.PUBSUB_TOPIC, transform.source.topic)
+ if transform.source.subscription:
step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION,
- transform.source.full_subscription)
- elif transform.source.full_topic:
- step.add_property(PropertyNames.PUBSUB_TOPIC,
- transform.source.full_topic)
+ transform.source.topic)
if transform.source.id_label:
step.add_property(PropertyNames.PUBSUB_ID_LABEL,
transform.source.id_label)
@@ -679,12 +639,7 @@ class DataflowRunner(PipelineRunner):
# step should be the type of value outputted by each step. Read steps
# automatically wrap output values in a WindowedValue wrapper, if necessary.
# This is also necessary for proper encoding for size estimation.
- # Using a GlobalWindowCoder as a place holder instead of the default
- # PickleCoder because GlobalWindowCoder is known coder.
- # TODO(robertwb): Query the collection for the windowfn to extract the
- # correct coder.
- coder = coders.WindowedValueCoder(transform._infer_output_coder(),
- coders.coders.GlobalWindowCoder()) # pylint: disable=protected-access
+ coder = coders.WindowedValueCoder(transform._infer_output_coder()) # pylint: disable=protected-access
step.encoding = self._get_cloud_encoding(coder)
step.add_property(
@@ -745,7 +700,7 @@ class DataflowRunner(PipelineRunner):
if not standard_options.streaming:
raise ValueError('PubSubPayloadSink is currently available for use '
'only in streaming pipelines.')
- step.add_property(PropertyNames.PUBSUB_TOPIC, transform.sink.full_topic)
+ step.add_property(PropertyNames.PUBSUB_TOPIC, transform.sink.topic)
else:
raise ValueError(
'Sink %r has unexpected format %s.' % (
@@ -753,12 +708,8 @@ class DataflowRunner(PipelineRunner):
step.add_property(PropertyNames.FORMAT, transform.sink.format)
# Wrap coder in WindowedValueCoder: this is necessary for proper encoding
- # for size estimation. Using a GlobalWindowCoder as a place holder instead
- # of the default PickleCoder because GlobalWindowCoder is known coder.
- # TODO(robertwb): Query the collection for the windowfn to extract the
- # correct coder.
- coder = coders.WindowedValueCoder(transform.sink.coder,
- coders.coders.GlobalWindowCoder())
+ # for size estimation.
+ coder = coders.WindowedValueCoder(transform.sink.coder)
step.encoding = self._get_cloud_encoding(coder)
step.add_property(PropertyNames.ENCODING, step.encoding)
step.add_property(
@@ -770,7 +721,7 @@ class DataflowRunner(PipelineRunner):
@classmethod
def serialize_windowing_strategy(cls, windowing):
from apache_beam.runners import pipeline_context
- from apache_beam.portability.api import beam_runner_api_pb2
+ from apache_beam.runners.api import beam_runner_api_pb2
context = pipeline_context.PipelineContext()
windowing_proto = windowing.to_runner_api(context)
return cls.byte_array_to_json_string(
@@ -783,7 +734,7 @@ class DataflowRunner(PipelineRunner):
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.runners import pipeline_context
- from apache_beam.portability.api import beam_runner_api_pb2
+ from apache_beam.runners.api import beam_runner_api_pb2
from apache_beam.transforms.core import Windowing
proto = beam_runner_api_pb2.MessageWithComponents()
proto.ParseFromString(cls.json_string_to_byte_array(serialized_data))
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index a9b8fdb..74fd01d 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -59,8 +59,7 @@ class DataflowRunnerTest(unittest.TestCase):
'--project=test-project',
'--staging_location=ignored',
'--temp_location=/dev/null',
- '--no_auth=True',
- '--dry_run=True']
+ '--no_auth=True']
@mock.patch('time.sleep', return_value=None)
def test_wait_until_finish(self, patched_time_sleep):
@@ -109,22 +108,8 @@ class DataflowRunnerTest(unittest.TestCase):
(p | ptransform.Create([1, 2, 3]) # pylint: disable=expression-not-assigned
| 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
| ptransform.GroupByKey())
- p.run()
-
- def test_streaming_create_translation(self):
- remote_runner = DataflowRunner()
- self.default_properties.append("--streaming")
- p = Pipeline(remote_runner, PipelineOptions(self.default_properties))
- p | ptransform.Create([1]) # pylint: disable=expression-not-assigned
- p.run()
- job_dict = json.loads(str(remote_runner.job))
- self.assertEqual(len(job_dict[u'steps']), 2)
-
- self.assertEqual(job_dict[u'steps'][0][u'kind'], u'ParallelRead')
- self.assertEqual(
- job_dict[u'steps'][0][u'properties'][u'pubsub_subscription'],
- '_starting_signal/')
- self.assertEqual(job_dict[u'steps'][1][u'kind'], u'ParallelDo')
+ remote_runner.job = apiclient.Job(p._options)
+ super(DataflowRunner, remote_runner).run(p)
def test_remote_runner_display_data(self):
remote_runner = DataflowRunner()
@@ -157,7 +142,8 @@ class DataflowRunnerTest(unittest.TestCase):
(p | ptransform.Create([1, 2, 3, 4, 5])
| 'Do' >> SpecialParDo(SpecialDoFn(), now))
- p.run()
+ remote_runner.job = apiclient.Job(p._options)
+ super(DataflowRunner, remote_runner).run(p)
job_dict = json.loads(str(remote_runner.job))
steps = [step
for step in job_dict['steps']
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
index 33dfe19..df1a3f2 100644
--- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
+++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
@@ -38,6 +38,7 @@ from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.internal.clients import storage
from apache_beam.runners.dataflow.internal import dependency
from apache_beam.runners.dataflow.internal.clients import dataflow
+from apache_beam.runners.dataflow.internal.dependency import get_required_container_version
from apache_beam.runners.dataflow.internal.dependency import get_sdk_name_and_version
from apache_beam.runners.dataflow.internal.names import PropertyNames
from apache_beam.transforms import cy_combiners
@@ -49,13 +50,6 @@ from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import WorkerOptions
-# Environment version information. It is passed to the service during a
-# a job submission and is used by the service to establish what features
-# are expected by the workers.
-_LEGACY_ENVIRONMENT_MAJOR_VERSION = '6'
-_FNAPI_ENVIRONMENT_MAJOR_VERSION = '1'
-
-
class Step(object):
"""Wrapper for a dataflow Step protobuf."""
@@ -153,10 +147,7 @@ class Environment(object):
if self.standard_options.streaming:
job_type = 'FNAPI_STREAMING'
else:
- if _use_fnapi(options):
- job_type = 'FNAPI_BATCH'
- else:
- job_type = 'PYTHON_BATCH'
+ job_type = 'PYTHON_BATCH'
self.proto.version.additionalProperties.extend([
dataflow.Environment.VersionValue.AdditionalProperty(
key='job_type',
@@ -214,8 +205,11 @@ class Environment(object):
pool.workerHarnessContainerImage = (
self.worker_options.worker_harness_container_image)
else:
+ # Default to using the worker harness container image for the current SDK
+ # version.
pool.workerHarnessContainerImage = (
- dependency.get_default_container_image_for_current_sdk(job_type))
+ 'dataflow.gcr.io/v1beta3/python:%s' %
+ get_required_container_version())
if self.worker_options.use_public_ips is not None:
if self.worker_options.use_public_ips:
pool.ipConfiguration = (
@@ -370,16 +364,11 @@ class Job(object):
class DataflowApplicationClient(object):
"""A Dataflow API client used by application code to create and query jobs."""
- def __init__(self, options):
+ def __init__(self, options, environment_version):
"""Initializes a Dataflow API client object."""
self.standard_options = options.view_as(StandardOptions)
self.google_cloud_options = options.view_as(GoogleCloudOptions)
-
- if _use_fnapi(options):
- self.environment_version = _FNAPI_ENVIRONMENT_MAJOR_VERSION
- else:
- self.environment_version = _LEGACY_ENVIRONMENT_MAJOR_VERSION
-
+ self.environment_version = environment_version
if self.google_cloud_options.no_auth:
credentials = None
else:
@@ -721,14 +710,6 @@ def translate_mean(accumulator, metric_update):
metric_update.kind = None
-def _use_fnapi(pipeline_options):
- standard_options = pipeline_options.view_as(StandardOptions)
- debug_options = pipeline_options.view_as(DebugOptions)
-
- return standard_options.streaming or (
- debug_options.experiments and 'beam_fn_api' in debug_options.experiments)
-
-
# To enable a counter on the service, add it to this dictionary.
metric_translations = {
cy_combiners.CountCombineFn: ('sum', translate_scalar),
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py
index 407ffcf..67cf77f 100644
--- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py
@@ -22,6 +22,7 @@ from mock import Mock
from apache_beam.metrics.cells import DistributionData
from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner
from apache_beam.runners.dataflow.internal.clients import dataflow
# Protect against environments where apitools library is not available.
@@ -39,7 +40,9 @@ class UtilTest(unittest.TestCase):
@unittest.skip("Enable once BEAM-1080 is fixed.")
def test_create_application_client(self):
pipeline_options = PipelineOptions()
- apiclient.DataflowApplicationClient(pipeline_options)
+ apiclient.DataflowApplicationClient(
+ pipeline_options,
+ DataflowRunner.BATCH_ENVIRONMENT_MAJOR_VERSION)
def test_set_network(self):
pipeline_options = PipelineOptions(
@@ -119,30 +122,6 @@ class UtilTest(unittest.TestCase):
self.assertEqual(
metric_update.floatingPointMean.count.lowBits, accumulator.count)
- def test_default_ip_configuration(self):
- pipeline_options = PipelineOptions(
- ['--temp_location', 'gs://any-location/temp'])
- env = apiclient.Environment([], pipeline_options, '2.0.0')
- self.assertEqual(env.proto.workerPools[0].ipConfiguration, None)
-
- def test_public_ip_configuration(self):
- pipeline_options = PipelineOptions(
- ['--temp_location', 'gs://any-location/temp',
- '--use_public_ips'])
- env = apiclient.Environment([], pipeline_options, '2.0.0')
- self.assertEqual(
- env.proto.workerPools[0].ipConfiguration,
- dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PUBLIC)
-
- def test_private_ip_configuration(self):
- pipeline_options = PipelineOptions(
- ['--temp_location', 'gs://any-location/temp',
- '--no_use_public_ips'])
- env = apiclient.Environment([], pipeline_options, '2.0.0')
- self.assertEqual(
- env.proto.workerPools[0].ipConfiguration,
- dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE)
-
if __name__ == '__main__':
unittest.main()