You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jr...@apache.org on 2023/07/06 13:42:36 UTC

[beam] branch master updated: Inherit Generic for TimestampedValue (#26290)

This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 5bc80aa2e6d Inherit Generic for TimestampedValue (#26290)
5bc80aa2e6d is described below

commit 5bc80aa2e6dfd3316fd34e8f47c539921f8212a8
Author: liferoad <hu...@gmail.com>
AuthorDate: Thu Jul 6 09:42:29 2023 -0400

    Inherit Generic for TimestampedValue (#26290)
    
    * Inherit Generic for TimestampedValue
    
    * add one test
    
    * change the generic only for value
    
    * polished the test
    
    * first try with type check
    
    * polish the type checks
    
    * linting
    
    * fix link
    
    * fix sorting
    
    * ignore mypy errors
    
    * normalize to beam type
    
    * fix the test
    
    * update the doc
    
    * added one typevar test
    
    ---------
    
    Co-authored-by: xqhu <xq...@google.com>
---
 sdks/python/apache_beam/transforms/core.py         |   3 +-
 .../transforms/timestamped_value_type_test.py      | 139 +++++++++++++++++++++
 sdks/python/apache_beam/transforms/window.py       |   9 +-
 sdks/python/apache_beam/typehints/typecheck.py     |  18 ++-
 4 files changed, 163 insertions(+), 6 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 026625d7805..6fd8cd2e03b 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -3781,7 +3781,8 @@ def _strip_output_annotations(type_hint):
   contains_annotation = False
 
   def visitor(t, unused_args):
-    if t in annotations:
+    if t in annotations or (hasattr(t, '__name__') and
+                            t.__name__ == TimestampedValue.__name__):
       raise StopIteration
 
   try:
diff --git a/sdks/python/apache_beam/transforms/timestamped_value_type_test.py b/sdks/python/apache_beam/transforms/timestamped_value_type_test.py
new file mode 100644
index 00000000000..46449bb1ef7
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/timestamped_value_type_test.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import TypeVar
+
+import apache_beam as beam
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.typehints.decorators import TypeCheckError
+
+T = TypeVar("T")
+
+
+def ConvertToTimestampedValue(plant: Dict[str, Any]) -> TimestampedValue[str]:
+  return TimestampedValue[str](plant["name"], plant["season"])
+
+
+def ConvertToTimestampedValue_1(plant: Dict[str, Any]) -> TimestampedValue:
+  return TimestampedValue(plant["name"], plant["season"])
+
+
+def ConvertToTimestampedValue_2(
+    plant: Dict[str, Any]) -> TimestampedValue[List[str]]:
+  return TimestampedValue[List[str]](plant["name"], plant["season"])
+
+
+def ConvertToTimestampedValue_3(plant: Dict[str, Any]) -> TimestampedValue[T]:
+  return TimestampedValue[T](plant["name"], plant["season"])
+
+
+class TypeCheckTimestampedValueTestCase(unittest.TestCase):
+  def setUp(self):
+    self.opts = beam.options.pipeline_options.PipelineOptions(
+        runtime_type_check=True)
+    self.data = [
+        {
+            "name": "Strawberry", "season": 1585699200
+        },  # April, 2020
+    ]
+    self.data_1 = [
+        {
+            "name": 1234, "season": 1585699200
+        },  # April, 2020
+    ]
+    self.data_2 = [
+        {
+            "name": ["abc", "cde"], "season": 1585699200
+        },  # April, 2020
+    ]
+    self.data_3 = [
+        {
+            "name": [123, "cde"], "season": 1585699200
+        },  # April, 2020
+    ]
+
+  def test_pcoll_default_hints(self):
+    for fn in (ConvertToTimestampedValue, ConvertToTimestampedValue_1):
+      pc = beam.Map(fn)
+      ht = pc.default_type_hints()
+      assert len(ht) == 3
+      assert ht.output_types[0][0]
+
+  def test_pcoll_with_output_hints(self):
+    pc = beam.Map(ConvertToTimestampedValue).with_output_types(str)
+    ht = pc.get_type_hints()
+    assert len(ht) == 3
+    assert ht.output_types[0][0] == str
+
+  def test_opts_with_check(self):
+    with beam.Pipeline(options=self.opts) as p:
+      _ = (
+          p
+          | "Garden plants" >> beam.Create(self.data)
+          | "With timestamps" >> beam.Map(ConvertToTimestampedValue)
+          | beam.Map(print))
+
+  def test_opts_with_check_list_str(self):
+    with beam.Pipeline(options=self.opts) as p:
+      _ = (
+          p
+          | "Garden plants" >> beam.Create(self.data_2)
+          | "With timestamps" >> beam.Map(ConvertToTimestampedValue_2)
+          | beam.Map(print))
+
+  def test_opts_with_check_wrong_data(self):
+    with self.assertRaises(TypeCheckError):
+      with beam.Pipeline(options=self.opts) as p:
+        _ = (
+            p
+            | "Garden plants" >> beam.Create(self.data_1)
+            | "With timestamps" >> beam.Map(ConvertToTimestampedValue)
+            | beam.Map(print))
+
+  def test_opts_with_check_wrong_data_list_str(self):
+    with self.assertRaises(TypeCheckError):
+      with beam.Pipeline(options=self.opts) as p:
+        _ = (
+            p
+            | "Garden plants" >> beam.Create(self.data_1)
+            | "With timestamps" >> beam.Map(ConvertToTimestampedValue_2)
+            | beam.Map(print))
+
+    with self.assertRaises(TypeCheckError):
+      with beam.Pipeline(options=self.opts) as p:
+        _ = (
+            p
+            | "Garden plants" >> beam.Create(self.data_3)
+            | "With timestamps" >> beam.Map(ConvertToTimestampedValue_2)
+            | beam.Map(print))
+
+  def test_opts_with_check_typevar(self):
+    with self.assertRaises(RuntimeError):
+      with beam.Pipeline(options=self.opts) as p:
+        _ = (
+            p
+            | "Garden plants" >> beam.Create(self.data_2)
+            | "With timestamps" >> beam.Map(ConvertToTimestampedValue_3)
+            | beam.Map(print))
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py
index 3b8c3bb44e2..eaf9550820c 100644
--- a/sdks/python/apache_beam/transforms/window.py
+++ b/sdks/python/apache_beam/transforms/window.py
@@ -52,9 +52,11 @@ WindowFn.
 import abc
 from functools import total_ordering
 from typing import Any
+from typing import Generic
 from typing import Iterable
 from typing import List
 from typing import Optional
+from typing import TypeVar
 
 from google.protobuf import duration_pb2
 from google.protobuf import timestamp_pb2
@@ -278,8 +280,11 @@ class IntervalWindow(windowed_value._IntervalWindowBase, BoundedWindow):
         min(self.start, other.start), max(self.end, other.end))
 
 
+V = TypeVar("V")
+
+
 @total_ordering
-class TimestampedValue(object):
+class TimestampedValue(Generic[V]):
   """A timestamped value having a value and a timestamp.
 
   Attributes:
@@ -287,7 +292,7 @@ class TimestampedValue(object):
     timestamp: Timestamp associated with the value as seconds since Unix epoch.
   """
   def __init__(self, value, timestamp):
-    # type: (Any, TimestampTypes) -> None
+    # type: (V, TimestampTypes) -> None
     self.value = value
     self.timestamp = Timestamp.of(timestamp)
 
diff --git a/sdks/python/apache_beam/typehints/typecheck.py b/sdks/python/apache_beam/typehints/typecheck.py
index 2f202a14ab7..37216066bca 100644
--- a/sdks/python/apache_beam/typehints/typecheck.py
+++ b/sdks/python/apache_beam/typehints/typecheck.py
@@ -31,6 +31,7 @@ from apache_beam import pipeline
 from apache_beam.pvalue import TaggedOutput
 from apache_beam.transforms import core
 from apache_beam.transforms.core import DoFn
+from apache_beam.transforms.window import TimestampedValue
 from apache_beam.transforms.window import WindowedValue
 from apache_beam.typehints.decorators import GeneratorWrapper
 from apache_beam.typehints.decorators import TypeCheckError
@@ -39,6 +40,7 @@ from apache_beam.typehints.decorators import getcallargs_forhints
 from apache_beam.typehints.typehints import CompositeTypeHintError
 from apache_beam.typehints.typehints import SimpleTypeHintError
 from apache_beam.typehints.typehints import check_constraint
+from apache_beam.typehints.typehints import normalize
 
 
 class AbstractDoFnWrapper(DoFn):
@@ -146,9 +148,19 @@ class TypeCheckWrapperDoFn(AbstractDoFnWrapper):
       return transform_results
 
     def type_check_output(o):
-      # TODO(robertwb): Multi-output.
-      x = o.value if isinstance(o, (TaggedOutput, WindowedValue)) else o
-      self.type_check(self._output_type_hint, x, is_input=False)
+      if isinstance(o, TimestampedValue) and hasattr(o, "__orig_class__"):
+        # when a typed TimestampedValue is set, check the value type
+        x = o.value
+        # per https://stackoverflow.com/questions/57706180/,
+        # __orig_class__ is te safe way to obtain the actual type
+        # from from Generic[T], supported since Python 3.5.3
+        beam_type = normalize(o.__orig_class__.__args__[0])
+        self.type_check(beam_type, x, is_input=False)
+      else:
+        # TODO(robertwb): Multi-output.
+        x = o.value if isinstance(o, (TaggedOutput, WindowedValue)) else o
+
+        self.type_check(self._output_type_hint, x, is_input=False)
 
     # If the return type is a generator, then we will need to interleave our
     # type-checking with its normal iteration so we don't deplete the