You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2020/06/30 17:55:50 UTC

[beam] branch master updated: [BEAM-9615] Add bytes, bool, and iterable coders (#12127)

This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e7a3392  [BEAM-9615] Add bytes, bool, and iterable coders (#12127)
e7a3392 is described below

commit e7a3392595e14cdb377fe95fb679305ea9a32249
Author: Robert Burke <lo...@users.noreply.github.com>
AuthorDate: Tue Jun 30 10:55:31 2020 -0700

    [BEAM-9615] Add bytes, bool, and iterable coders (#12127)
---
 sdks/go/pkg/beam/core/graph/coder/bool.go          |  57 ++++++++++
 sdks/go/pkg/beam/core/graph/coder/bool_test.go     |  79 +++++++++++++
 sdks/go/pkg/beam/core/graph/coder/bytes.go         |  68 ++++++++++++
 sdks/go/pkg/beam/core/graph/coder/bytes_test.go    |  61 +++++++++++
 sdks/go/pkg/beam/core/graph/coder/iterable.go      | 122 +++++++++++++++++++++
 sdks/go/pkg/beam/core/graph/coder/iterable_test.go | 113 +++++++++++++++++++
 6 files changed, 500 insertions(+)

diff --git a/sdks/go/pkg/beam/core/graph/coder/bool.go b/sdks/go/pkg/beam/core/graph/coder/bool.go
new file mode 100644
index 0000000..309ae24
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/bool.go
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"io"
+
+	"github.com/apache/beam/sdks/go/pkg/beam/core/util/ioutilx"
+	"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
+)
+
+// EncodeBool encodes a boolean according to the beam protocol.
+func EncodeBool(v bool, w io.Writer) error {
+	// Encoding: false = 0, true = 1
+	var err error
+	if v {
+		_, err = ioutilx.WriteUnsafe(w, []byte{1})
+	} else {
+		_, err = ioutilx.WriteUnsafe(w, []byte{0})
+	}
+	if err != nil {
+		return errors.Wrap(err, "error encoding bool")
+	}
+	return nil
+}
+
+// DecodeBool decodes a boolean according to the beam protocol.
+func DecodeBool(r io.Reader) (bool, error) {
+	// Encoding: false = 0, true = 1
+	var b [1]byte
+	if err := ioutilx.ReadNBufUnsafe(r, b[:]); err != nil {
+		if err == io.EOF {
+			return false, err
+		}
+		return false, errors.Wrap(err, "error decoding bool")
+	}
+	switch b[0] {
+	case 0:
+		return false, nil
+	case 1:
+		return true, nil
+	}
+	return false, errors.Errorf("error decoding bool: received invalid value %v", b[0])
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/bool_test.go b/sdks/go/pkg/beam/core/graph/coder/bool_test.go
new file mode 100644
index 0000000..77aaf8b
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/bool_test.go
@@ -0,0 +1,79 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"bytes"
+	"fmt"
+	"testing"
+
+	"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
+	"github.com/google/go-cmp/cmp"
+)
+
+func TestEncodeBool(t *testing.T) {
+	tests := []struct {
+		v    bool
+		want []byte
+	}{
+		{v: false, want: []byte{0}},
+		{v: true, want: []byte{1}},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(fmt.Sprintf("%v", test.v), func(t *testing.T) {
+			var buf bytes.Buffer
+			err := EncodeBool(test.v, &buf)
+			if err != nil {
+				t.Fatalf("EncodeBool(%v) = %v", test.v, err)
+			}
+			if d := cmp.Diff(test.want, buf.Bytes()); d != "" {
+				t.Errorf("EncodeBool(%v) = %v, want %v diff(-want,+got):\n %v", test.v, buf.Bytes(), test.want, d)
+			}
+		})
+	}
+}
+
+func TestDecodeBool(t *testing.T) {
+	tests := []struct {
+		b    []byte
+		want bool
+		err  error
+	}{
+		{want: false, b: []byte{0}},
+		{want: true, b: []byte{1}},
+		{b: []byte{42}, err: errors.Errorf("error decoding bool: received invalid value %v", 42)},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(fmt.Sprintf("%v", test.want), func(t *testing.T) {
+			buf := bytes.NewBuffer(test.b)
+			got, err := DecodeBool(buf)
+			if test.err != nil && err != nil {
+				if d := cmp.Diff(test.err.Error(), err.Error()); d != "" {
+					t.Errorf("DecodeBool(%v) = %v, want %v diff(-want,+got):\n %v", test.b, err, test.err, d)
+				}
+				return
+			}
+			if err != nil {
+				t.Fatalf("DecodeBool(%v) = %v", test.b, err)
+			}
+			if d := cmp.Diff(test.want, got); d != "" {
+				t.Errorf("DecodeBool(%v) = %v, want %v diff(-want,+got):\n %v", test.b, got, test.want, d)
+			}
+		})
+	}
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/bytes.go b/sdks/go/pkg/beam/core/graph/coder/bytes.go
new file mode 100644
index 0000000..5bb5692
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/bytes.go
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"fmt"
+	"io"
+
+	"github.com/apache/beam/sdks/go/pkg/beam/core/util/ioutilx"
+	"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
+)
+
+// EncodeByte encodes a single byte.
+func EncodeByte(v byte, w io.Writer) error {
+	// Encoding: raw byte.
+	if _, err := ioutilx.WriteUnsafe(w, []byte{v}); err != nil {
+		return fmt.Errorf("error encoding byte: %v", err)
+	}
+	return nil
+}
+
+// DecodeByte decodes a single byte.
+func DecodeByte(r io.Reader) (byte, error) {
+	// Encoding: raw byte
+	var b [1]byte
+	if err := ioutilx.ReadNBufUnsafe(r, b[:]); err != nil {
+		if err == io.EOF {
+			return 0, err
+		}
+		return 0, errors.Wrap(err, "error decoding byte")
+	}
+	return b[0], nil
+}
+
+// EncodeBytes encodes a []byte with a length prefix per the beam protocol.
+func EncodeBytes(v []byte, w io.Writer) error {
+	// Encoding: size (varint) + raw data
+	size := len(v)
+	if err := EncodeVarInt((int64)(size), w); err != nil {
+		return err
+	}
+	_, err := w.Write(v)
+	return err
+
+}
+
+// DecodeBytes decodes a length prefixed []byte according to the beam protocol.
+func DecodeBytes(r io.Reader) ([]byte, error) {
+	// Encoding: size (varint) + raw data
+	size, err := DecodeVarInt(r)
+	if err != nil {
+		return nil, err
+	}
+	return ioutilx.ReadN(r, (int)(size))
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/bytes_test.go b/sdks/go/pkg/beam/core/graph/coder/bytes_test.go
new file mode 100644
index 0000000..0d7d1a7
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/bytes_test.go
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"bytes"
+	"fmt"
+	"strings"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+func TestEncodeDecodeBytes(t *testing.T) {
+	longString := strings.Repeat(" this sentence is 32 characters.", 8) // 256 characters to ensure LP works.
+	tests := []struct {
+		v       []byte
+		encoded []byte
+	}{
+		{v: []byte{}, encoded: []byte{0}},
+		{v: []byte{42}, encoded: []byte{1, 42}},
+		{v: []byte{42, 23}, encoded: []byte{2, 42, 23}},
+		{v: []byte(longString), encoded: append([]byte{128, 2}, []byte(longString)...)},
+	}
+	for _, test := range tests {
+		test := test
+		t.Run(fmt.Sprintf("encode %q", test.v), func(t *testing.T) {
+			var buf bytes.Buffer
+			err := EncodeBytes(test.v, &buf)
+			if err != nil {
+				t.Fatalf("EncodeBytes(%q) = %v", test.v, err)
+			}
+			if d := cmp.Diff(test.encoded, buf.Bytes()); d != "" {
+				t.Errorf("EncodeBytes(%q) = %v, want %v diff(-want,+got):\n %v", test.v, buf.Bytes(), test.encoded, d)
+			}
+		})
+		t.Run(fmt.Sprintf("decode %v", test.v), func(t *testing.T) {
+			buf := bytes.NewBuffer(test.encoded)
+			got, err := DecodeBytes(buf)
+			if err != nil {
+				t.Fatalf("DecodeBytes(%q) = %v", test.encoded, err)
+			}
+			if d := cmp.Diff(test.v, got); d != "" {
+				t.Errorf("DecodeBytes(%q) = %q, want %v diff(-want,+got):\n %v", test.encoded, got, test.v, d)
+			}
+		})
+	}
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/iterable.go b/sdks/go/pkg/beam/core/graph/coder/iterable.go
new file mode 100644
index 0000000..c6affc6
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/iterable.go
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"io"
+	"reflect"
+
+	"github.com/apache/beam/sdks/go/pkg/beam/internal/errors"
+)
+
+// TODO(lostluck): 2020.06.29 export these for use for others?
+
+// iterableEncoder reflectively encodes a slice or array type using
+// the beam fixed length iterable encoding.
+func iterableEncoder(rt reflect.Type, encode func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error {
+	return func(rv reflect.Value, w io.Writer) error {
+		size := rv.Len()
+		if err := EncodeInt32((int32)(size), w); err != nil {
+			return err
+		}
+		for i := 0; i < size; i++ {
+			if err := encode(rv.Index(i), w); err != nil {
+				return err
+			}
+		}
+		return nil
+	}
+}
+
+// iterableDecoderForSlice can decode from both the fixed sized and
+// multi-chunk variants of the beam iterable protocol.
+// Returns an error for other protocols (such as state backed).
+func iterableDecoderForSlice(rt reflect.Type, decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error {
+	return func(ret reflect.Value, r io.Reader) error {
+		// (1) Read count prefixed encoded data
+		size, err := DecodeInt32(r)
+		if err != nil {
+			return err
+		}
+		n := int(size)
+		switch {
+		case n >= 0:
+			rv := reflect.MakeSlice(rt, n, n)
+			if err := decodeToIterable(rv, r, decodeToElem); err != nil {
+				return err
+			}
+			ret.Set(rv)
+			return nil
+		case n == -1:
+			chunk, err := DecodeVarInt(r)
+			if err != nil {
+				return err
+			}
+			rv := reflect.MakeSlice(rt, 0, int(chunk))
+			for chunk != 0 {
+				rvi := reflect.MakeSlice(rt, int(chunk), int(chunk))
+				if err := decodeToIterable(rvi, r, decodeToElem); err != nil {
+					return err
+				}
+				rv = reflect.AppendSlice(rv, rvi)
+				chunk, err = DecodeVarInt(r)
+				if err != nil {
+					return err
+				}
+			}
+			ret.Set(rv)
+			return nil
+		default:
+			return errors.Errorf("unable to decode slice iterable with size: %d", n)
+		}
+	}
+}
+
+// iterableDecoderForArray can decode from only the fixed sized and
+// multi-chunk variant of the beam iterable protocol.
+// Returns an error for other protocols (such as state backed).
+func iterableDecoderForArray(rt reflect.Type, decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error {
+	return func(ret reflect.Value, r io.Reader) error {
+		// (1) Read count prefixed encoded data
+		size, err := DecodeInt32(r)
+		if err != nil {
+			return err
+		}
+		n := int(size)
+		if rt.Len() != n {
+			return errors.Errorf("len mismatch decoding a %v: want %d got %d", rt, rt.Len(), n)
+		}
+		switch {
+		case n >= 0:
+			if err := decodeToIterable(ret, r, decodeToElem); err != nil {
+				return err
+			}
+			return nil
+		default:
+			return errors.Errorf("unable to decode array iterable with size: %d", n)
+		}
+	}
+}
+
+func decodeToIterable(rv reflect.Value, r io.Reader, decodeTo func(reflect.Value, io.Reader) error) error {
+	for i := 0; i < rv.Len(); i++ {
+		err := decodeTo(rv.Index(i), r)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
diff --git a/sdks/go/pkg/beam/core/graph/coder/iterable_test.go b/sdks/go/pkg/beam/core/graph/coder/iterable_test.go
new file mode 100644
index 0000000..977b588
--- /dev/null
+++ b/sdks/go/pkg/beam/core/graph/coder/iterable_test.go
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package coder
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"reflect"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+func TestEncodeDecodeIterable(t *testing.T) {
+	stringEnc := func(rv reflect.Value, w io.Writer) error {
+		return EncodeStringUTF8(rv.String(), w)
+	}
+	stringDec := func(rv reflect.Value, r io.Reader) error {
+		v, err := DecodeStringUTF8(r)
+		if err != nil {
+			return nil
+		}
+		t.Log(v)
+		rv.SetString(v)
+		return nil
+	}
+
+	tests := []struct {
+		v          interface{}
+		encElm     func(reflect.Value, io.Writer) error
+		decElm     func(reflect.Value, io.Reader) error
+		encoded    []byte
+		decodeOnly bool
+	}{
+		{
+			v: [4]byte{1, 2, 3, 4},
+			encElm: func(rv reflect.Value, w io.Writer) error {
+				return EncodeByte(byte(rv.Uint()), w)
+			},
+			decElm: func(rv reflect.Value, r io.Reader) error {
+				b, err := DecodeByte(r)
+				if err != nil {
+					return nil
+				}
+				rv.SetUint(uint64(b))
+				return nil
+			},
+			encoded: []byte{0, 0, 0, 4, 1, 2, 3, 4},
+		},
+		{
+			v:       []string{"my", "gopher"},
+			encElm:  stringEnc,
+			decElm:  stringDec,
+			encoded: []byte{0, 0, 0, 2, 2, 'm', 'y', 6, 'g', 'o', 'p', 'h', 'e', 'r'},
+		},
+		{
+			v:          []string{"my", "gopher", "rocks"},
+			encElm:     stringEnc,
+			decElm:     stringDec,
+			encoded:    []byte{255, 255, 255, 255, 1, 2, 'm', 'y', 2, 6, 'g', 'o', 'p', 'h', 'e', 'r', 5, 'r', 'o', 'c', 'k', 's', 0},
+			decodeOnly: true,
+		},
+	}
+	for _, test := range tests {
+		test := test
+		if !test.decodeOnly {
+			t.Run(fmt.Sprintf("encode %q", test.v), func(t *testing.T) {
+				var buf bytes.Buffer
+				err := iterableEncoder(reflect.TypeOf(test.v), test.encElm)(reflect.ValueOf(test.v), &buf)
+				if err != nil {
+					t.Fatalf("EncodeBytes(%q) = %v", test.v, err)
+				}
+				if d := cmp.Diff(test.encoded, buf.Bytes()); d != "" {
+					t.Errorf("EncodeBytes(%q) = %v, want %v diff(-want,+got):\n %v", test.v, buf.Bytes(), test.encoded, d)
+				}
+			})
+		}
+		t.Run(fmt.Sprintf("decode %v", test.v), func(t *testing.T) {
+			buf := bytes.NewBuffer(test.encoded)
+			rt := reflect.TypeOf(test.v)
+			var dec func(reflect.Value, io.Reader) error
+			switch rt.Kind() {
+			case reflect.Slice:
+				dec = iterableDecoderForSlice(rt, test.decElm)
+			case reflect.Array:
+				dec = iterableDecoderForArray(rt, test.decElm)
+			}
+			rv := reflect.New(rt).Elem()
+			err := dec(rv, buf)
+			if err != nil {
+				t.Fatalf("DecodeBytes(%q) = %v", test.encoded, err)
+			}
+			got := rv.Interface()
+			if d := cmp.Diff(test.v, got); d != "" {
+				t.Errorf("DecodeBytes(%q) = %q, want %v diff(-want,+got):\n %v", test.encoded, got, test.v, d)
+			}
+		})
+	}
+}