You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2020/12/28 22:24:14 UTC
[beam] branch master updated: [BEAM-9615] Initial Custom Schema
Coder Support (#13611)
This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 2ceef2d [BEAM-9615] Initial Custom Schema Coder Support (#13611)
2ceef2d is described below
commit 2ceef2d8955017ac4d5b3f0a069811f05f7131f2
Author: Robert Burke <lo...@users.noreply.github.com>
AuthorDate: Mon Dec 28 14:23:30 2020 -0800
[BEAM-9615] Initial Custom Schema Coder Support (#13611)
---
sdks/go/pkg/beam/core/graph/coder/map_test.go | 11 +-
sdks/go/pkg/beam/core/graph/coder/row.go | 337 +++-----------
sdks/go/pkg/beam/core/graph/coder/row_decoder.go | 308 +++++++++++++
sdks/go/pkg/beam/core/graph/coder/row_encoder.go | 271 ++++++++++++
sdks/go/pkg/beam/core/graph/coder/row_test.go | 488 ++++++++++++++++-----
.../pkg/beam/core/graph/coder/testutil/testutil.go | 154 +++++++
.../core/graph/coder/testutil/testutil_test.go | 201 +++++++++
7 files changed, 1379 insertions(+), 391 deletions(-)
diff --git a/sdks/go/pkg/beam/core/graph/coder/map_test.go b/sdks/go/pkg/beam/core/graph/coder/map_test.go
index 0b825c2..b3441af 100644
--- a/sdks/go/pkg/beam/core/graph/coder/map_test.go
+++ b/sdks/go/pkg/beam/core/graph/coder/map_test.go
@@ -22,15 +22,16 @@ import (
"reflect"
"testing"
- "github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx"
"github.com/google/go-cmp/cmp"
)
func TestEncodeDecodeMap(t *testing.T) {
- byteEnc := containerEncoderForType(reflectx.Uint8)
- byteDec := containerDecoderForType(reflectx.Uint8)
- bytePtrEnc := containerEncoderForType(reflect.PtrTo(reflectx.Uint8))
- bytePtrDec := containerDecoderForType(reflect.PtrTo(reflectx.Uint8))
+ byteEnc := func(v reflect.Value, w io.Writer) error {
+ return EncodeByte(byte(v.Uint()), w)
+ }
+ byteDec := reflectDecodeByte
+ bytePtrEnc := containerNilEncoder(byteEnc)
+ bytePtrDec := containerNilDecoder(byteDec)
ptrByte := byte(42)
diff --git a/sdks/go/pkg/beam/core/graph/coder/row.go b/sdks/go/pkg/beam/core/graph/coder/row.go
index 714e150..6ee1b8b 100644
--- a/sdks/go/pkg/beam/core/graph/coder/row.go
+++ b/sdks/go/pkg/beam/core/graph/coder/row.go
@@ -30,10 +30,7 @@ import (
// Returns an error if the given type is invalid or not encodable to a beam
// schema row.
func RowEncoderForStruct(rt reflect.Type) (func(interface{}, io.Writer) error, error) {
- if err := rowTypeValidation(rt, true); err != nil {
- return nil, err
- }
- return encoderForType(rt), nil
+ return (&RowEncoderBuilder{}).Build(rt)
}
// RowDecoderForStruct returns a decoding function that decodes the beam row encoding
@@ -42,10 +39,7 @@ func RowEncoderForStruct(rt reflect.Type) (func(interface{}, io.Writer) error, e
// Returns an error if the given type is invalid or not decodable from a beam
// schema row.
func RowDecoderForStruct(rt reflect.Type) (func(io.Reader) (interface{}, error), error) {
- if err := rowTypeValidation(rt, true); err != nil {
- return nil, err
- }
- return decoderForType(rt), nil
+ return (&RowDecoderBuilder{}).Build(rt)
}
func rowTypeValidation(rt reflect.Type, strictExportedFields bool) error {
@@ -58,277 +52,6 @@ func rowTypeValidation(rt reflect.Type, strictExportedFields bool) error {
return nil
}
-// decoderForType returns a decoder function for the struct or pointer to struct type.
-func decoderForType(t reflect.Type) func(io.Reader) (interface{}, error) {
- var isPtr bool
- // Pointers become the value type for decomposition.
- if t.Kind() == reflect.Ptr {
- isPtr = true
- t = t.Elem()
- }
- dec := decoderForStructReflect(t)
-
- if isPtr {
- return func(r io.Reader) (interface{}, error) {
- rv := reflect.New(t)
- err := dec(rv.Elem(), r)
- return rv.Interface(), err
- }
- }
- return func(r io.Reader) (interface{}, error) {
- rv := reflect.New(t)
- err := dec(rv.Elem(), r)
- return rv.Elem().Interface(), err
- }
-}
-
-// decoderForSingleTypeReflect returns a reflection based decoder function for the
-// given type.
-func decoderForSingleTypeReflect(t reflect.Type) func(reflect.Value, io.Reader) error {
- switch t.Kind() {
- case reflect.Struct:
- return decoderForStructReflect(t)
- case reflect.Bool:
- return func(rv reflect.Value, r io.Reader) error {
- v, err := DecodeBool(r)
- if err != nil {
- return errors.Wrap(err, "error decoding bool field")
- }
- rv.SetBool(v)
- return nil
- }
- case reflect.Uint8:
- return func(rv reflect.Value, r io.Reader) error {
- b, err := DecodeByte(r)
- if err != nil {
- return errors.Wrap(err, "error decoding single byte field")
- }
- rv.SetUint(uint64(b))
- return nil
- }
- case reflect.String:
- return func(rv reflect.Value, r io.Reader) error {
- v, err := DecodeStringUTF8(r)
- if err != nil {
- return errors.Wrap(err, "error decoding string field")
- }
- rv.SetString(v)
- return nil
- }
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return func(rv reflect.Value, r io.Reader) error {
- v, err := DecodeVarInt(r)
- if err != nil {
- return errors.Wrap(err, "error decoding varint field")
- }
- rv.SetInt(v)
- return nil
- }
- case reflect.Float32, reflect.Float64:
- return func(rv reflect.Value, r io.Reader) error {
- v, err := DecodeDouble(r)
- if err != nil {
- return errors.Wrap(err, "error decoding double field")
- }
- rv.SetFloat(v)
- return nil
- }
- case reflect.Ptr:
- decf := decoderForSingleTypeReflect(t.Elem())
- return func(rv reflect.Value, r io.Reader) error {
- nv := reflect.New(t.Elem())
- rv.Set(nv)
- return decf(nv.Elem(), r)
- }
- case reflect.Slice:
- // Special case handling for byte slices.
- if t.Elem().Kind() == reflect.Uint8 {
- return func(rv reflect.Value, r io.Reader) error {
- b, err := DecodeBytes(r)
- if err != nil {
- return errors.Wrap(err, "error decoding []byte field")
- }
- rv.SetBytes(b)
- return nil
- }
- }
- decf := containerDecoderForType(t.Elem())
- return iterableDecoderForSlice(t, decf)
- case reflect.Array:
- decf := containerDecoderForType(t.Elem())
- return iterableDecoderForArray(t, decf)
- case reflect.Map:
- decK := containerDecoderForType(t.Key())
- decV := containerDecoderForType(t.Elem())
- return mapDecoder(t, decK, decV)
- }
- panic(fmt.Sprintf("unimplemented type to decode: %v", t))
-}
-
-func containerDecoderForType(t reflect.Type) func(reflect.Value, io.Reader) error {
- if t.Kind() == reflect.Ptr {
- return containerNilDecoder(decoderForSingleTypeReflect(t.Elem()))
- }
- return decoderForSingleTypeReflect(t)
-}
-
-type typeDecoderReflect struct {
- typ reflect.Type
- fields []func(reflect.Value, io.Reader) error
-}
-
-// decoderForStructReflect returns a reflection based decoder function for the
-// given struct type.
-func decoderForStructReflect(t reflect.Type) func(reflect.Value, io.Reader) error {
- var coder typeDecoderReflect
- for i := 0; i < t.NumField(); i++ {
- i := i // avoid alias issues in the closures.
- dec := decoderForSingleTypeReflect(t.Field(i).Type)
- coder.fields = append(coder.fields, func(rv reflect.Value, r io.Reader) error {
- return dec(rv.Field(i), r)
- })
- }
-
- return func(rv reflect.Value, r io.Reader) error {
- nf, nils, err := readRowHeader(rv, r)
- if err != nil {
- return err
- }
- if nf != len(coder.fields) {
- return errors.Errorf("schema[%v] changed: got %d fields, want %d fields", "TODO", nf, len(coder.fields))
- }
- for i, f := range coder.fields {
- if isFieldNil(nils, i) {
- continue
- }
- if err := f(rv, r); err != nil {
- return err
- }
- }
- return nil
- }
-}
-
-// isFieldNil examines the passed in packed bits nils buffer
-// and returns true if the field at that index wasn't encoded
-// and can be skipped in decoding.
-func isFieldNil(nils []byte, f int) bool {
- i, b := f/8, f%8
- return len(nils) != 0 && (nils[i]>>uint8(b))&0x1 == 1
-}
-
-// encoderForType returns an encoder function for the struct or pointer to struct type.
-func encoderForType(t reflect.Type) func(interface{}, io.Writer) error {
- var isPtr bool
- // Pointers become the value type for decomposition.
- if t.Kind() == reflect.Ptr {
- isPtr = true
- t = t.Elem()
- }
- enc := encoderForStructReflect(t)
-
- if isPtr {
- return func(v interface{}, w io.Writer) error {
- return enc(reflect.ValueOf(v).Elem(), w)
- }
- }
- return func(v interface{}, w io.Writer) error {
- return enc(reflect.ValueOf(v), w)
- }
-}
-
-// Generates coder using reflection for
-func encoderForSingleTypeReflect(t reflect.Type) func(reflect.Value, io.Writer) error {
- switch t.Kind() {
- case reflect.Struct:
- return encoderForStructReflect(t)
- case reflect.Bool:
- return func(rv reflect.Value, w io.Writer) error {
- return EncodeBool(rv.Bool(), w)
- }
- case reflect.Uint8:
- return func(rv reflect.Value, w io.Writer) error {
- return EncodeByte(byte(rv.Uint()), w)
- }
- case reflect.String:
- return func(rv reflect.Value, w io.Writer) error {
- return EncodeStringUTF8(rv.String(), w)
- }
- case reflect.Int, reflect.Int64, reflect.Int16, reflect.Int32, reflect.Int8:
- return func(rv reflect.Value, w io.Writer) error {
- return EncodeVarInt(int64(rv.Int()), w)
- }
- case reflect.Float32, reflect.Float64:
- return func(rv reflect.Value, w io.Writer) error {
- return EncodeDouble(float64(rv.Float()), w)
- }
- case reflect.Ptr:
- // Nils are handled at the struct field level.
- encf := encoderForSingleTypeReflect(t.Elem())
- return func(rv reflect.Value, w io.Writer) error {
- return encf(rv.Elem(), w)
- }
- case reflect.Slice:
- // Special case handling for byte slices.
- if t.Elem().Kind() == reflect.Uint8 {
- return func(rv reflect.Value, w io.Writer) error {
- return EncodeBytes(rv.Bytes(), w)
- }
- }
- encf := containerEncoderForType(t.Elem())
- return iterableEncoder(t, encf)
- case reflect.Array:
- encf := containerEncoderForType(t.Elem())
- return iterableEncoder(t, encf)
- case reflect.Map:
- encK := containerEncoderForType(t.Key())
- encV := containerEncoderForType(t.Elem())
- return mapEncoder(t, encK, encV)
- }
- panic(fmt.Sprintf("unimplemented type to encode: %v", t))
-}
-
-func containerEncoderForType(t reflect.Type) func(reflect.Value, io.Writer) error {
- if t.Kind() == reflect.Ptr {
- return containerNilEncoder(encoderForSingleTypeReflect(t.Elem()))
- }
- return encoderForSingleTypeReflect(t)
-}
-
-type typeEncoderReflect struct {
- debug []string
- fields []func(reflect.Value, io.Writer) error
-}
-
-// encoderForStructReflect generates reflection field access closures for structs.
-func encoderForStructReflect(t reflect.Type) func(reflect.Value, io.Writer) error {
- var coder typeEncoderReflect
- for i := 0; i < t.NumField(); i++ {
- coder.debug = append(coder.debug, t.Field(i).Type.Name())
- coder.fields = append(coder.fields, encoderForSingleTypeReflect(t.Field(i).Type))
- }
-
- return func(rv reflect.Value, w io.Writer) error {
- // Row/Structs are prefixed with the number of fields that are encoded in total.
- if err := writeRowHeader(rv, w); err != nil {
- return err
- }
- for i, f := range coder.fields {
- rvf := rv.Field(i)
- switch rvf.Kind() {
- case reflect.Ptr, reflect.Map, reflect.Slice:
- if rvf.IsNil() {
- continue
- }
- }
- if err := f(rvf, w); err != nil {
- return errors.Wrapf(err, "encoding %v, expected: %v", rvf.Type(), coder.debug[i])
- }
- }
- return nil
- }
-}
-
// writeRowHeader handles the field header for row encodings.
func writeRowHeader(rv reflect.Value, w io.Writer) error {
// Row/Structs are prefixed with the number of fields that are encoded in total.
@@ -370,11 +93,15 @@ func writeRowHeader(rv reflect.Value, w io.Writer) error {
return nil
}
-// readRowHeader handles the field header for row decodings.
+// ReadRowHeader handles the field header for row decodings.
+//
+// This returns the number of encoded fileds, the raw bitpacked bytes and
+// any error during decoding. Each bit only needs only needs to be
+// examined once during decoding using the IsFieldNil helper function.
//
-// This returns the raw bitpacked byte slice because we only need to
-// examine each bit once, so we may as well do so inline with field checking.
-func readRowHeader(rv reflect.Value, r io.Reader) (int, []byte, error) {
+// If there are no nil fields encoded,the byte array will be nil, and no
+// encoded fields will be nil.
+func ReadRowHeader(r io.Reader) (int, []byte, error) {
nf, err := DecodeVarInt(r) // is for checksum purposes (old vs new versions of a schemas)
if err != nil {
return 0, nil, err
@@ -394,3 +121,47 @@ func readRowHeader(rv reflect.Value, r io.Reader) (int, []byte, error) {
}
return int(nf), nils, nil
}
+
+// IsFieldNil examines the passed in packed bits nils buffer
+// and returns true if the field at that index wasn't encoded
+// and can be skipped in decoding.
+func IsFieldNil(nils []byte, f int) bool {
+ i, b := f/8, f%8
+ return len(nils) != 0 && (nils[i]>>uint8(b))&0x1 == 1
+}
+
+// WriteSimpleRowHeader is a convenience function to write Beam Schema Row Headers
+// for values that do not have any nil fields. Writes the number of fields total
+// and a 0 len byte slice to indicate no fields are nil.
+func WriteSimpleRowHeader(fields int, w io.Writer) error {
+ if err := EncodeVarInt(int64(fields), w); err != nil {
+ return err
+ }
+ // Never nils, so we write the 0 byte header.
+ if err := EncodeVarInt(0, w); err != nil {
+ return fmt.Errorf("WriteSimpleRowHeader a 0 length nils bit field: %v", err)
+ }
+ return nil
+}
+
+// ReadSimpleRowHeader is a convenience function to read Beam Schema Row Headers
+// for values that do not have any nil fields. Reads and validates the number of
+// fields total (returning an error for mismatches, and checks that there are
+// no nils encoded as a bit field.
+func ReadSimpleRowHeader(fields int, r io.Reader) error {
+ n, err := DecodeVarInt(r)
+ if err != nil {
+ return fmt.Errorf("ReadSimpleRowHeader field count: %v, %v", n, err)
+ }
+ if int(n) != fields {
+ return fmt.Errorf("ReadSimpleRowHeader field count mismatch, got %v, want %v", n, fields)
+ }
+ n, err = DecodeVarInt(r)
+ if err != nil {
+ return fmt.Errorf("ReadSimpleRowHeader reading nils count: %v, %v", n, err)
+ }
+ if n != 0 {
+ return fmt.Errorf("ReadSimpleRowHeader expected no nils encoded count, got %v", n)
+ }
+ return nil
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/row_decoder.go b/sdks/go/pkg/beam/core/graph/coder/row_decoder.go
new file mode 100644
index 0000000..c950013
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/row_decoder.go
@@ -0,0 +1,308 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+
+ "github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
+)
+
+// RowDecoderBuilder allows one to build Beam Schema row encoders for provided types.
+type RowDecoderBuilder struct {
+ allFuncs map[reflect.Type]decoderProvider
+ ifaceFuncs []reflect.Type
+}
+
+type decoderProvider = func(reflect.Type) (func(io.Reader) (interface{}, error), error)
+
+// Register accepts a provider to decode schema encoded values
+// of that type.
+//
+// When decoding values, decoder functions produced by this builder will
+// first check for exact type matches, then interfaces implemented by
+// the type in recency order of registration, and then finally the
+// default Beam Schema encoding behavior.
+//
+// TODO(BEAM-9615): Add final factory types. This interface is subject to change.
+// Currently f must be a function func(reflect.Type) (func(io.Reader) (interface{}, error), error)
+func (b *RowDecoderBuilder) Register(rt reflect.Type, f interface{}) {
+ fd, ok := f.(decoderProvider)
+ if !ok {
+ panic(fmt.Sprintf("%T isn't a supported decoder function type (passed with %v)", f, rt))
+ }
+
+ if rt.Kind() == reflect.Interface && rt.NumMethod() == 0 {
+ panic(fmt.Sprintf("interface type %v must have methods", rt))
+ }
+
+ if b.allFuncs == nil {
+ b.allFuncs = make(map[reflect.Type]decoderProvider)
+ }
+ b.allFuncs[rt] = fd
+ if rt.Kind() == reflect.Interface {
+ b.ifaceFuncs = append(b.ifaceFuncs, rt)
+ }
+}
+
+// Build constructs a Beam Schema coder for the given type, using any providers registered for
+// itself or it's fields.
+func (b *RowDecoderBuilder) Build(rt reflect.Type) (func(io.Reader) (interface{}, error), error) {
+ if err := rowTypeValidation(rt, true); err != nil {
+ return nil, err
+ }
+ return b.decoderForType(rt)
+}
+
+// decoderForType returns a decoder function for the struct or pointer to struct type.
+func (b *RowDecoderBuilder) decoderForType(t reflect.Type) (func(io.Reader) (interface{}, error), error) {
+ // Check if there are any providers registered for this type, or that this type adheres to any interfaces.
+ f, err := b.customFunc(t)
+ if err != nil {
+ return nil, err
+ }
+ if f != nil {
+ return f, nil
+ }
+
+ var isPtr bool
+ // Pointers become the value type for decomposition.
+ if t.Kind() == reflect.Ptr {
+ isPtr = true
+ t = t.Elem()
+ }
+ dec, err := b.decoderForStructReflect(t)
+ if err != nil {
+ return nil, err
+ }
+
+ if isPtr {
+ return func(r io.Reader) (interface{}, error) {
+ rv := reflect.New(t)
+ err := dec(rv.Elem(), r)
+ return rv.Interface(), err
+ }, nil
+ }
+ return func(r io.Reader) (interface{}, error) {
+ rv := reflect.New(t)
+ err := dec(rv.Elem(), r)
+ return rv.Elem().Interface(), err
+ }, nil
+}
+
+// decoderForStructReflect returns a reflection based decoder function for the
+// given struct type.
+func (b *RowDecoderBuilder) decoderForStructReflect(t reflect.Type) (func(reflect.Value, io.Reader) error, error) {
+ var coder typeDecoderReflect
+ for i := 0; i < t.NumField(); i++ {
+ i := i // avoid alias issues in the closures.
+ dec, err := b.decoderForSingleTypeReflect(t.Field(i).Type)
+ if err != nil {
+ return nil, err
+ }
+ coder.fields = append(coder.fields, func(rv reflect.Value, r io.Reader) error {
+ return dec(rv.Field(i), r)
+ })
+ }
+ return func(rv reflect.Value, r io.Reader) error {
+ nf, nils, err := ReadRowHeader(r)
+ if err != nil {
+ return err
+ }
+ if nf != len(coder.fields) {
+ return errors.Errorf("schema[%v] changed: got %d fields, want %d fields", "TODO", nf, len(coder.fields))
+ }
+ for i, f := range coder.fields {
+ if IsFieldNil(nils, i) {
+ continue
+ }
+ if err := f(rv, r); err != nil {
+ return err
+ }
+ }
+ return nil
+ }, nil
+}
+
+func reflectDecodeBool(rv reflect.Value, r io.Reader) error {
+ v, err := DecodeBool(r)
+ if err != nil {
+ return errors.Wrap(err, "error decoding bool field")
+ }
+ rv.SetBool(v)
+ return nil
+}
+
+func reflectDecodeByte(rv reflect.Value, r io.Reader) error {
+ b, err := DecodeByte(r)
+ if err != nil {
+ return errors.Wrap(err, "error decoding single byte field")
+ }
+ rv.SetUint(uint64(b))
+ return nil
+}
+
+func reflectDecodeString(rv reflect.Value, r io.Reader) error {
+ v, err := DecodeStringUTF8(r)
+ if err != nil {
+ return errors.Wrap(err, "error decoding string field")
+ }
+ rv.SetString(v)
+ return nil
+}
+
+func reflectDecodeInt(rv reflect.Value, r io.Reader) error {
+ v, err := DecodeVarInt(r)
+ if err != nil {
+ return errors.Wrap(err, "error decoding varint field")
+ }
+ rv.SetInt(v)
+ return nil
+}
+
+func reflectDecodeFloat(rv reflect.Value, r io.Reader) error {
+ v, err := DecodeDouble(r)
+ if err != nil {
+ return errors.Wrap(err, "error decoding double field")
+ }
+ rv.SetFloat(v)
+ return nil
+}
+
+func reflectDecodeByteSlice(rv reflect.Value, r io.Reader) error {
+ b, err := DecodeBytes(r)
+ if err != nil {
+ return errors.Wrap(err, "error decoding []byte field")
+ }
+ rv.SetBytes(b)
+ return nil
+}
+
+// customFunc returns nil if no custom func exists for this type.
+// If an error is returned, coder construction should be aborted.
+func (b *RowDecoderBuilder) customFunc(t reflect.Type) (func(io.Reader) (interface{}, error), error) {
+ if fact, ok := b.allFuncs[t]; ok {
+ f, err := fact(t)
+
+ if err != nil {
+ return nil, err
+ }
+ return f, nil
+ }
+ // Check satisfaction of interface types in reverse registration order.
+ for i := len(b.ifaceFuncs) - 1; i >= 0; i-- {
+ it := b.ifaceFuncs[i]
+ if ok := t.AssignableTo(it); ok {
+ if fact, ok := b.allFuncs[it]; ok {
+ f, err := fact(t)
+ if err != nil {
+ return nil, err
+ }
+ return f, nil
+ }
+ }
+ }
+ return nil, nil
+}
+
+// decoderForSingleTypeReflect returns a reflection based decoder function for the
+// given type.
+func (b *RowDecoderBuilder) decoderForSingleTypeReflect(t reflect.Type) (func(reflect.Value, io.Reader) error, error) {
+ // Check if there are any providers registered for this type, or that this type adheres to any interfaces.
+ dec, err := b.customFunc(t)
+ if err != nil {
+ return nil, err
+ }
+ if dec != nil {
+ return func(v reflect.Value, r io.Reader) error {
+ elm, err := dec(r)
+ if err != nil {
+ return err
+ }
+ v.Set(reflect.ValueOf(elm))
+ return nil
+ }, nil
+ }
+ switch t.Kind() {
+ case reflect.Struct:
+ return b.decoderForStructReflect(t)
+ case reflect.Bool:
+ return reflectDecodeBool, nil
+ case reflect.Uint8:
+ return reflectDecodeByte, nil
+ case reflect.String:
+ return reflectDecodeString, nil
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return reflectDecodeInt, nil
+ case reflect.Float32, reflect.Float64:
+ return reflectDecodeFloat, nil
+ case reflect.Ptr:
+ decf, err := b.decoderForSingleTypeReflect(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return func(rv reflect.Value, r io.Reader) error {
+ nv := reflect.New(t.Elem())
+ rv.Set(nv)
+ return decf(nv.Elem(), r)
+ }, nil
+ case reflect.Slice:
+ // Special case handling for byte slices.
+ if t.Elem().Kind() == reflect.Uint8 {
+ return reflectDecodeByteSlice, nil
+ }
+ decf, err := b.containerDecoderForType(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return iterableDecoderForSlice(t, decf), nil
+ case reflect.Array:
+ decf, err := b.containerDecoderForType(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return iterableDecoderForArray(t, decf), nil
+ case reflect.Map:
+ decK, err := b.containerDecoderForType(t.Key())
+ if err != nil {
+ return nil, err
+ }
+ decV, err := b.containerDecoderForType(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return mapDecoder(t, decK, decV), nil
+ }
+ panic(fmt.Sprintf("unimplemented type to decode: %v", t))
+}
+
+func (b *RowDecoderBuilder) containerDecoderForType(t reflect.Type) (func(reflect.Value, io.Reader) error, error) {
+ if t.Kind() == reflect.Ptr {
+ dec, err := b.decoderForSingleTypeReflect(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return containerNilDecoder(dec), nil
+ }
+ return b.decoderForSingleTypeReflect(t)
+}
+
+type typeDecoderReflect struct {
+ typ reflect.Type
+ fields []func(reflect.Value, io.Reader) error
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/row_encoder.go b/sdks/go/pkg/beam/core/graph/coder/row_encoder.go
new file mode 100644
index 0000000..bfb872d
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/row_encoder.go
@@ -0,0 +1,271 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+
+ "github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
+)
+
+// RowEncoderBuilder allows one to build Beam Schema row encoders for provided types.
+type RowEncoderBuilder struct {
+ allFuncs map[reflect.Type]encoderProvider
+ ifaceFuncs []reflect.Type
+}
+
+type encoderProvider = func(reflect.Type) (func(interface{}, io.Writer) error, error)
+
+// Register accepts a provider for the given type to schema encode values of that type.
+//
+// When generating encoding functions, this builder will first check for exact type
+// matches, then against interfaces with registered factories in recency order of
+// registration, and then finally use the default Beam Schema encoding behavior.
+//
+// TODO(BEAM-9615): Add final factory types. This interface is subject to change.
+// Currently f must be a function of the type func(reflect.Type) func(T, io.Writer) (error).
+func (b *RowEncoderBuilder) Register(rt reflect.Type, f interface{}) {
+ fe, ok := f.(encoderProvider)
+ if !ok {
+ panic(fmt.Sprintf("%T isn't a supported encoder function type (passed with %v)", f, rt))
+ }
+
+ if rt.Kind() == reflect.Interface && rt.NumMethod() == 0 {
+ panic(fmt.Sprintf("interface type %v must have methods", rt))
+ }
+ if b.allFuncs == nil {
+ b.allFuncs = make(map[reflect.Type]encoderProvider)
+ }
+ b.allFuncs[rt] = fe
+ if rt.Kind() == reflect.Interface {
+ b.ifaceFuncs = append(b.ifaceFuncs, rt)
+ }
+}
+
+// Build constructs a Beam Schema coder for the given type, using any providers registered for
+// itself or it's fields.
+func (b *RowEncoderBuilder) Build(rt reflect.Type) (func(interface{}, io.Writer) error, error) {
+ if err := rowTypeValidation(rt, true); err != nil {
+ return nil, err
+ }
+ return b.encoderForType(rt)
+}
+
+// customFunc returns nil if no custom func exists for this type.
+// If an error is returned, coder construction should be aborted.
+func (b *RowEncoderBuilder) customFunc(t reflect.Type) (func(interface{}, io.Writer) error, error) {
+ if fact, ok := b.allFuncs[t]; ok {
+ f, err := fact(t)
+
+ if err != nil {
+ return nil, err
+ }
+ return f, err
+ }
+ // Check satisfaction of interface types in reverse registration order.
+ for i := len(b.ifaceFuncs) - 1; i >= 0; i-- {
+ it := b.ifaceFuncs[i]
+ if ok := t.AssignableTo(it); ok {
+ if fact, ok := b.allFuncs[it]; ok {
+ f, err := fact(t)
+ // TODO handle errors?
+ if err != nil {
+ return nil, err
+ }
+ return f, nil
+ }
+ }
+ }
+ return nil, nil
+}
+
+// encoderForType returns an encoder function for the struct or pointer to struct type.
+func (b *RowEncoderBuilder) encoderForType(t reflect.Type) (func(interface{}, io.Writer) error, error) {
+ // Check if there are any providers registered for this type, or that this type adheres to any interfaces.
+ var isPtr bool
+ // Pointers become the value type for decomposition.
+ if t.Kind() == reflect.Ptr {
+ // If we have something for the pointer version already, we're done.
+ enc, err := b.customFunc(t)
+ if err != nil {
+ return nil, err
+ }
+ if enc != nil {
+ return enc, nil
+ }
+ isPtr = true
+ t = t.Elem()
+ }
+
+ {
+ enc, err := b.customFunc(t)
+ if err != nil {
+ return nil, err
+ }
+ if enc != nil {
+ if isPtr {
+ // We have the value version, but not a pointer version, so we jump through reflect to
+ // get the right type to pass in.
+ return func(v interface{}, w io.Writer) error {
+ return enc(reflect.ValueOf(v).Elem().Interface(), w)
+ }, nil
+ }
+ return enc, nil
+ }
+ }
+
+ enc, err := b.encoderForStructReflect(t)
+ if err != nil {
+ return nil, err
+ }
+
+ if isPtr {
+ return func(v interface{}, w io.Writer) error {
+ return enc(reflect.ValueOf(v).Elem(), w)
+ }, nil
+ }
+ return func(v interface{}, w io.Writer) error {
+ return enc(reflect.ValueOf(v), w)
+ }, nil
+}
+
+// Generates coder using reflection for
+func (b *RowEncoderBuilder) encoderForSingleTypeReflect(t reflect.Type) (func(reflect.Value, io.Writer) error, error) {
+ // Check if there are any providers registered for this type, or that this type adheres to any interfaces.
+ enc, err := b.customFunc(t)
+ if err != nil {
+ return nil, err
+ }
+ if enc != nil {
+ return func(v reflect.Value, w io.Writer) error {
+ return enc(v.Interface(), w)
+ }, nil
+ }
+
+ switch t.Kind() {
+ case reflect.Struct:
+ return b.encoderForStructReflect(t)
+ case reflect.Bool:
+ return func(rv reflect.Value, w io.Writer) error {
+ return EncodeBool(rv.Bool(), w)
+ }, nil
+ case reflect.Uint8:
+ return func(rv reflect.Value, w io.Writer) error {
+ return EncodeByte(byte(rv.Uint()), w)
+ }, nil
+ case reflect.String:
+ return func(rv reflect.Value, w io.Writer) error {
+ return EncodeStringUTF8(rv.String(), w)
+ }, nil
+ case reflect.Int, reflect.Int64, reflect.Int16, reflect.Int32, reflect.Int8:
+ return func(rv reflect.Value, w io.Writer) error {
+ return EncodeVarInt(int64(rv.Int()), w)
+ }, nil
+ case reflect.Float32, reflect.Float64:
+ return func(rv reflect.Value, w io.Writer) error {
+ return EncodeDouble(float64(rv.Float()), w)
+ }, nil
+ case reflect.Ptr:
+ // Nils are handled at the struct field level.
+ encf, err := b.encoderForSingleTypeReflect(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return func(rv reflect.Value, w io.Writer) error {
+ return encf(rv.Elem(), w)
+ }, nil
+ case reflect.Slice:
+ // Special case handling for byte slices.
+ if t.Elem().Kind() == reflect.Uint8 {
+ return func(rv reflect.Value, w io.Writer) error {
+ return EncodeBytes(rv.Bytes(), w)
+ }, nil
+ }
+ encf, err := b.containerEncoderForType(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return iterableEncoder(t, encf), nil
+ case reflect.Array:
+ encf, err := b.containerEncoderForType(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return iterableEncoder(t, encf), nil
+ case reflect.Map:
+ encK, err := b.containerEncoderForType(t.Key())
+ if err != nil {
+ return nil, err
+ }
+ encV, err := b.containerEncoderForType(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return mapEncoder(t, encK, encV), nil
+ }
+ panic(fmt.Sprintf("unimplemented type to encode: %v", t))
+}
+
+func (b *RowEncoderBuilder) containerEncoderForType(t reflect.Type) (func(reflect.Value, io.Writer) error, error) {
+ if t.Kind() == reflect.Ptr {
+ encf, err := b.encoderForSingleTypeReflect(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return containerNilEncoder(encf), nil
+ }
+ return b.encoderForSingleTypeReflect(t)
+}
+
+type typeEncoderReflect struct {
+ debug []string
+ fields []func(reflect.Value, io.Writer) error
+}
+
+// encoderForStructReflect generates reflection field access closures for structs.
+func (b *RowEncoderBuilder) encoderForStructReflect(t reflect.Type) (func(reflect.Value, io.Writer) error, error) {
+ var coder typeEncoderReflect
+ for i := 0; i < t.NumField(); i++ {
+ coder.debug = append(coder.debug, t.Field(i).Type.Name())
+ enc, err := b.encoderForSingleTypeReflect(t.Field(i).Type)
+ if err != nil {
+ return nil, err
+ }
+ coder.fields = append(coder.fields, enc)
+ }
+
+ return func(rv reflect.Value, w io.Writer) error {
+ // Row/Structs are prefixed with the number of fields that are encoded in total.
+ if err := writeRowHeader(rv, w); err != nil {
+ return err
+ }
+ for i, f := range coder.fields {
+ rvf := rv.Field(i)
+ switch rvf.Kind() {
+ case reflect.Ptr, reflect.Map, reflect.Slice:
+ if rvf.IsNil() {
+ continue
+ }
+ }
+ if err := f(rvf, w); err != nil {
+ return errors.Wrapf(err, "encoding %v, expected: %v", rvf.Type(), coder.debug[i])
+ }
+ }
+ return nil
+ }, nil
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/row_test.go b/sdks/go/pkg/beam/core/graph/coder/row_test.go
index 38b7c5d..9a5b900 100644
--- a/sdks/go/pkg/beam/core/graph/coder/row_test.go
+++ b/sdks/go/pkg/beam/core/graph/coder/row_test.go
@@ -18,9 +18,11 @@ package coder
import (
"bytes"
"fmt"
+ "io"
"reflect"
"testing"
+ "github.com/apache/beam/sdks/go/pkg/beam/core/util/jsonx"
"github.com/google/go-cmp/cmp"
)
@@ -28,113 +30,123 @@ func TestReflectionRowCoderGeneration(t *testing.T) {
num := 35
tests := []struct {
want interface{}
- }{{
- // Top level value check
- want: UserType1{
- A: "cats",
- B: 24,
- C: "pjamas",
- },
- }, {
- // Top level pointer check
- want: &UserType1{
- A: "marmalade",
- B: 24,
- C: "jam",
- },
- }, {
- // Inner pointer check.
- want: UserType2{
- A: "dogs",
- B: &UserType1{
+ }{
+ {
+ // Top level value check
+ want: UserType1{
A: "cats",
B: 24,
C: "pjamas",
},
- C: &num,
- },
- }, {
- // nil pointer check.
- want: UserType2{
- A: "dogs",
- B: nil,
- C: nil,
- },
- }, {
- // All zeroes
- want: struct {
- V00 bool
- V01 byte // unsupported by spec (same as uint8)
- V02 uint8 // unsupported by spec
- V03 int16
- // V04 uint16 // unsupported by spec
- V05 int32
- // V06 uint32 // unsupported by spec
- V07 int64
- // V08 uint64 // unsupported by spec
- V09 int
- V10 struct{}
- V11 *struct{}
- V12 [0]int
- V13 [2]int
- V14 []int
- V15 map[string]int
- V16 float32
- V17 float64
- V18 []byte
- V19 [2]*int
- V20 map[*string]*int
- }{},
- }, {
- want: struct {
- V00 bool
- V01 byte // unsupported by spec (same as uint8)
- V02 uint8 // unsupported by spec
- V03 int16
- // V04 uint16 // unsupported by spec
- V05 int32
- // V06 uint32 // unsupported by spec
- V07 int64
- // V08 uint64 // unsupported by spec
- V09 int
- V10 struct{}
- V11 *struct{}
- V12 [0]int
- V13 [2]int
- V14 []int
- V15 map[string]int
- V16 float32
- V17 float64
- V18 []byte
- V19 [2]*int
- V20 map[string]*int
- V21 []*int
- }{
- V00: true,
- V01: 1,
- V02: 2,
- V03: 3,
- V05: 5,
- V07: 7,
- V09: 9,
- V10: struct{}{},
- V11: &struct{}{},
- V12: [0]int{},
- V13: [2]int{72, 908},
- V14: []int{12, 9326, 641346, 6},
- V15: map[string]int{"pants": 42},
- V16: 3.14169,
- V17: 2.6e100,
- V18: []byte{21, 17, 65, 255, 0, 16},
- V19: [2]*int{nil, &num},
- V20: map[string]*int{
- "notnil": &num,
- "nil": nil,
- },
- V21: []*int{nil, &num, nil},
+ }, {
+ // Top level pointer check
+ want: &UserType1{
+ A: "marmalade",
+ B: 24,
+ C: "jam",
+ },
+ }, {
+ // Inner pointer check.
+ want: UserType2{
+ A: "dogs",
+ B: &UserType1{
+ A: "cats",
+ B: 24,
+ C: "pjamas",
+ },
+ C: &num,
+ },
+ }, {
+ // nil pointer check.
+ want: UserType2{
+ A: "dogs",
+ B: nil,
+ C: nil,
+ },
+ }, {
+ // nested struct check
+ want: UserType3{
+ A: UserType1{
+ A: "marmalade",
+ B: 24,
+ C: "jam",
+ },
+ },
+ }, {
+ // All zeroes
+ want: struct {
+ V00 bool
+ V01 byte // unsupported by spec (same as uint8)
+ V02 uint8 // unsupported by spec
+ V03 int16
+ // V04 uint16 // unsupported by spec
+ V05 int32
+ // V06 uint32 // unsupported by spec
+ V07 int64
+ // V08 uint64 // unsupported by spec
+ V09 int
+ V10 struct{}
+ V11 *struct{}
+ V12 [0]int
+ V13 [2]int
+ V14 []int
+ V15 map[string]int
+ V16 float32
+ V17 float64
+ V18 []byte
+ V19 [2]*int
+ V20 map[*string]*int
+ }{},
+ }, {
+ want: struct {
+ V00 bool
+ V01 byte // unsupported by spec (same as uint8)
+ V02 uint8 // unsupported by spec
+ V03 int16
+ // V04 uint16 // unsupported by spec
+ V05 int32
+ // V06 uint32 // unsupported by spec
+ V07 int64
+ // V08 uint64 // unsupported by spec
+ V09 int
+ V10 struct{}
+ V11 *struct{}
+ V12 [0]int
+ V13 [2]int
+ V14 []int
+ V15 map[string]int
+ V16 float32
+ V17 float64
+ V18 []byte
+ V19 [2]*int
+ V20 map[string]*int
+ V21 []*int
+ }{
+ V00: true,
+ V01: 1,
+ V02: 2,
+ V03: 3,
+ V05: 5,
+ V07: 7,
+ V09: 9,
+ V10: struct{}{},
+ V11: &struct{}{},
+ V12: [0]int{},
+ V13: [2]int{72, 908},
+ V14: []int{12, 9326, 641346, 6},
+ V15: map[string]int{"pants": 42},
+ V16: 3.14169,
+ V17: 2.6e100,
+ V18: []byte{21, 17, 65, 255, 0, 16},
+ V19: [2]*int{nil, &num},
+ V20: map[string]*int{
+ "notnil": &num,
+ "nil": nil,
+ },
+ V21: []*int{nil, &num, nil},
+ },
+ // TODO add custom types such as protocol buffers.
},
- // TODO add custom types such as protocol buffers.
- },
}
for _, test := range tests {
t.Run(fmt.Sprintf("%+v", test.want), func(t *testing.T) {
@@ -162,7 +174,6 @@ func TestReflectionRowCoderGeneration(t *testing.T) {
}
})
}
-
}
type UserType1 struct {
@@ -176,3 +187,274 @@ type UserType2 struct {
B *UserType1
C *int
}
+
+type UserType3 struct {
+ A UserType1
+}
+
+func ut1Enc(val interface{}, w io.Writer) error {
+ if err := WriteSimpleRowHeader(3, w); err != nil {
+ return err
+ }
+ elm := val.(UserType1)
+ if err := EncodeStringUTF8(elm.A, w); err != nil {
+ return err
+ }
+ if err := EncodeVarInt(int64(elm.B), w); err != nil {
+ return err
+ }
+ if err := EncodeStringUTF8(elm.C, w); err != nil {
+ return err
+ }
+ return nil
+}
+
+func ut1Dec(r io.Reader) (interface{}, error) {
+ if err := ReadSimpleRowHeader(3, r); err != nil {
+ return nil, err
+ }
+ a, err := DecodeStringUTF8(r)
+ if err != nil {
+ return nil, fmt.Errorf("decoding string field A: %v", err)
+ }
+ b, err := DecodeVarInt(r)
+ if err != nil {
+ return nil, fmt.Errorf("decoding int field B: %v", err)
+ }
+ c, err := DecodeStringUTF8(r)
+ if err != nil {
+ return nil, fmt.Errorf("decoding string field C: %v, %v", c, err)
+ }
+ return UserType1{
+ A: a,
+ B: int(b),
+ C: c,
+ }, nil
+}
+
+func TestRowCoder_CustomCoder(t *testing.T) {
+ customRT := reflect.TypeOf(UserType1{})
+ customEnc := ut1Enc
+ customDec := ut1Dec
+
+ num := 35
+ tests := []struct {
+ want interface{}
+ }{
+ {
+ // Top level value check
+ want: UserType1{
+ A: "cats",
+ B: 24,
+ C: "pjamas",
+ },
+ }, {
+ // Top level pointer check
+ want: &UserType1{
+ A: "marmalade",
+ B: 24,
+ C: "jam",
+ },
+ }, {
+ // Inner pointer check.
+ want: UserType2{
+ A: "dogs",
+ B: &UserType1{
+ A: "cats",
+ B: 24,
+ C: "pjamas",
+ },
+ C: &num,
+ },
+ }, {
+ // nil pointer check.
+ want: UserType2{
+ A: "dogs",
+ B: nil,
+ C: nil,
+ },
+ }, {
+ // nested struct check
+ want: UserType3{
+ A: UserType1{
+ A: "marmalade",
+ B: 24,
+ C: "jam",
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%+v", test.want), func(t *testing.T) {
+ rt := reflect.TypeOf(test.want)
+ var encB RowEncoderBuilder
+ encB.Register(customRT, func(reflect.Type) (func(interface{}, io.Writer) error, error) { return customEnc, nil })
+ enc, err := encB.Build(rt)
+ if err != nil {
+ t.Fatalf("RowEncoderBuilder.Build(%v) = %v, want nil error", rt, err)
+ }
+ var decB RowDecoderBuilder
+ decB.Register(customRT, func(reflect.Type) (func(io.Reader) (interface{}, error), error) { return customDec, nil })
+ dec, err := decB.Build(rt)
+ if err != nil {
+ t.Fatalf("RowDecoderBuilder.Build(%v) = %v, want nil error", rt, err)
+ }
+ var buf bytes.Buffer
+ if err := enc(test.want, &buf); err != nil {
+ t.Fatalf("enc(%v) = err, want nil error", err)
+ }
+ _, err = dec(&buf)
+ if err != nil {
+ t.Fatalf("BuildDecoder(%v) = %v, want nil error", rt, err)
+ }
+ })
+ }
+}
+
+func BenchmarkRowCoder_RoundTrip(b *testing.B) {
+
+ num := 35
+ benches := []struct {
+ want interface{}
+ customRT reflect.Type
+ customEnc func(interface{}, io.Writer) error
+ customDec func(io.Reader) (interface{}, error)
+ }{
+ {
+ // Top level value check
+ want: UserType1{
+ A: "cats",
+ B: 24,
+ C: "pjamas",
+ },
+ customRT: reflect.TypeOf(UserType1{}),
+ customEnc: ut1Enc,
+ customDec: ut1Dec,
+ }, {
+ // Top level pointer check
+ want: &UserType1{
+ A: "marmalade",
+ B: 24,
+ C: "jam",
+ },
+ customRT: reflect.TypeOf(UserType1{}),
+ customEnc: ut1Enc,
+ customDec: ut1Dec,
+ }, {
+ // Inner pointer check.
+ want: UserType2{
+ A: "dogs",
+ B: &UserType1{
+ A: "cats",
+ B: 24,
+ C: "pjamas",
+ },
+ C: &num,
+ },
+ customRT: reflect.TypeOf(UserType1{}),
+ customEnc: ut1Enc,
+ customDec: ut1Dec,
+ }, {
+ // nil pointer check.
+ want: UserType2{
+ A: "dogs",
+ B: nil,
+ C: nil,
+ },
+ customRT: reflect.TypeOf(UserType1{}),
+ customEnc: ut1Enc,
+ customDec: ut1Dec,
+ }, {
+ // nested struct check
+ want: UserType3{
+ A: UserType1{
+ A: "marmalade",
+ B: 24,
+ C: "jam",
+ },
+ },
+ customRT: reflect.TypeOf(UserType1{}),
+ customEnc: ut1Enc,
+ customDec: ut1Dec,
+ },
+ }
+ for _, bench := range benches {
+ rt := reflect.TypeOf(bench.want)
+ {
+ enc, err := RowEncoderForStruct(rt)
+ if err != nil {
+ b.Fatalf("BuildEncoder(%v) = %v, want nil error", rt, err)
+ }
+ dec, err := RowDecoderForStruct(rt)
+ if err != nil {
+ b.Fatalf("BuildDecoder(%v) = %v, want nil error", rt, err)
+ }
+ var buf bytes.Buffer
+ b.Run(fmt.Sprintf("SCHEMA %+v", bench.want), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ if err := enc(bench.want, &buf); err != nil {
+ b.Fatalf("enc(%v) = err, want nil error", err)
+ }
+ _, err := dec(&buf)
+ if err != nil {
+ b.Fatalf("BuildDecoder(%v) = %v, want nil error", rt, err)
+ }
+ }
+ })
+ }
+ if bench.customEnc != nil && bench.customDec != nil && rt == bench.customRT {
+ var buf bytes.Buffer
+ b.Run(fmt.Sprintf("CUSTOM %+v", bench.want), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ if err := bench.customEnc(bench.want, &buf); err != nil {
+ b.Fatalf("enc(%v) = err, want nil error", err)
+ }
+ _, err := bench.customDec(&buf)
+ if err != nil {
+ b.Fatalf("BuildDecoder(%v) = %v, want nil error", rt, err)
+ }
+ }
+ })
+ }
+ if bench.customEnc != nil && bench.customDec != nil {
+ var encB RowEncoderBuilder
+ encB.Register(bench.customRT, func(reflect.Type) (func(interface{}, io.Writer) error, error) { return bench.customEnc, nil })
+ enc, err := encB.Build(rt)
+ if err != nil {
+ b.Fatalf("RowEncoderBuilder.Build(%v) = %v, want nil error", rt, err)
+ }
+ var decB RowDecoderBuilder
+ decB.Register(bench.customRT, func(reflect.Type) (func(io.Reader) (interface{}, error), error) { return bench.customDec, nil })
+ dec, err := decB.Build(rt)
+ if err != nil {
+ b.Fatalf("RowDecoderBuilder.Build(%v) = %v, want nil error", rt, err)
+ }
+ var buf bytes.Buffer
+ b.Run(fmt.Sprintf("REGISTERED %+v", bench.want), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ if err := enc(bench.want, &buf); err != nil {
+ b.Fatalf("enc(%v) = err, want nil error", err)
+ }
+ _, err := dec(&buf)
+ if err != nil {
+ b.Fatalf("BuildDecoder(%v) = %v, want nil error", rt, err)
+ }
+ }
+ })
+ }
+ {
+ b.Run(fmt.Sprintf("JSON %+v", bench.want), func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ data, err := jsonx.Marshal(bench.want)
+ if err != nil {
+ b.Fatalf("jsonx.Marshal(%v) = err, want nil error", err)
+ }
+ val := reflect.New(rt)
+ if err := jsonx.Unmarshal(val.Interface(), data); err != nil {
+ b.Fatalf("jsonx.Unmarshal(%v) = %v, want nil error; type: %v", rt, err, val.Type())
+ }
+ }
+ })
+ }
+ }
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/testutil/testutil.go b/sdks/go/pkg/beam/core/graph/coder/testutil/testutil.go
new file mode 100644
index 0000000..3942be5
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/testutil/testutil.go
@@ -0,0 +1,154 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil contains helpers to test and validate custom Beam Schema coders.
+package testutil
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/apache/beam/sdks/go/pkg/beam/core/graph/coder"
+ "github.com/google/go-cmp/cmp"
+)
+
+// SchemaCoder helps validate custom schema coders.
+type SchemaCoder struct {
+ encBldUT, encBldSchema coder.RowEncoderBuilder
+ decBldUT, decBldSchema coder.RowDecoderBuilder
+
+ // CmpOptions to pass into the round trip comparison
+ CmpOptions cmp.Options
+}
+
+// Register adds additional custom types not under test to both the under test
+// and default schema coders.
+func (v *SchemaCoder) Register(rt reflect.Type, encF, decF interface{}) {
+ v.encBldUT.Register(rt, encF)
+ v.encBldSchema.Register(rt, encF)
+ v.decBldUT.Register(rt, decF)
+ v.decBldSchema.Register(rt, decF)
+}
+
+// T is an interface to facilitate testing the tester. The methods need
+// to match the one's we're using of *testing.T.
+type T interface {
+ Helper()
+ Run(string, func(*testing.T)) bool
+ Errorf(string, ...interface{})
+ Failed() bool
+ FailNow()
+}
+
+// Validate is a test utility to validate custom schema coders generate
+// beam schema encoded bytes.
+//
+// Validate accepts the reflect.Type to register, factory functions for
+// encoding and decoding, an anonymous struct type equivalent to the encoded
+// format produced and consumed by the factory produced functions and test
+// values. Test values must be either a struct, pointer to struct, or a slice
+// where each element is a struct or pointer to struct.
+//
+// TODO(lostluck): Improve documentation.
+// TODO(lostluck): Abstract into a configurable struct, to handle
+//
+// Validate will register the under test factories and generate an encoder and
+// decoder function. These functions will be re-used for all test values. This
+// emulates coders being re-used for all elements within a bundle.
+//
+// Validate mutates the SchemaCoderValidator, so the SchemaCoderValidator may not be used more than once.
+func (v *SchemaCoder) Validate(t T, rt reflect.Type, encF, decF, schema interface{}, values interface{}) {
+ t.Helper()
+ testValues := reflect.ValueOf(values)
+ // Check whether we have a slice type or not.
+ if testValues.Type().Kind() != reflect.Slice {
+ vs := reflect.MakeSlice(reflect.SliceOf(testValues.Type()), 0, 1)
+ testValues = reflect.Append(vs, testValues)
+ }
+ if testValues.Len() == 0 {
+ t.Errorf("No test values provided for ValidateSchemaCoder(%v)", rt)
+ }
+ // We now have non empty slice of test values!
+
+ v.encBldUT.Register(rt, encF)
+ v.decBldUT.Register(rt, decF)
+
+ testRt := testValues.Type().Elem()
+ encUT, err := v.encBldUT.Build(testRt)
+ if err != nil {
+ t.Errorf("Unable to build encoder function with given factory: coder.RowEncoderBuilder.Build(%v) = %v, want nil error", rt, err)
+ }
+ decUT, err := v.decBldUT.Build(testRt)
+ if err != nil {
+ t.Errorf("Unable to build decoder function with given factory: coder.RowDecoderBuilder.Build(%v) = %v, want nil error", rt, err)
+ }
+
+ schemaRt := reflect.TypeOf(schema)
+ encSchema, err := v.encBldSchema.Build(schemaRt)
+ if err != nil {
+ t.Errorf("Unable to build encoder function for schema equivalent type: coder.RowEncoderBuilder.Build(%v) = %v, want nil error", rt, err)
+ }
+ decSchema, err := v.decBldSchema.Build(schemaRt)
+ if err != nil {
+ t.Errorf("Unable to build decoder function for schema equivalent type: coder.RowDecoderBuilder.Build(%v) = %v, want nil error", rt, err)
+ }
+ // We use error messages instead of fatals to allow all the cases to be
+ // checked. None of the coder functions are used until the per value runs
+ // so a user can get additional information per run.
+ if t.Failed() {
+ t.FailNow()
+ }
+ for i := 0; i < testValues.Len(); i++ {
+ t.Run(fmt.Sprintf("%v[%d]", rt, i), func(t *testing.T) {
+ var buf bytes.Buffer
+ want := testValues.Index(i).Interface()
+ if err := encUT(want, &buf); err != nil {
+ t.Fatalf("error calling Under Test encoder[%v](%v) = %v", testRt, want, err)
+ }
+ initialBytes := clone(buf.Bytes())
+
+ bufSchema := bytes.NewBuffer(clone(initialBytes))
+
+ schemaV, err := decSchema(bufSchema)
+ if err != nil {
+ t.Fatalf("error calling Equivalent Schema decoder[%v]() = %v", schemaRt, err)
+ }
+ err = encSchema(schemaV, bufSchema)
+ if err != nil {
+ t.Fatalf("error calling Equivalent Schema encoder[%v](%v) = %v, want nil error", schemaRt, schemaV, err)
+ }
+ roundTripBytes := clone(bufSchema.Bytes())
+
+ if d := cmp.Diff(initialBytes, roundTripBytes); d != "" {
+ t.Errorf("round trip through equivalent schema type didn't produce equivalent byte slices (-initial,+roundTrip): \n%v", d)
+ }
+ got, err := decUT(bufSchema)
+ if err != nil {
+ t.Fatalf("Under Test decoder(%v) = %v, want nil error", rt, err)
+ }
+ if d := cmp.Diff(want, got, v.CmpOptions); d != "" {
+ t.Fatalf("round trip through custom coder produced diff: (-want, +got):\n%v", d)
+ }
+ })
+ }
+}
+
+func clone(b []byte) []byte {
+ c := make([]byte, len(b))
+ copy(c, b)
+ return c
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/testutil/testutil_test.go b/sdks/go/pkg/beam/core/graph/coder/testutil/testutil_test.go
new file mode 100644
index 0000000..7305e7a
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/testutil/testutil_test.go
@@ -0,0 +1,201 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/apache/beam/sdks/go/pkg/beam/core/graph/coder"
+)
+
+type UserInterface interface {
+ mark()
+}
+
+type UserType1 struct {
+ A string
+ B int
+ C string
+}
+
+func (UserType1) mark() {}
+
+func ut1EncDropB(val interface{}, w io.Writer) error {
+ if err := coder.WriteSimpleRowHeader(2, w); err != nil {
+ return err
+ }
+ elm := val.(UserType1)
+ if err := coder.EncodeStringUTF8(elm.A, w); err != nil {
+ return err
+ }
+ if err := coder.EncodeStringUTF8(elm.C, w); err != nil {
+ return err
+ }
+ return nil
+}
+
+func ut1DecDropB(r io.Reader) (interface{}, error) {
+ if err := coder.ReadSimpleRowHeader(2, r); err != nil {
+ return nil, err
+ }
+ a, err := coder.DecodeStringUTF8(r)
+ if err != nil {
+ return nil, fmt.Errorf("decoding string field A: %v", err)
+ }
+ c, err := coder.DecodeStringUTF8(r)
+ if err != nil {
+ return nil, fmt.Errorf("decoding string field C: %v, %v", c, err)
+ }
+ return UserType1{
+ A: a,
+ B: 42,
+ C: c,
+ }, nil
+}
+
+type UserType2 struct {
+ A UserType1
+}
+
+func TestValidateCoder(t *testing.T) {
+ // Validates a custom UserType1 encoding, which drops encoding the "B" field,
+ // always setting it to a constant value.
+ t.Run("SingleValue", func(t *testing.T) {
+ (&SchemaCoder{}).Validate(t, reflect.TypeOf((*UserType1)(nil)).Elem(),
+ func(reflect.Type) (func(interface{}, io.Writer) error, error) { return ut1EncDropB, nil },
+ func(reflect.Type) (func(io.Reader) (interface{}, error), error) { return ut1DecDropB, nil },
+ struct{ A, C string }{},
+ UserType1{
+ A: "cats",
+ B: 42,
+ C: "pjamas",
+ },
+ )
+ })
+ t.Run("SliceOfValues", func(t *testing.T) {
+ (&SchemaCoder{}).Validate(t, reflect.TypeOf((*UserType1)(nil)).Elem(),
+ func(reflect.Type) (func(interface{}, io.Writer) error, error) { return ut1EncDropB, nil },
+ func(reflect.Type) (func(io.Reader) (interface{}, error), error) { return ut1DecDropB, nil },
+ struct{ A, C string }{},
+ []UserType1{
+ {
+ A: "cats",
+ B: 42,
+ C: "pjamas",
+ }, {
+ A: "dogs",
+ B: 42,
+ C: "breakfast",
+ }, {
+ A: "fish",
+ B: 42,
+ C: "plenty of",
+ },
+ },
+ )
+ })
+ t.Run("InterfaceCoder", func(t *testing.T) {
+ (&SchemaCoder{}).Validate(t, reflect.TypeOf((*UserInterface)(nil)).Elem(),
+ func(rt reflect.Type) (func(interface{}, io.Writer) error, error) {
+ return ut1EncDropB, nil
+ },
+ func(rt reflect.Type) (func(io.Reader) (interface{}, error), error) {
+ return ut1DecDropB, nil
+ },
+ struct{ A, C string }{},
+ UserType1{
+ A: "cats",
+ B: 42,
+ C: "pjamas",
+ },
+ )
+ })
+ t.Run("FailureCases", func(t *testing.T) {
+ var c checker
+ err := fmt.Errorf("FactoryError")
+ var v SchemaCoder
+ // Register the pointer type to the default encoder too.
+ v.Register(reflect.TypeOf((*UserType2)(nil)),
+ func(reflect.Type) (func(interface{}, io.Writer) error, error) { return nil, err },
+ func(reflect.Type) (func(io.Reader) (interface{}, error), error) { return nil, err },
+ )
+ v.Validate(&c, reflect.TypeOf((*UserType1)(nil)).Elem(),
+ func(reflect.Type) (func(interface{}, io.Writer) error, error) { return ut1EncDropB, err },
+ func(reflect.Type) (func(io.Reader) (interface{}, error), error) { return ut1DecDropB, err },
+ struct {
+ A, C string
+ B *UserType2 // To trigger the bad factory registered earlier.
+ }{},
+ []UserType1{},
+ )
+ if got, want := len(c.errors), 5; got != want {
+ t.Fatalf("SchemaCoder.Validate did not fail as expected. Got %v errors logged, but want %v", got, want)
+ }
+ if !strings.Contains(c.errors[0].fmt, "No test values") {
+ t.Fatalf("SchemaCoder.Validate with no values did not fail. fmt: %q", c.errors[0].fmt)
+ }
+ if !strings.Contains(c.errors[1].fmt, "Unable to build encoder function with given factory") {
+ t.Fatalf("SchemaCoder.Validate with no values did not fail. fmt: %q", c.errors[1].fmt)
+ }
+ if !strings.Contains(c.errors[2].fmt, "Unable to build decoder function with given factory") {
+ t.Fatalf("SchemaCoder.Validate with no values did not fail. fmt: %q", c.errors[2].fmt)
+ }
+ if !strings.Contains(c.errors[3].fmt, "Unable to build encoder function for schema equivalent type") {
+ t.Fatalf("SchemaCoder.Validate with no values did not fail. fmt: %q", c.errors[3].fmt)
+ }
+ if !strings.Contains(c.errors[4].fmt, "Unable to build decoder function for schema equivalent type") {
+ t.Fatalf("SchemaCoder.Validate with no values did not fail. fmt: %q", c.errors[4].fmt)
+ }
+ })
+}
+
+type msg struct {
+ fmt string
+ params []interface{}
+}
+
+type checker struct {
+ errors []msg
+
+ runCount int
+ failNowCalled bool
+}
+
+func (c *checker) Helper() {}
+
+func (c *checker) Run(string, func(*testing.T)) bool {
+ c.runCount++
+ return true
+}
+
+func (c *checker) Errorf(fmt string, params ...interface{}) {
+ c.errors = append(c.errors, msg{
+ fmt: fmt,
+ params: params,
+ })
+}
+
+func (c *checker) Failed() bool {
+ return len(c.errors) > 0
+}
+
+func (c *checker) FailNow() {
+ c.failNowCalled = true
+}