You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/04/04 17:30:26 UTC
[beam] branch master updated: [BEAM-10529] nullable xlang coder (#16923)
This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 11cccb4da73 [BEAM-10529] nullable xlang coder (#16923)
11cccb4da73 is described below
commit 11cccb4da73a931e8c1eb7cf93af0e9a5b332914
Author: johnjcasey <95...@users.noreply.github.com>
AuthorDate: Mon Apr 4 13:30:20 2022 -0400
[BEAM-10529] nullable xlang coder (#16923)
* [BEAM-10529] add java and generic components of nullable xlang tests
* [BEAM-10529] fix test case
* [BEAM-10529] add coders and typehints to support nullable xlang coders
* [BEAM-10529] update external builder to support nullable coder
* [BEAM-10529] clean up coders.py
* [BEAM-10529] add coder translation test
* [BEAM-10529] add additional check to typecoder to not accidentally misidentify coders as nullable
* [BEAM-10529] add test to retrieve nullable coder from typehint
* [BEAM-10529] run spotless
* [BEAM-10529] add go nullable coder
* [BEAM-10529] cleanup extra println
* [BEAM-10529] improve comments, clean up python
* [BEAM-10529] remove changes to kafkaIO to simplify pr
* [BEAM-10529] add coders to go exec, add asf license text
* [BEAM-10529] clean up error handlign
* [BEAM-10529] update go fromyaml to handle nullable cases
* [BEAM-10529] add unit test, register nullable coder in dataflow.go
* [BEAM-10529] remove mistaken commit
* [BEAM-10529] add argument check to CoderTranslators
* [BEAM-10529] Address python comments & cleanup
* [BEAM-10529] address go comments
* [BEAM-10529] remove extra check that was added in error
* [BEAM-10529] fix typo
* [BEAM-10529] re-order check for nonetype to prevent attribute errors
* [BEAM-10529] change isinstance to ==
---
.../beam/model/fnexecution/v1/standard_coders.yaml | 12 +++
.../pipeline/src/main/proto/beam_runner_api.proto | 16 ++--
.../core/construction/CoderTranslators.java | 17 ++++
.../core/construction/ModelCoderRegistrar.java | 3 +
.../runners/core/construction/ModelCoders.java | 5 +-
.../core/construction/CoderTranslationTest.java | 2 +
.../runners/fnexecution/wire/CommonCoderTest.java | 17 +++-
sdks/go/pkg/beam/core/graph/coder/coder.go | 19 ++++-
sdks/go/pkg/beam/core/graph/coder/coder_test.go | 61 ++++++++++++++
sdks/go/pkg/beam/core/graph/coder/map.go | 32 -------
sdks/go/pkg/beam/core/graph/coder/map_test.go | 4 +-
sdks/go/pkg/beam/core/graph/coder/nil.go | 53 ++++++++++++
sdks/go/pkg/beam/core/graph/coder/nil_test.go | 98 ++++++++++++++++++++++
sdks/go/pkg/beam/core/graph/coder/row_decoder.go | 2 +-
sdks/go/pkg/beam/core/graph/coder/row_encoder.go | 2 +-
sdks/go/pkg/beam/core/runtime/exec/coder.go | 62 ++++++++++++++
sdks/go/pkg/beam/core/runtime/exec/coder_test.go | 6 ++
sdks/go/pkg/beam/core/runtime/graphx/coder.go | 19 ++++-
sdks/go/pkg/beam/core/runtime/graphx/coder_test.go | 4 +
sdks/go/pkg/beam/core/runtime/graphx/dataflow.go | 24 ++++++
sdks/go/pkg/beam/core/typex/fulltype.go | 2 +
sdks/go/pkg/beam/core/typex/special.go | 9 +-
.../go/test/regression/coders/fromyaml/fromyaml.go | 13 +++
sdks/python/apache_beam/coders/coders.py | 20 ++++-
.../apache_beam/coders/standard_coders_test.py | 4 +-
sdks/python/apache_beam/coders/typecoders.py | 3 +-
sdks/python/apache_beam/coders/typecoders_test.py | 7 ++
sdks/python/apache_beam/typehints/typehints.py | 21 ++++-
.../python/apache_beam/typehints/typehints_test.py | 8 ++
29 files changed, 492 insertions(+), 53 deletions(-)
diff --git a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
index 10fcaa5a6ed..7473f30ae6a 100644
--- a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
+++ b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml
@@ -569,3 +569,15 @@ coder:
examples:
"\u0080\u0000\u0001\u0052\u009a\u00a4\u009b\u0067\u0080\u0000\u0001\u0052\u009a\u00a4\u009b\u0068\u0080\u00dd\u00db\u0001" : {window: {end: 1454293425000, span: 3600000}}
"\u007f\u00df\u003b\u0064\u005a\u001c\u00ad\u0075\u007f\u00df\u003b\u0064\u005a\u001c\u00ad\u0076\u00ed\u0002" : {window: {end: -9223372036854410, span: 365}}
+
+
+---
+coder:
+ urn: "beam:coder:nullable:v1"
+ components: [{urn: "beam:coder:bytes:v1"}]
+nested: true
+
+examples:
+ "\u0001\u0003\u0061\u0062\u0063" : "abc"
+ "\u0001\u000a\u006d\u006f\u0072\u0065\u0020\u0062\u0079\u0074\u0065\u0073" : "more bytes"
+ "\u0000" : null
diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto
index 7fdb5aaf5e8..c1e318491f2 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -1043,11 +1043,8 @@ message StandardCoders {
// - Followed by N interleaved keys and values, encoded with their
// corresponding coder.
//
- // Nullable types in container types (ArrayType, MapType) are encoded by:
- // - A one byte null indicator, 0x00 for null values, or 0x01 for present
- // values.
- // - For present values the null indicator is followed by the value
- // encoded with it's corresponding coder.
+ // Nullable types in container types (ArrayType, MapType) per the
+ // encoding described for general Nullable types below.
//
// Well known logical types:
// beam:logical_type:micros_instant:v1
@@ -1085,6 +1082,15 @@ message StandardCoders {
// Components: the user key coder.
// Experimental.
SHARDED_KEY = 15 [(beam_urn) = "beam:coder:sharded_key:v1"];
+
+ // Wraps a coder of a potentially null value
+ // A Nullable Type is encoded by:
+ // - A one byte null indicator, 0x00 for null values, or 0x01 for present
+ // values.
+ // - For present values the null indicator is followed by the value
+ // encoded with it's corresponding coder.
+ // Components: single coder for the value
+ NULLABLE = 17 [(beam_urn) = "beam:coder:nullable:v1"];
}
}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
index 1838fa692e1..59d14b60862 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java
@@ -27,6 +27,7 @@ import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
import org.apache.beam.sdk.schemas.Schema;
@@ -204,6 +205,22 @@ class CoderTranslators {
};
}
+ static CoderTranslator<NullableCoder<?>> nullable() {
+ return new SimpleStructuredCoderTranslator<NullableCoder<?>>() {
+ @Override
+ protected NullableCoder<?> fromComponents(List<Coder<?>> components) {
+ checkArgument(
+ components.size() == 1, "Expected one component, but received: " + components);
+ return NullableCoder.of(components.get(0));
+ }
+
+ @Override
+ public List<? extends Coder<?>> getComponents(NullableCoder<?> from) {
+ return from.getComponents();
+ }
+ };
+ }
+
public abstract static class SimpleStructuredCoderTranslator<T extends Coder<?>>
implements CoderTranslator<T> {
@Override
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
index e0cc8dc11b9..eb94476d2b7 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java
@@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
@@ -73,6 +74,7 @@ public class ModelCoderRegistrar implements CoderTranslatorRegistrar {
.put(RowCoder.class, ModelCoders.ROW_CODER_URN)
.put(ShardedKey.Coder.class, ModelCoders.SHARDED_KEY_CODER_URN)
.put(TimestampPrefixingWindowCoder.class, ModelCoders.CUSTOM_WINDOW_CODER_URN)
+ .put(NullableCoder.class, ModelCoders.NULLABLE_CODER_URN)
.build();
private static final Map<Class<? extends Coder>, CoderTranslator<? extends Coder>>
@@ -96,6 +98,7 @@ public class ModelCoderRegistrar implements CoderTranslatorRegistrar {
.put(RowCoder.class, CoderTranslators.row())
.put(ShardedKey.Coder.class, CoderTranslators.shardedKey())
.put(TimestampPrefixingWindowCoder.class, CoderTranslators.timestampPrefixingWindow())
+ .put(NullableCoder.class, CoderTranslators.nullable())
.build();
static {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
index b616ffab462..bc0ec755f4c 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java
@@ -67,6 +67,8 @@ public class ModelCoders {
public static final String SHARDED_KEY_CODER_URN = getUrn(StandardCoders.Enum.SHARDED_KEY);
+ public static final String NULLABLE_CODER_URN = getUrn(StandardCoders.Enum.NULLABLE);
+
static {
checkState(
STATE_BACKED_ITERABLE_CODER_URN.equals(getUrn(StandardCoders.Enum.STATE_BACKED_ITERABLE)));
@@ -90,7 +92,8 @@ public class ModelCoders {
ROW_CODER_URN,
PARAM_WINDOWED_VALUE_CODER_URN,
STATE_BACKED_ITERABLE_CODER_URN,
- SHARDED_KEY_CODER_URN);
+ SHARDED_KEY_CODER_URN,
+ NULLABLE_CODER_URN);
public static Set<String> urns() {
return MODEL_CODER_URNS;
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
index 544f43d0f0f..f759ebede63 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java
@@ -42,6 +42,7 @@ import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
@@ -97,6 +98,7 @@ public class CoderTranslationTest {
Field.of("bar", FieldType.logicalType(FixedBytes.of(123))))))
.add(ShardedKey.Coder.of(StringUtf8Coder.of()))
.add(TimestampPrefixingWindowCoder.of(IntervalWindowCoder.of()))
+ .add(NullableCoder.of(ByteArrayCoder.of()))
.build();
/**
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/wire/CommonCoderTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/wire/CommonCoderTest.java
index c5f2283e41b..ca5274a358b 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/wire/CommonCoderTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/wire/CommonCoderTest.java
@@ -69,6 +69,7 @@ import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.IterableLikeCoder;
import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder;
@@ -134,6 +135,7 @@ public class CommonCoderTest {
.put(getUrn(StandardCoders.Enum.SHARDED_KEY), ShardedKey.Coder.class)
.put(getUrn(StandardCoders.Enum.CUSTOM_WINDOW), TimestampPrefixingWindowCoder.class)
.put(getUrn(StandardCoders.Enum.STATE_BACKED_ITERABLE), StateBackedIterable.Coder.class)
+ .put(getUrn(StandardCoders.Enum.NULLABLE), NullableCoder.class)
.build();
@AutoValue
@@ -201,7 +203,7 @@ public class CommonCoderTest {
@SuppressWarnings("mutable")
abstract byte[] getSerialized();
- abstract Object getValue();
+ abstract @Nullable Object getValue();
static OneCoderTestSpec create(
CommonCoder coder, boolean nested, byte[] serialized, Object value) {
@@ -382,6 +384,17 @@ public class CommonCoderTest {
Map<String, Object> kvMap = (Map<String, Object>) value;
Coder windowCoder = ((TimestampPrefixingWindowCoder) coder).getWindowCoder();
return convertValue(kvMap.get("window"), coderSpec.getComponents().get(0), windowCoder);
+ } else if (s.equals(getUrn(StandardCoders.Enum.NULLABLE))) {
+ if (coderSpec.getComponents().size() == 1
+ && coderSpec.getComponents().get(0).getUrn().equals(getUrn(StandardCoders.Enum.BYTES))) {
+ if (value == null) {
+ return null;
+ } else {
+ return ((String) value).getBytes(StandardCharsets.ISO_8859_1);
+ }
+ } else {
+ throw new IllegalStateException("Unknown or missing nested coder for nullable coder");
+ }
} else {
throw new IllegalStateException("Unknown coder URN: " + coderSpec.getUrn());
}
@@ -575,6 +588,8 @@ public class CommonCoderTest {
assertEquals(expectedValue, actualValue);
} else if (s.equals(getUrn(StandardCoders.Enum.CUSTOM_WINDOW))) {
assertEquals(expectedValue, actualValue);
+ } else if (s.equals(getUrn(StandardCoders.Enum.NULLABLE))) {
+ assertThat(expectedValue, equalTo(actualValue));
} else {
throw new IllegalStateException("Unknown coder URN: " + coder.getUrn());
}
diff --git a/sdks/go/pkg/beam/core/graph/coder/coder.go b/sdks/go/pkg/beam/core/graph/coder/coder.go
index 6eea66b0d31..8424f7af875 100644
--- a/sdks/go/pkg/beam/core/graph/coder/coder.go
+++ b/sdks/go/pkg/beam/core/graph/coder/coder.go
@@ -169,6 +169,7 @@ const (
VarInt Kind = "varint"
Double Kind = "double"
Row Kind = "R"
+ Nullable Kind = "N"
Timer Kind = "T"
PaneInfo Kind = "PI"
WindowedValue Kind = "W"
@@ -198,7 +199,7 @@ type Coder struct {
Kind Kind
T typex.FullType
- Components []*Coder // WindowedValue, KV, CoGBK
+ Components []*Coder // WindowedValue, KV, CoGBK, Nullable
Custom *CustomCoder // Custom
Window *WindowCoder // WindowedValue
@@ -260,7 +261,7 @@ func (c *Coder) String() string {
switch c.Kind {
case WindowedValue, ParamWindowedValue, Window, Timer:
ret += fmt.Sprintf("!%v", c.Window)
- case KV, CoGBK, Bytes, Bool, VarInt, Double, String, LP: // No additional info.
+ case KV, CoGBK, Bytes, Bool, VarInt, Double, String, LP, Nullable: // No additional info.
default:
ret += fmt.Sprintf("[%v]", c.T)
}
@@ -394,6 +395,20 @@ func NewKV(components []*Coder) *Coder {
}
}
+func NewN(component *Coder) *Coder {
+ coders := []*Coder{component}
+ checkCodersNotNil(coders)
+ return &Coder{
+ Kind: Nullable,
+ T: typex.New(typex.NullableType, component.T),
+ Components: coders,
+ }
+}
+
+func IsNullable(c *Coder) bool {
+ return c.Kind == Nullable
+}
+
// IsCoGBK returns true iff the coder is for a CoGBK type.
func IsCoGBK(c *Coder) bool {
return c.Kind == CoGBK
diff --git a/sdks/go/pkg/beam/core/graph/coder/coder_test.go b/sdks/go/pkg/beam/core/graph/coder/coder_test.go
index 762ed848f58..44606dc1efb 100644
--- a/sdks/go/pkg/beam/core/graph/coder/coder_test.go
+++ b/sdks/go/pkg/beam/core/graph/coder/coder_test.go
@@ -168,6 +168,9 @@ func TestCoder_String(t *testing.T) {
}, {
want: "KV<bytes,varint>",
c: NewKV([]*Coder{bytes, ints}),
+ }, {
+ want: "N<bytes>",
+ c: NewN(bytes),
}, {
want: "CoGBK<bytes,varint,bytes>",
c: NewCoGBK([]*Coder{bytes, ints, bytes}),
@@ -277,6 +280,10 @@ func TestCoder_Equals(t *testing.T) {
want: true,
a: NewKV([]*Coder{custom1, ints}),
b: NewKV([]*Coder{customSame, ints}),
+ }, {
+ want: true,
+ a: NewN(custom1),
+ b: NewN(customSame),
}, {
want: true,
a: NewCoGBK([]*Coder{custom1, ints, customSame}),
@@ -517,6 +524,60 @@ func TestNewKV(t *testing.T) {
}
}
+func TestNewNullable(t *testing.T) {
+ bytes := NewBytes()
+
+ tests := []struct {
+ name string
+ component *Coder
+ shouldpanic bool
+ want *Coder
+ }{
+ {
+ name: "nil",
+ component: nil,
+ shouldpanic: true,
+ },
+ {
+ name: "empty",
+ component: &Coder{},
+ shouldpanic: true,
+ },
+ {
+ name: "bytes",
+ component: bytes,
+ shouldpanic: false,
+ want: &Coder{
+ Kind: Nullable,
+ T: typex.New(typex.NullableType, bytes.T),
+ Components: []*Coder{bytes},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ if test.shouldpanic {
+ defer func() {
+ if p := recover(); p != nil {
+ t.Log(p)
+ return
+ }
+ t.Fatalf("NewNullable(%v): want panic", test.component)
+ }()
+ }
+ got := NewN(test.component)
+ if !IsNullable(got) {
+ t.Errorf("IsNullable(%v) = false, want true", got)
+ }
+ if test.want != nil && !test.want.Equals(got) {
+ t.Fatalf("NewNullable(%v) = %v, want %v", test.component, got, test.want)
+ }
+ })
+ }
+}
+
func TestNewCoGBK(t *testing.T) {
bytes := NewBytes()
ints := NewVarInt()
diff --git a/sdks/go/pkg/beam/core/graph/coder/map.go b/sdks/go/pkg/beam/core/graph/coder/map.go
index 2d72446bf44..30eee7b008a 100644
--- a/sdks/go/pkg/beam/core/graph/coder/map.go
+++ b/sdks/go/pkg/beam/core/graph/coder/map.go
@@ -62,24 +62,6 @@ func mapDecoder(rt reflect.Type, decodeToKey, decodeToElem typeDecoderFieldRefle
}
}
-// containerNilDecoder handles when a value is nillable for map or iterable components.
-// Nillable types have an extra byte prefixing them indicating nil status.
-func containerNilDecoder(decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error {
- return func(ret reflect.Value, r io.Reader) error {
- hasValue, err := DecodeBool(r)
- if err != nil {
- return err
- }
- if !hasValue {
- return nil
- }
- if err := decodeToElem(ret, r); err != nil {
- return err
- }
- return nil
- }
-}
-
// mapEncoder reflectively encodes a map or array type using the beam map encoding.
func mapEncoder(rt reflect.Type, encodeKey, encodeValue typeEncoderFieldReflect) func(reflect.Value, io.Writer) error {
return func(rv reflect.Value, w io.Writer) error {
@@ -132,17 +114,3 @@ func mapEncoder(rt reflect.Type, encodeKey, encodeValue typeEncoderFieldReflect)
return nil
}
}
-
-// containerNilEncoder handles when a value is nillable for map or iterable components.
-// Nillable types have an extra byte prefixing them indicating nil status.
-func containerNilEncoder(encodeElem func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error {
- return func(rv reflect.Value, w io.Writer) error {
- if rv.IsNil() {
- return EncodeBool(false, w)
- }
- if err := EncodeBool(true, w); err != nil {
- return err
- }
- return encodeElem(rv, w)
- }
-}
diff --git a/sdks/go/pkg/beam/core/graph/coder/map_test.go b/sdks/go/pkg/beam/core/graph/coder/map_test.go
index ee4c35afa60..3291f7fbd5a 100644
--- a/sdks/go/pkg/beam/core/graph/coder/map_test.go
+++ b/sdks/go/pkg/beam/core/graph/coder/map_test.go
@@ -38,8 +38,8 @@ func TestEncodeDecodeMap(t *testing.T) {
v.Set(reflect.New(reflectx.Uint8))
return byteDec(v.Elem(), r)
}
- byteCtnrPtrEnc := containerNilEncoder(bytePtrEnc)
- byteCtnrPtrDec := containerNilDecoder(bytePtrDec)
+ byteCtnrPtrEnc := NullableEncoder(bytePtrEnc)
+ byteCtnrPtrDec := NullableDecoder(bytePtrDec)
ptrByte := byte(42)
diff --git a/sdks/go/pkg/beam/core/graph/coder/nil.go b/sdks/go/pkg/beam/core/graph/coder/nil.go
new file mode 100644
index 00000000000..a7ed27cb6d5
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/nil.go
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+ "io"
+ "reflect"
+)
+
+// NullableDecoder handles when a value is nillable.
+// Nillable types have an extra byte prefixing them indicating nil status.
+func NullableDecoder(decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error {
+ return func(ret reflect.Value, r io.Reader) error {
+ hasValue, err := DecodeBool(r)
+ if err != nil {
+ return err
+ }
+ if !hasValue {
+ return nil
+ }
+ if err := decodeToElem(ret, r); err != nil {
+ return err
+ }
+ return nil
+ }
+}
+
+// NullableEncoder handles when a value is nillable.
+// Nillable types have an extra byte prefixing them indicating nil status.
+func NullableEncoder(encodeElem func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error {
+ return func(rv reflect.Value, w io.Writer) error {
+ if rv.IsNil() {
+ return EncodeBool(false, w)
+ }
+ if err := EncodeBool(true, w); err != nil {
+ return err
+ }
+ return encodeElem(rv, w)
+ }
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/nil_test.go b/sdks/go/pkg/beam/core/graph/coder/nil_test.go
new file mode 100644
index 00000000000..89410b9c939
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/nil_test.go
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "testing"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestEncodeDecodeNullable(t *testing.T) {
+ byteEnc := func(v reflect.Value, w io.Writer) error {
+ return EncodeByte(byte(v.Uint()), w)
+ }
+ byteDec := func(v reflect.Value, r io.Reader) error {
+ b, err := DecodeByte(r)
+ if err != nil {
+ return errors.Wrap(err, "error decoding single byte field")
+ }
+ v.SetUint(uint64(b))
+ return nil
+ }
+ bytePtrEnc := func(v reflect.Value, w io.Writer) error {
+ return byteEnc(v.Elem(), w)
+ }
+ bytePtrDec := func(v reflect.Value, r io.Reader) error {
+ v.Set(reflect.New(reflectx.Uint8))
+ return byteDec(v.Elem(), r)
+ }
+ byteCtnrPtrEnc := NullableEncoder(bytePtrEnc)
+ byteCtnrPtrDec := NullableDecoder(bytePtrDec)
+
+ tests := []struct {
+ decoded interface{}
+ encoded []byte
+ }{
+ {
+ decoded: (*byte)(nil),
+ encoded: []byte{0},
+ },
+ {
+ decoded: create(10),
+ encoded: []byte{1, 10},
+ },
+ {
+ decoded: create(20),
+ encoded: []byte{1, 20},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("encode %q", test.encoded), func(t *testing.T) {
+ var buf bytes.Buffer
+ encErr := byteCtnrPtrEnc(reflect.ValueOf(test.decoded), &buf)
+ if encErr != nil {
+ t.Fatalf("NullableEncoder(%q) = %v", test.decoded, encErr)
+ }
+ if d := cmp.Diff(test.encoded, buf.Bytes()); d != "" {
+ t.Errorf("NullableEncoder(%q) = %v, want %v diff(-want,+got):\n %v", test.decoded, buf.Bytes(), test.encoded, d)
+ }
+ })
+ t.Run(fmt.Sprintf("decode %q", test.decoded), func(t *testing.T) {
+ buf := bytes.NewBuffer(test.encoded)
+ rv := reflect.New(reflect.TypeOf(test.decoded)).Elem()
+ decErr := byteCtnrPtrDec(rv, buf)
+ if decErr != nil {
+ t.Fatalf("NullableDecoder(%q) = %v", test.encoded, decErr)
+ }
+ if d := cmp.Diff(test.decoded, rv.Interface()); d != "" {
+ t.Errorf("NullableDecoder (%q) = %q, want %v diff(-want,+got):\n %v", test.encoded, rv.Interface(), test.decoded, d)
+ }
+ })
+ }
+
+}
+
+func create(x byte) *byte {
+ return &x
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/row_decoder.go b/sdks/go/pkg/beam/core/graph/coder/row_decoder.go
index 1e1fcad3215..9688ed9876c 100644
--- a/sdks/go/pkg/beam/core/graph/coder/row_decoder.go
+++ b/sdks/go/pkg/beam/core/graph/coder/row_decoder.go
@@ -386,7 +386,7 @@ func (b *RowDecoderBuilder) containerDecoderForType(t reflect.Type) (typeDecoder
return typeDecoderFieldReflect{}, err
}
if t.Kind() == reflect.Ptr {
- return typeDecoderFieldReflect{decode: containerNilDecoder(dec.decode), addr: dec.addr}, nil
+ return typeDecoderFieldReflect{decode: NullableDecoder(dec.decode), addr: dec.addr}, nil
}
return dec, nil
}
diff --git a/sdks/go/pkg/beam/core/graph/coder/row_encoder.go b/sdks/go/pkg/beam/core/graph/coder/row_encoder.go
index e12776459da..cfc1a8e51a3 100644
--- a/sdks/go/pkg/beam/core/graph/coder/row_encoder.go
+++ b/sdks/go/pkg/beam/core/graph/coder/row_encoder.go
@@ -262,7 +262,7 @@ func (b *RowEncoderBuilder) containerEncoderForType(t reflect.Type) (typeEncoder
return typeEncoderFieldReflect{}, err
}
if t.Kind() == reflect.Ptr {
- return typeEncoderFieldReflect{encode: containerNilEncoder(encf.encode), addr: encf.addr}, nil
+ return typeEncoderFieldReflect{encode: NullableEncoder(encf.encode), addr: encf.addr}, nil
}
return encf, nil
}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/coder.go b/sdks/go/pkg/beam/core/runtime/exec/coder.go
index c7a19eae047..145209a492c 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/coder.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/coder.go
@@ -156,6 +156,12 @@ func MakeElementEncoder(c *coder.Coder) ElementEncoder {
enc: enc,
}
+ case coder.Nullable:
+ return &nullableEncoder{
+ inner: MakeElementEncoder(c.Components[0]),
+ be: boolEncoder{},
+ }
+
default:
panic(fmt.Sprintf("Unexpected coder: %v", c))
}
@@ -267,6 +273,12 @@ func MakeElementDecoder(c *coder.Coder) ElementDecoder {
dec: dec,
}
+ case coder.Nullable:
+ return &nullableDecoder{
+ inner: MakeElementDecoder(c.Components[0]),
+ bd: boolDecoder{},
+ }
+
default:
panic(fmt.Sprintf("Unexpected coder: %v", c))
}
@@ -609,6 +621,56 @@ func convertIfNeeded(v interface{}, allocated *FullValue) *FullValue {
return allocated
}
+type nullableEncoder struct {
+ inner ElementEncoder
+ be boolEncoder
+}
+
+func (n *nullableEncoder) Encode(value *FullValue, writer io.Writer) error {
+ if value.Elm == nil {
+ if err := n.be.Encode(&FullValue{Elm: false}, writer); err != nil {
+ return err
+ }
+ return nil
+ }
+ if err := n.be.Encode(&FullValue{Elm: true}, writer); err != nil {
+ return err
+ }
+ if err := n.inner.Encode(value, writer); err != nil {
+ return err
+ }
+ return nil
+}
+
+type nullableDecoder struct {
+ inner ElementDecoder
+ bd boolDecoder
+}
+
+func (n *nullableDecoder) Decode(reader io.Reader) (*FullValue, error) {
+ hasValue, err := n.bd.Decode(reader)
+ if err != nil {
+ return nil, err
+ }
+ if !hasValue.Elm.(bool) {
+ return &FullValue{}, nil
+ }
+ val, err := n.inner.Decode(reader)
+ if err != nil {
+ return nil, err
+ }
+ return val, nil
+}
+
+func (n *nullableDecoder) DecodeTo(reader io.Reader, value *FullValue) error {
+ val, err := n.Decode(reader)
+ if err != nil {
+ return err
+ }
+ value.Elm = val.Elm
+ return nil
+}
+
type iterableEncoder struct {
t reflect.Type
enc ElementEncoder
diff --git a/sdks/go/pkg/beam/core/runtime/exec/coder_test.go b/sdks/go/pkg/beam/core/runtime/exec/coder_test.go
index 02d1f81da7e..7ee13ecaf9e 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/coder_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/coder_test.go
@@ -80,6 +80,12 @@ func TestCoders(t *testing.T) {
}, {
coder: coder.NewPW(coder.NewString(), coder.NewGlobalWindow()),
val: &FullValue{Elm: "myString" /*Windowing info isn't encoded for PW so we can omit it here*/},
+ }, {
+ coder: coder.NewN(coder.NewBytes()),
+ val: &FullValue{},
+ }, {
+ coder: coder.NewN(coder.NewBytes()),
+ val: &FullValue{Elm: []byte("myBytes")},
},
} {
t.Run(fmt.Sprintf("%v", test.coder), func(t *testing.T) {
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/coder.go b/sdks/go/pkg/beam/core/runtime/graphx/coder.go
index a8b89753823..e70567d705b 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/coder.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/coder.go
@@ -46,6 +46,7 @@ const (
urnParamWindowedValueCoder = "beam:coder:param_windowed_value:v1"
urnTimerCoder = "beam:coder:timer:v1"
urnRowCoder = "beam:coder:row:v1"
+ urnNullableCoder = "beam:coder:nullable:v1"
urnGlobalWindow = "beam:coder:global_window:v1"
urnIntervalWindow = "beam:coder:interval_window:v1"
@@ -71,6 +72,7 @@ func knownStandardCoders() []string {
urnGlobalWindow,
urnIntervalWindow,
urnRowCoder,
+ urnNullableCoder,
// TODO(BEAM-10660): Add urnTimerCoder once finalized.
}
}
@@ -368,6 +370,15 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder,
return nil, err
}
return coder.NewR(typex.New(t)), nil
+ case urnNullableCoder:
+ if len(components) != 1 {
+ return nil, errors.Errorf("could not unmarshal nullable coder from %v, expected one component but got %d", c, len(components))
+ }
+ elm, err := b.Coder(components[0])
+ if err != nil {
+ return nil, err
+ }
+ return coder.NewN(elm), nil
// Special handling for window coders so they can be treated as
// a general coder. Generally window coders are not used outside of
@@ -386,7 +397,6 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder,
return nil, err
}
return &coder.Coder{Kind: coder.Window, T: typex.New(reflect.TypeOf((*struct{})(nil)).Elem()), Window: w}, nil
-
default:
return nil, errors.Errorf("could not unmarshal coder from %v, unknown URN %v", c, urn)
}
@@ -465,6 +475,13 @@ func (b *CoderMarshaller) Add(c *coder.Coder) (string, error) {
}
return b.internBuiltInCoder(urnKVCoder, comp...), nil
+ case coder.Nullable:
+ comp, err := b.AddMulti(c.Components)
+ if err != nil {
+ return "", errors.Wrapf(err, "failed to marshal Nullable coder %v", c)
+ }
+ return b.internBuiltInCoder(urnNullableCoder, comp...), nil
+
case coder.CoGBK:
comp, err := b.AddMulti(c.Components)
if err != nil {
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/coder_test.go b/sdks/go/pkg/beam/core/runtime/graphx/coder_test.go
index 8296c9fd338..aad15df0f23 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/coder_test.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/coder_test.go
@@ -88,6 +88,10 @@ func TestMarshalUnmarshalCoders(t *testing.T) {
"W<bytes>",
coder.NewW(coder.NewBytes(), coder.NewGlobalWindow()),
},
+ {
+ "N<bytes>",
+ coder.NewN(coder.NewBytes()),
+ },
{
"KV<foo,bar>",
coder.NewKV([]*coder.Coder{foo, bar}),
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/dataflow.go b/sdks/go/pkg/beam/core/runtime/graphx/dataflow.go
index 77aa6ca46a5..e2eec3b5bcc 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/dataflow.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/dataflow.go
@@ -48,6 +48,7 @@ const (
doubleType = "kind:double"
streamType = "kind:stream"
pairType = "kind:pair"
+ nullableType = "kind:nullable"
lengthPrefixType = "kind:length_prefix"
rowType = "kind:row"
@@ -117,6 +118,16 @@ func EncodeCoderRef(c *coder.Coder) (*CoderRef, error) {
}
return &CoderRef{Type: pairType, Components: []*CoderRef{key, value}, IsPairLike: true}, nil
+ case coder.Nullable:
+ if len(c.Components) != 1 {
+ return nil, errors.Errorf("bad N: %v", c)
+ }
+ innerref, err := EncodeCoderRef(c.Components[0])
+ if err != nil {
+ return nil, err
+ }
+ return &CoderRef{Type: nullableType, Components: []*CoderRef{innerref}}, nil
+
case coder.CoGBK:
if len(c.Components) < 2 {
return nil, errors.Errorf("bad CoGBK: %v", c)
@@ -264,6 +275,19 @@ func DecodeCoderRef(c *CoderRef) (*coder.Coder, error) {
t := typex.New(root, key.T, value.T)
return &coder.Coder{Kind: kind, T: t, Components: []*coder.Coder{key, value}}, nil
+ case nullableType:
+ if len(c.Components) != 1 {
+ return nil, errors.Errorf("bad nullable: %+v", c)
+ }
+
+ inner, err := DecodeCoderRef(c.Components[0])
+ if err != nil {
+ return nil, err
+ }
+
+ t := typex.New(typex.NullableType, inner.T)
+ return &coder.Coder{Kind: coder.Nullable, T: t, Components: []*coder.Coder{inner}}, nil
+
case lengthPrefixType:
if len(c.Components) != 1 {
return nil, errors.Errorf("bad length prefix: %+v", c)
diff --git a/sdks/go/pkg/beam/core/typex/fulltype.go b/sdks/go/pkg/beam/core/typex/fulltype.go
index cbfb443755b..df5425a4e1a 100644
--- a/sdks/go/pkg/beam/core/typex/fulltype.go
+++ b/sdks/go/pkg/beam/core/typex/fulltype.go
@@ -87,6 +87,8 @@ func printShortComposite(t reflect.Type) string {
return "CoGBK"
case KVType:
return "KV"
+ case NullableType:
+ return "Nullable"
default:
return fmt.Sprintf("invalid(%v)", t)
}
diff --git a/sdks/go/pkg/beam/core/typex/special.go b/sdks/go/pkg/beam/core/typex/special.go
index d13aab562a9..b45cb61081b 100644
--- a/sdks/go/pkg/beam/core/typex/special.go
+++ b/sdks/go/pkg/beam/core/typex/special.go
@@ -38,9 +38,10 @@ var (
WindowType = reflect.TypeOf((*Window)(nil)).Elem()
PaneInfoType = reflect.TypeOf((*PaneInfo)(nil)).Elem()
- KVType = reflect.TypeOf((*KV)(nil)).Elem()
- CoGBKType = reflect.TypeOf((*CoGBK)(nil)).Elem()
- WindowedValueType = reflect.TypeOf((*WindowedValue)(nil)).Elem()
+ KVType = reflect.TypeOf((*KV)(nil)).Elem()
+ NullableType = reflect.TypeOf((*Nullable)(nil)).Elem()
+ CoGBKType = reflect.TypeOf((*CoGBK)(nil)).Elem()
+ WindowedValueType = reflect.TypeOf((*WindowedValue)(nil)).Elem()
BundleFinalizationType = reflect.TypeOf((*BundleFinalization)(nil)).Elem()
)
@@ -92,6 +93,8 @@ type PaneInfo struct {
type KV struct{}
+type Nullable struct{}
+
type CoGBK struct{}
type WindowedValue struct{}
diff --git a/sdks/go/test/regression/coders/fromyaml/fromyaml.go b/sdks/go/test/regression/coders/fromyaml/fromyaml.go
index 82d3e9fdb24..199ff4e2a91 100644
--- a/sdks/go/test/regression/coders/fromyaml/fromyaml.go
+++ b/sdks/go/test/regression/coders/fromyaml/fromyaml.go
@@ -218,6 +218,19 @@ func diff(c Coder, elem *exec.FullValue, eg yaml.MapItem) bool {
}
return pass
+ case "beam:coder:nullable:v1":
+ if elem.Elm == nil || eg.Value == nil {
+ got, want = elem.Elm, eg.Value
+ } else {
+ got = string(elem.Elm.([]byte))
+ switch egv := eg.Value.(type) {
+ case string:
+ want = egv
+ case []byte:
+ want = string(egv)
+ }
+ }
+
case "beam:coder:iterable:v1":
pass := true
gotrv := reflect.ValueOf(elem.Elm)
diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py
index 63add754d0f..fce397df626 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -607,6 +607,22 @@ class NullableCoder(FastCoder):
def to_type_hint(self):
return typehints.Optional[self._value_coder.to_type_hint()]
+ def _get_component_coders(self):
+ # type: () -> List[Coder]
+ return [self._value_coder]
+
+ @classmethod
+ def from_type_hint(cls, typehint, registry):
+ if typehints.is_nullable(typehint):
+ return cls(
+ registry.get_coder(
+ typehints.get_concrete_type_from_nullable(typehint)))
+ else:
+ raise TypeError(
+ 'Typehint is not of nullable type, '
+ 'and cannot be converted to a NullableCoder',
+ typehint)
+
def is_deterministic(self):
# type: () -> bool
return self._value_coder.is_deterministic()
@@ -619,6 +635,9 @@ class NullableCoder(FastCoder):
return hash(type(self)) + hash(self._value_coder)
+Coder.register_structured_urn(common_urns.coders.NULLABLE.urn, NullableCoder)
+
+
class VarIntCoder(FastCoder):
"""Variable-length integer coder."""
def _create_impl(self):
@@ -1524,7 +1543,6 @@ Coder.register_structured_urn(
class StateBackedIterableCoder(FastCoder):
-
DEFAULT_WRITE_THRESHOLD = 1
def __init__(
diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py
index aa925a3146f..e25e232597e 100644
--- a/sdks/python/apache_beam/coders/standard_coders_test.py
+++ b/sdks/python/apache_beam/coders/standard_coders_test.py
@@ -193,7 +193,9 @@ class StandardCodersTest(unittest.TestCase):
value_parser: ShardedKey(
key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')),
'beam:coder:custom_window:v1': lambda x,
- window_parser: window_parser(x['window'])
+ window_parser: window_parser(x['window']),
+ 'beam:coder:nullable:v1': lambda x,
+ value_parser: x.encode('utf-8') if x else None
}
def test_standard_coders(self):
diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py
index 3bf2d15e787..a66ebe52369 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -65,7 +65,6 @@ See apache_beam.typehints.decorators module for more details.
"""
# pytype: skip-file
-
from typing import Any
from typing import Dict
from typing import Iterable
@@ -138,6 +137,8 @@ class CoderRegistry(object):
return coders.IterableCoder.from_type_hint(typehint, self)
elif isinstance(typehint, typehints.ListConstraint):
return coders.ListCoder.from_type_hint(typehint, self)
+ elif typehints.is_nullable(typehint):
+ return coders.NullableCoder.from_type_hint(typehint, self)
elif typehint is None:
# In some old code, None is used for Any.
# TODO(robertwb): Clean this up.
diff --git a/sdks/python/apache_beam/coders/typecoders_test.py b/sdks/python/apache_beam/coders/typecoders_test.py
index 02f4565c5e2..f74483ad48d 100644
--- a/sdks/python/apache_beam/coders/typecoders_test.py
+++ b/sdks/python/apache_beam/coders/typecoders_test.py
@@ -140,6 +140,13 @@ class TypeCodersTest(unittest.TestCase):
self.assertIs(
list, type(expected_coder.decode(expected_coder.encode(values))))
+ def test_nullable_coder(self):
+ expected_coder = coders.NullableCoder(coders.BytesCoder())
+ real_coder = typecoders.registry.get_coder(typehints.Optional(bytes))
+ self.assertEqual(expected_coder, real_coder)
+ self.assertEqual(expected_coder.encode(None), real_coder.encode(None))
+ self.assertEqual(expected_coder.encode(b'abc'), real_coder.encode(b'abc'))
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py
index 9b4fc95a7d9..45c2366dd8b 100644
--- a/sdks/python/apache_beam/typehints/typehints.py
+++ b/sdks/python/apache_beam/typehints/typehints.py
@@ -503,10 +503,13 @@ class UnionHint(CompositeTypeHint):
return 'Union[%s]' % (
', '.join(sorted(_unified_repr(t) for t in self.union_types)))
- def _inner_types(self):
+ def inner_types(self):
for t in self.union_types:
yield t
+ def contains_type(self, maybe_type):
+ return maybe_type in self.union_types
+
def _consistent_with_check_(self, sub):
if isinstance(sub, UnionConstraint):
# A union type is compatible if every possible type is compatible.
@@ -601,6 +604,22 @@ class OptionalHint(UnionHint):
return Union[py_type, type(None)]
+def is_nullable(typehint):
+ return (
+ isinstance(typehint, UnionConstraint) and
+ typehint.contains_type(type(None)) and
+ len(list(typehint.inner_types())) == 2)
+
+
+def get_concrete_type_from_nullable(typehint):
+ if is_nullable(typehint):
+ for inner_type in typehint.inner_types():
+ if not type(None) == inner_type:
+ return inner_type
+ else:
+ raise TypeError('Typehint is not of nullable type', typehint)
+
+
class TupleHint(CompositeTypeHint):
"""A Tuple type-hint.
diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py
index b3a4d636e9b..8818639035f 100644
--- a/sdks/python/apache_beam/typehints/typehints_test.py
+++ b/sdks/python/apache_beam/typehints/typehints_test.py
@@ -320,6 +320,14 @@ class OptionalHintTestCase(TypeHintTestCase):
hint = typehints.Optional[int]
self.assertTrue(isinstance(hint, typehints.UnionHint.UnionConstraint))
+ def test_is_optional(self):
+ hint1 = typehints.Optional[int]
+ self.assertTrue(typehints.is_nullable(hint1))
+ hint2 = typehints.UnionConstraint({int, bytes})
+ self.assertFalse(typehints.is_nullable(hint2))
+ hint3 = typehints.UnionConstraint({int, bytes, type(None)})
+ self.assertFalse(typehints.is_nullable(hint3))
+
class TupleHintTestCase(TypeHintTestCase):
def test_getitem_invalid_ellipsis_type_param(self):