You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@qpid.apache.org by ac...@apache.org on 2018/04/10 21:17:49 UTC
[55/55] [abbrv] qpid-proton git commit: Merge tag '0.22.0' into go1
Merge tag '0.22.0' into go1
Release 0.22.0
Project: http://git-wip-us.apache.org/repos/asf/qpid-proton/repo
Commit: http://git-wip-us.apache.org/repos/asf/qpid-proton/commit/6f799990
Tree: http://git-wip-us.apache.org/repos/asf/qpid-proton/tree/6f799990
Diff: http://git-wip-us.apache.org/repos/asf/qpid-proton/diff/6f799990
Branch: refs/heads/go1
Commit: 6f799990cdf739b3caacf66eb2a9a29b14c9abeb
Parents: 6e5b4d5 e3797ce
Author: Alan Conway <ac...@redhat.com>
Authored: Tue Apr 10 17:15:21 2018 -0400
Committer: Alan Conway <ac...@redhat.com>
Committed: Tue Apr 10 17:15:21 2018 -0400
----------------------------------------------------------------------
amqp/marshal.go | 3 +-
amqp/unmarshal.go | 4 +-
electron/auth_test.go | 92 ++++++++++++++++++++++++---------------------
electron/connection.go | 65 +++++++++++++++++++++-----------
4 files changed, 94 insertions(+), 70 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/6f799990/amqp/marshal.go
----------------------------------------------------------------------
diff --cc amqp/marshal.go
index 33b30a8,0000000..99584a2
mode 100644,000000..100644
--- a/amqp/marshal.go
+++ b/amqp/marshal.go
@@@ -1,360 -1,0 +1,359 @@@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+
+package amqp
+
+// #include <proton/codec.h>
+import "C"
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "time"
+ "unsafe"
+)
+
+// Error returned if Go data cannot be marshaled as an AMQP type.
+type MarshalError struct {
+ // The Go type.
+ GoType reflect.Type
+ s string
+}
+
+func (e MarshalError) Error() string { return e.s }
+
+func newMarshalError(v interface{}, s string) *MarshalError {
+ t := reflect.TypeOf(v)
+ return &MarshalError{GoType: t, s: fmt.Sprintf("cannot marshal %s: %s", t, s)}
+}
+
+func dataMarshalError(v interface{}, data *C.pn_data_t) error {
+ if pe := PnError(C.pn_data_error(data)); pe != nil {
+ return newMarshalError(v, pe.Error())
+ }
+ return nil
+}
+
+/*
+Marshal encodes a Go value as AMQP data in buffer.
+If buffer is nil, or is not large enough, a new buffer is created.
+
+Returns the buffer used for encoding with len() adjusted to the actual size of data.
+
+Go types are encoded as follows
+
+ +-------------------------------------+--------------------------------------------+
+ |Go type |AMQP type |
+ +-------------------------------------+--------------------------------------------+
+ |bool |bool |
+ +-------------------------------------+--------------------------------------------+
+ |int8, int16, int32, int64 (int) |byte, short, int, long (int or long) |
+ +-------------------------------------+--------------------------------------------+
+ |uint8, uint16, uint32, uint64 (uint) |ubyte, ushort, uint, ulong (uint or ulong) |
+ +-------------------------------------+--------------------------------------------+
+ |float32, float64 |float, double. |
+ +-------------------------------------+--------------------------------------------+
+ |string |string |
+ +-------------------------------------+--------------------------------------------+
+ |[]byte, Binary |binary |
+ +-------------------------------------+--------------------------------------------+
+ |Symbol |symbol |
+ +-------------------------------------+--------------------------------------------+
+ |Char |char |
+ +-------------------------------------+--------------------------------------------+
+ |interface{} |the contained type |
+ +-------------------------------------+--------------------------------------------+
+ |nil |null |
+ +-------------------------------------+--------------------------------------------+
+ |map[K]T |map with K and T converted as above |
+ +-------------------------------------+--------------------------------------------+
+ |Map |map, may have mixed types for keys, values |
+ +-------------------------------------+--------------------------------------------+
+ |AnyMap |map (See AnyMap) |
+ +-------------------------------------+--------------------------------------------+
+ |List, []interface{} |list, may have mixed-type values |
+ +-------------------------------------+--------------------------------------------+
+ |[]T, [N]T |array, T is mapped as per this table |
+ +-------------------------------------+--------------------------------------------+
+ |Described |described type |
+ +-------------------------------------+--------------------------------------------+
+ |time.Time |timestamp |
+ +-------------------------------------+--------------------------------------------+
+ |UUID |uuid |
+ +-------------------------------------+--------------------------------------------+
+
+The following Go types cannot be marshaled: uintptr, function, channel, struct, complex64/128
+
- AMQP types not yet supported:
- - decimal32/64/128,
++AMQP types not yet supported: decimal32/64/128
+*/
+
+func Marshal(v interface{}, buffer []byte) (outbuf []byte, err error) {
+ data := C.pn_data(0)
+ defer C.pn_data_free(data)
+ if err = recoverMarshal(v, data); err != nil {
+ return buffer, err
+ }
+ encode := func(buf []byte) ([]byte, error) {
+ n := int(C.pn_data_encode(data, cPtr(buf), cLen(buf)))
+ switch {
+ case n == int(C.PN_OVERFLOW):
+ return buf, overflow
+ case n < 0:
+ return buf, dataMarshalError(v, data)
+ default:
+ return buf[:n], nil
+ }
+ }
+ return encodeGrow(buffer, encode)
+}
+
+// Internal use only
+func MarshalUnsafe(v interface{}, pnData unsafe.Pointer) (err error) {
+ return recoverMarshal(v, (*C.pn_data_t)(pnData))
+}
+
+func recoverMarshal(v interface{}, data *C.pn_data_t) (err error) {
+ defer func() { // Convert panic to error return
+ if r := recover(); r != nil {
+ if err2, ok := r.(*MarshalError); ok {
+ err = err2 // Convert internal panic to error
+ } else {
+ panic(r) // Unrecognized error, continue to panic
+ }
+ }
+ }()
+ marshal(v, data) // Panics on error
+ return
+}
+
+const minEncode = 256
+
+// overflow is returned when an encoding function can't fit data in the buffer.
+var overflow = fmt.Errorf("buffer too small")
+
+// encodeFn encodes into buffer[0:len(buffer)].
+// Returns buffer with length adjusted for data encoded.
+// If buffer too small, returns overflow as error.
+type encodeFn func(buffer []byte) ([]byte, error)
+
+// encodeGrow calls encode() into buffer, if it returns overflow grows the buffer.
+// Returns the final buffer.
+func encodeGrow(buffer []byte, encode encodeFn) ([]byte, error) {
+ if buffer == nil || len(buffer) == 0 {
+ buffer = make([]byte, minEncode)
+ }
+ var err error
+ for buffer, err = encode(buffer); err == overflow; buffer, err = encode(buffer) {
+ buffer = make([]byte, 2*len(buffer))
+ }
+ return buffer, err
+}
+
+// Marshal v to data
+func marshal(i interface{}, data *C.pn_data_t) {
+ switch v := i.(type) {
+ case nil:
+ C.pn_data_put_null(data)
+ case bool:
+ C.pn_data_put_bool(data, C.bool(v))
+
+ // Signed integers
+ case int8:
+ C.pn_data_put_byte(data, C.int8_t(v))
+ case int16:
+ C.pn_data_put_short(data, C.int16_t(v))
+ case int32:
+ C.pn_data_put_int(data, C.int32_t(v))
+ case int64:
+ C.pn_data_put_long(data, C.int64_t(v))
+ case int:
+ if intIs64 {
+ C.pn_data_put_long(data, C.int64_t(v))
+ } else {
+ C.pn_data_put_int(data, C.int32_t(v))
+ }
+
+ // Unsigned integers
+ case uint8:
+ C.pn_data_put_ubyte(data, C.uint8_t(v))
+ case uint16:
+ C.pn_data_put_ushort(data, C.uint16_t(v))
+ case uint32:
+ C.pn_data_put_uint(data, C.uint32_t(v))
+ case uint64:
+ C.pn_data_put_ulong(data, C.uint64_t(v))
+ case uint:
+ if intIs64 {
+ C.pn_data_put_ulong(data, C.uint64_t(v))
+ } else {
+ C.pn_data_put_uint(data, C.uint32_t(v))
+ }
+
+ // Floating point
+ case float32:
+ C.pn_data_put_float(data, C.float(v))
+ case float64:
+ C.pn_data_put_double(data, C.double(v))
+
+ // String-like (string, binary, symbol)
+ case string:
+ C.pn_data_put_string(data, pnBytes([]byte(v)))
+ case []byte:
+ C.pn_data_put_binary(data, pnBytes(v))
+ case Binary:
+ C.pn_data_put_binary(data, pnBytes([]byte(v)))
+ case Symbol:
+ C.pn_data_put_symbol(data, pnBytes([]byte(v)))
+
+ // Other simple types
+ case time.Time:
+ C.pn_data_put_timestamp(data, C.pn_timestamp_t(v.UnixNano()/1000))
+ case UUID:
+ C.pn_data_put_uuid(data, *(*C.pn_uuid_t)(unsafe.Pointer(&v[0])))
+ case Char:
+ C.pn_data_put_char(data, (C.pn_char_t)(v))
+
+ // Described types
+ case Described:
+ C.pn_data_put_described(data)
+ C.pn_data_enter(data)
+ marshal(v.Descriptor, data)
+ marshal(v.Value, data)
+ C.pn_data_exit(data)
+
+ // Restricted type annotation-key, marshals as contained value
+ case AnnotationKey:
+ marshal(v.Get(), data)
+
+ // Special type to represent AMQP maps with keys that are illegal in Go
+ case AnyMap:
+ C.pn_data_put_map(data)
+ C.pn_data_enter(data)
+ defer C.pn_data_exit(data)
+ for _, kv := range v {
+ marshal(kv.Key, data)
+ marshal(kv.Value, data)
+ }
+
+ default:
+ // Examine complex types (Go map, slice, array) by reflected structure
+ switch reflect.TypeOf(i).Kind() {
+
+ case reflect.Map:
+ m := reflect.ValueOf(v)
+ C.pn_data_put_map(data)
+ if C.pn_data_enter(data) {
+ defer C.pn_data_exit(data)
+ } else {
+ panic(dataMarshalError(i, data))
+ }
+ for _, key := range m.MapKeys() {
+ marshal(key.Interface(), data)
+ marshal(m.MapIndex(key).Interface(), data)
+ }
+
+ case reflect.Slice, reflect.Array:
+ // Note: Go array and slice are mapped the same way:
+ // if element type is an interface, map to AMQP list (mixed type)
+ // if element type is a non-interface type map to AMQP array (single type)
+ s := reflect.ValueOf(v)
+ if pnType, ok := arrayTypeMap[s.Type().Elem()]; ok {
+ C.pn_data_put_array(data, false, pnType)
+ } else {
+ C.pn_data_put_list(data)
+ }
+ C.pn_data_enter(data)
+ defer C.pn_data_exit(data)
+ for j := 0; j < s.Len(); j++ {
+ marshal(s.Index(j).Interface(), data)
+ }
+
+ default:
+ panic(newMarshalError(v, "no conversion"))
+ }
+ }
+ if err := dataMarshalError(i, data); err != nil {
+ panic(err)
+ }
+}
+
+// Mapping froo Go element type to AMQP array type for types that can go in an AMQP array
+// NOTE: this must be kept consistent with marshal() which does the actual marshalling.
+var arrayTypeMap = map[reflect.Type]C.pn_type_t{
+ nil: C.PN_NULL,
+ reflect.TypeOf(true): C.PN_BOOL,
+
+ reflect.TypeOf(int8(0)): C.PN_BYTE,
+ reflect.TypeOf(int16(0)): C.PN_INT,
+ reflect.TypeOf(int32(0)): C.PN_SHORT,
+ reflect.TypeOf(int64(0)): C.PN_LONG,
+
+ reflect.TypeOf(uint8(0)): C.PN_UBYTE,
+ reflect.TypeOf(uint16(0)): C.PN_UINT,
+ reflect.TypeOf(uint32(0)): C.PN_USHORT,
+ reflect.TypeOf(uint64(0)): C.PN_ULONG,
+
+ reflect.TypeOf(float32(0)): C.PN_FLOAT,
+ reflect.TypeOf(float64(0)): C.PN_DOUBLE,
+
+ reflect.TypeOf(""): C.PN_STRING,
+ reflect.TypeOf((*Symbol)(nil)).Elem(): C.PN_SYMBOL,
+ reflect.TypeOf((*Binary)(nil)).Elem(): C.PN_BINARY,
+ reflect.TypeOf([]byte{}): C.PN_BINARY,
+
+ reflect.TypeOf((*time.Time)(nil)).Elem(): C.PN_TIMESTAMP,
+ reflect.TypeOf((*UUID)(nil)).Elem(): C.PN_UUID,
+ reflect.TypeOf((*Char)(nil)).Elem(): C.PN_CHAR,
+}
+
+// Compute mapping of int/uint at runtime as they depend on execution environment.
+func init() {
+ if intIs64 {
+ arrayTypeMap[reflect.TypeOf(int(0))] = C.PN_LONG
+ arrayTypeMap[reflect.TypeOf(uint(0))] = C.PN_ULONG
+ } else {
+ arrayTypeMap[reflect.TypeOf(int(0))] = C.PN_INT
+ arrayTypeMap[reflect.TypeOf(uint(0))] = C.PN_UINT
+ }
+}
+
+func clearMarshal(v interface{}, data *C.pn_data_t) {
+ C.pn_data_clear(data)
+ marshal(v, data)
+}
+
+// Encoder encodes AMQP values to an io.Writer
+type Encoder struct {
+ writer io.Writer
+ buffer []byte
+}
+
+// New encoder returns a new encoder that writes to w.
+func NewEncoder(w io.Writer) *Encoder {
+ return &Encoder{w, make([]byte, minEncode)}
+}
+
+func (e *Encoder) Encode(v interface{}) (err error) {
+ e.buffer, err = Marshal(v, e.buffer)
+ if err == nil {
+ _, err = e.writer.Write(e.buffer)
+ }
+ return err
+}
http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/6f799990/amqp/unmarshal.go
----------------------------------------------------------------------
diff --cc amqp/unmarshal.go
index 97e8437,0000000..2c6e3f1
mode 100644,000000..100644
--- a/amqp/unmarshal.go
+++ b/amqp/unmarshal.go
@@@ -1,733 -1,0 +1,731 @@@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+oor more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+
+package amqp
+
+// #include <proton/codec.h>
+import "C"
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "time"
+ "unsafe"
+)
+
+const minDecode = 1024
+
+// Error returned if AMQP data cannot be unmarshaled as the desired Go type.
+type UnmarshalError struct {
+ // The name of the AMQP type.
+ AMQPType string
+ // The Go type.
+ GoType reflect.Type
+
+ s string
+}
+
+func (e UnmarshalError) Error() string { return e.s }
+
+// Error returned if there are not enough bytes to decode a complete AMQP value.
+var EndOfData = &UnmarshalError{s: "Not enough data for AMQP value"}
+
+var badData = &UnmarshalError{s: "Unexpected error in data"}
+
+func newUnmarshalError(pnType C.pn_type_t, v interface{}) *UnmarshalError {
+ e := &UnmarshalError{
+ AMQPType: C.pn_type_t(pnType).String(),
+ GoType: reflect.TypeOf(v),
+ }
+ if e.GoType == nil || e.GoType.Kind() != reflect.Ptr {
+ e.s = fmt.Sprintf("cannot unmarshal to Go type %v, not a pointer", e.GoType)
+ } else {
+ e.s = fmt.Sprintf("cannot unmarshal AMQP %v to Go %v", e.AMQPType, e.GoType.Elem())
+ }
+ return e
+}
+
+func doPanic(data *C.pn_data_t, v interface{}) {
+ e := newUnmarshalError(C.pn_data_type(data), v)
+ panic(e)
+}
+
+func doPanicMsg(data *C.pn_data_t, v interface{}, msg string) {
+ e := newUnmarshalError(C.pn_data_type(data), v)
+ e.s = e.s + ": " + msg
+ panic(e)
+}
+
+func panicIfBadData(data *C.pn_data_t, v interface{}) {
+ if C.pn_data_errno(data) != 0 {
+ doPanicMsg(data, v, PnError(C.pn_data_error(data)).Error())
+ }
+}
+
+func panicUnless(ok bool, data *C.pn_data_t, v interface{}) {
+ if !ok {
+ doPanic(data, v)
+ }
+}
+
+func checkOp(ok bool, v interface{}) {
+ if !ok {
+ panic(&badData)
+ }
+}
+
+//
+// Decoding from a pn_data_t
+//
+// NOTE: we use panic() to signal a decoding error, simplifies decoding logic.
+// We recover() at the highest possible level - i.e. in the exported Unmarshal or Decode.
+//
+
+// Decoder decodes AMQP values from an io.Reader.
+//
+type Decoder struct {
+ reader io.Reader
+ buffer bytes.Buffer
+}
+
+// NewDecoder returns a new decoder that reads from r.
+//
+// The decoder has it's own buffer and may read more data than required for the
+// AMQP values requested. Use Buffered to see if there is data left in the
+// buffer.
+//
+func NewDecoder(r io.Reader) *Decoder {
+ return &Decoder{r, bytes.Buffer{}}
+}
+
+// Buffered returns a reader of the data remaining in the Decoder's buffer. The
+// reader is valid until the next call to Decode.
+//
+func (d *Decoder) Buffered() io.Reader {
+ return bytes.NewReader(d.buffer.Bytes())
+}
+
+// Decode reads the next AMQP value from the Reader and stores it in the value pointed to by v.
+//
+// See the documentation for Unmarshal for details about the conversion of AMQP into a Go value.
+//
+func (d *Decoder) Decode(v interface{}) (err error) {
+ data := C.pn_data(0)
+ defer C.pn_data_free(data)
+ var n int
+ for n, err = decode(data, d.buffer.Bytes()); err == EndOfData; {
+ err = d.more()
+ if err == nil {
+ n, err = decode(data, d.buffer.Bytes())
+ }
+ }
+ if err == nil {
+ if err = recoverUnmarshal(v, data); err == nil {
+ d.buffer.Next(n)
+ }
+ }
+ return
+}
+
+/*
+
+Unmarshal decodes AMQP-encoded bytes and stores the result in the Go value
+pointed to by v. Legal conversions from the source AMQP type to the target Go
+type as follows:
+
+ +----------------------------+-------------------------------------------------+
+ |Target Go type | Allowed AMQP types
+ +============================+==================================================+
+ |bool |bool |
+ +----------------------------+--------------------------------------------------+
+ |int, int8, int16, int32, |Equivalent or smaller signed integer type: |
+ |int64 |byte, short, int, long or char |
+ +----------------------------+--------------------------------------------------+
+ |uint, uint8, uint16, uint32,|Equivalent or smaller unsigned integer type: |
+ |uint64 |ubyte, ushort, uint, ulong |
+ +----------------------------+--------------------------------------------------+
+ |float32, float64 |Equivalent or smaller float or double |
+ +----------------------------+--------------------------------------------------+
+ |string, []byte |string, symbol or binary |
+ +----------------------------+--------------------------------------------------+
+ |Symbol |symbol |
+ +----------------------------+--------------------------------------------------+
+ |Char |char |
+ +----------------------------+--------------------------------------------------+
+ |Described |AMQP described type [1] |
+ +----------------------------+--------------------------------------------------+
+ |Time |timestamp |
+ +----------------------------+--------------------------------------------------+
+ |UUID |uuid |
+ +----------------------------+--------------------------------------------------+
+ |map[interface{}]interface{} |Any AMQP map |
+ +----------------------------+--------------------------------------------------+
+ |map[K]T |map, provided all keys and values can unmarshal |
+ | |to types K,T |
+ +----------------------------+--------------------------------------------------+
+ |[]interface{} |AMQP list or array |
+ +----------------------------+--------------------------------------------------+
+ |[]T |list or array if elements can unmarshal as T |
+ +----------------------------+------------------n-------------------------------+
+ |interface{} |any AMQP type[2] |
+ +----------------------------+--------------------------------------------------+
+
+[1] An AMQP described value can also unmarshal to a plain value, discarding the
+descriptor. Unmarshalling into the special amqp.Described type preserves the
+descriptor.
+
+[2] Any AMQP value can be unmarshalled to an interface{}. The Go type is
+determined by the AMQP type as follows:
+
+ +----------------------------+--------------------------------------------------+
+ |Source AMQP Type |Go Type in target interface{} |
+ +============================+==================================================+
+ |bool |bool |
+ +----------------------------+--------------------------------------------------+
+ |byte,short,int,long |int8,int16,int32,int64 |
+ +----------------------------+--------------------------------------------------+
+ |ubyte,ushort,uint,ulong |uint8,uint16,uint32,uint64 |
+ +----------------------------+--------------------------------------------------+
+ |float, double |float32, float64 |
+ +----------------------------+--------------------------------------------------+
+ |string |string |
+ +----------------------------+--------------------------------------------------+
+ |symbol |Symbol |
+ +----------------------------+--------------------------------------------------+
+ |char |Char |
+ +----------------------------+--------------------------------------------------+
+ |binary |Binary |
+ +----------------------------+--------------------------------------------------+
+ |null |nil |
+ +----------------------------+--------------------------------------------------+
+ |described type |Described |
+ +----------------------------+--------------------------------------------------+
+ |timestamp |time.Time |
+ +----------------------------+--------------------------------------------------+
+ |uuid |UUID |
+ +----------------------------+--------------------------------------------------+
+ |map |Map or AnyMap[4] |
+ +----------------------------+--------------------------------------------------+
+ |list |List |
+ +----------------------------+--------------------------------------------------+
+ |array |[]T for simple types, T is chosen as above [3] |
+ +----------------------------+--------------------------------------------------+
+
+[3] An AMQP array of simple types unmarshalls as a slice of the corresponding Go type.
+An AMQP array containing complex types (lists, maps or nested arrays) unmarshals
+to the generic array type amqp.Array
+
+[4] An AMQP map unmarshals as the generic `type Map map[interface{}]interface{}`
+unless it contains key values that are illegal as Go map types, in which case
+it unmarshals as type AnyMap.
+
+The following Go types cannot be unmarshaled: uintptr, function, interface,
+channel, array (use slice), struct
+
- AMQP types not yet supported:
- - decimal32/64/128
- - maps with key values that are not legal Go map keys.
++AMQP types not yet supported: decimal32/64/128
+*/
+func Unmarshal(bytes []byte, v interface{}) (n int, err error) {
+ data := C.pn_data(0)
+ defer C.pn_data_free(data)
+ n, err = decode(data, bytes)
+ if err == nil {
+ err = recoverUnmarshal(v, data)
+ }
+ return
+}
+
+// Internal
+func UnmarshalUnsafe(pnData unsafe.Pointer, v interface{}) (err error) {
+ return recoverUnmarshal(v, (*C.pn_data_t)(pnData))
+}
+
+// more reads more data when we can't parse a complete AMQP type
+func (d *Decoder) more() error {
+ var readSize int64 = minDecode
+ if int64(d.buffer.Len()) > readSize { // Grow by doubling
+ readSize = int64(d.buffer.Len())
+ }
+ var n int64
+ n, err := d.buffer.ReadFrom(io.LimitReader(d.reader, readSize))
+ if n == 0 && err == nil { // ReadFrom won't report io.EOF, just returns 0
+ err = io.EOF
+ }
+ return err
+}
+
+// Call unmarshal(), convert panic to error value
+func recoverUnmarshal(v interface{}, data *C.pn_data_t) (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if uerr, ok := r.(*UnmarshalError); ok {
+ err = uerr
+ } else {
+ panic(r)
+ }
+ }
+ }()
+ unmarshal(v, data)
+ return nil
+}
+
+// Unmarshal from data into value pointed at by v. Returns v.
+// NOTE: If you update this you also need to update getInterface()
+func unmarshal(v interface{}, data *C.pn_data_t) {
+ rt := reflect.TypeOf(v)
+ rv := reflect.ValueOf(v)
+ panicUnless(v != nil && rt.Kind() == reflect.Ptr && !rv.IsNil(), data, v)
+
+ // Check for PN_DESCRIBED first, as described types can unmarshal into any of the Go types.
+ // An interface{} target is handled in the switch below, even for described types.
+ if _, isInterface := v.(*interface{}); !isInterface && bool(C.pn_data_is_described(data)) {
+ getDescribed(data, v)
+ return
+ }
+
+ // Unmarshal based on the target type
+ pnType := C.pn_data_type(data)
+ switch v := v.(type) {
+
+ case *bool:
+ panicUnless(pnType == C.PN_BOOL, data, v)
+ *v = bool(C.pn_data_get_bool(data))
+
+ case *int8:
+ panicUnless(pnType == C.PN_BYTE, data, v)
+ *v = int8(C.pn_data_get_byte(data))
+
+ case *uint8:
+ panicUnless(pnType == C.PN_UBYTE, data, v)
+ *v = uint8(C.pn_data_get_ubyte(data))
+
+ case *int16:
+ switch C.pn_data_type(data) {
+ case C.PN_BYTE:
+ *v = int16(C.pn_data_get_byte(data))
+ case C.PN_SHORT:
+ *v = int16(C.pn_data_get_short(data))
+ default:
+ doPanic(data, v)
+ }
+
+ case *uint16:
+ switch pnType {
+ case C.PN_UBYTE:
+ *v = uint16(C.pn_data_get_ubyte(data))
+ case C.PN_USHORT:
+ *v = uint16(C.pn_data_get_ushort(data))
+ default:
+ doPanic(data, v)
+ }
+
+ case *int32:
+ switch pnType {
+ case C.PN_CHAR:
+ *v = int32(C.pn_data_get_char(data))
+ case C.PN_BYTE:
+ *v = int32(C.pn_data_get_byte(data))
+ case C.PN_SHORT:
+ *v = int32(C.pn_data_get_short(data))
+ case C.PN_INT:
+ *v = int32(C.pn_data_get_int(data))
+ default:
+ doPanic(data, v)
+ }
+
+ case *uint32:
+ switch pnType {
+ case C.PN_CHAR:
+ *v = uint32(C.pn_data_get_char(data))
+ case C.PN_UBYTE:
+ *v = uint32(C.pn_data_get_ubyte(data))
+ case C.PN_USHORT:
+ *v = uint32(C.pn_data_get_ushort(data))
+ case C.PN_UINT:
+ *v = uint32(C.pn_data_get_uint(data))
+ default:
+ doPanic(data, v)
+ }
+
+ case *int64:
+ switch pnType {
+ case C.PN_CHAR:
+ *v = int64(C.pn_data_get_char(data))
+ case C.PN_BYTE:
+ *v = int64(C.pn_data_get_byte(data))
+ case C.PN_SHORT:
+ *v = int64(C.pn_data_get_short(data))
+ case C.PN_INT:
+ *v = int64(C.pn_data_get_int(data))
+ case C.PN_LONG:
+ *v = int64(C.pn_data_get_long(data))
+ default:
+ doPanic(data, v)
+ }
+
+ case *uint64:
+ switch pnType {
+ case C.PN_CHAR:
+ *v = uint64(C.pn_data_get_char(data))
+ case C.PN_UBYTE:
+ *v = uint64(C.pn_data_get_ubyte(data))
+ case C.PN_USHORT:
+ *v = uint64(C.pn_data_get_ushort(data))
+ case C.PN_ULONG:
+ *v = uint64(C.pn_data_get_ulong(data))
+ default:
+ doPanic(data, v)
+ }
+
+ case *int:
+ switch pnType {
+ case C.PN_CHAR:
+ *v = int(C.pn_data_get_char(data))
+ case C.PN_BYTE:
+ *v = int(C.pn_data_get_byte(data))
+ case C.PN_SHORT:
+ *v = int(C.pn_data_get_short(data))
+ case C.PN_INT:
+ *v = int(C.pn_data_get_int(data))
+ case C.PN_LONG:
+ if intIs64 {
+ *v = int(C.pn_data_get_long(data))
+ } else {
+ doPanic(data, v)
+ }
+ default:
+ doPanic(data, v)
+ }
+
+ case *uint:
+ switch pnType {
+ case C.PN_CHAR:
+ *v = uint(C.pn_data_get_char(data))
+ case C.PN_UBYTE:
+ *v = uint(C.pn_data_get_ubyte(data))
+ case C.PN_USHORT:
+ *v = uint(C.pn_data_get_ushort(data))
+ case C.PN_UINT:
+ *v = uint(C.pn_data_get_uint(data))
+ case C.PN_ULONG:
+ if intIs64 {
+ *v = uint(C.pn_data_get_ulong(data))
+ } else {
+ doPanic(data, v)
+ }
+ default:
+ doPanic(data, v)
+ }
+
+ case *float32:
+ panicUnless(pnType == C.PN_FLOAT, data, v)
+ *v = float32(C.pn_data_get_float(data))
+
+ case *float64:
+ switch pnType {
+ case C.PN_FLOAT:
+ *v = float64(C.pn_data_get_float(data))
+ case C.PN_DOUBLE:
+ *v = float64(C.pn_data_get_double(data))
+ default:
+ doPanic(data, v)
+ }
+
+ case *string:
+ switch pnType {
+ case C.PN_STRING:
+ *v = goString(C.pn_data_get_string(data))
+ case C.PN_SYMBOL:
+ *v = goString(C.pn_data_get_symbol(data))
+ case C.PN_BINARY:
+ *v = goString(C.pn_data_get_binary(data))
+ default:
+ doPanic(data, v)
+ }
+
+ case *[]byte:
+ switch pnType {
+ case C.PN_STRING:
+ *v = goBytes(C.pn_data_get_string(data))
+ case C.PN_SYMBOL:
+ *v = goBytes(C.pn_data_get_symbol(data))
+ case C.PN_BINARY:
+ *v = goBytes(C.pn_data_get_binary(data))
+ default:
+ doPanic(data, v)
+ }
+ return
+
+ case *Char:
+ panicUnless(pnType == C.PN_CHAR, data, v)
+ *v = Char(C.pn_data_get_char(data))
+
+ case *Binary:
+ panicUnless(pnType == C.PN_BINARY, data, v)
+ *v = Binary(goBytes(C.pn_data_get_binary(data)))
+
+ case *Symbol:
+ panicUnless(pnType == C.PN_SYMBOL, data, v)
+ *v = Symbol(goBytes(C.pn_data_get_symbol(data)))
+
+ case *time.Time:
+ panicUnless(pnType == C.PN_TIMESTAMP, data, v)
+ *v = time.Unix(0, int64(C.pn_data_get_timestamp(data))*1000)
+
+ case *UUID:
+ panicUnless(pnType == C.PN_UUID, data, v)
+ pn := C.pn_data_get_uuid(data)
+ copy((*v)[:], C.GoBytes(unsafe.Pointer(&pn.bytes), 16))
+
+ case *AnnotationKey:
+ panicUnless(pnType == C.PN_ULONG || pnType == C.PN_SYMBOL || pnType == C.PN_STRING, data, v)
+ unmarshal(&v.value, data)
+
+ case *AnyMap:
+ panicUnless(C.pn_data_type(data) == C.PN_MAP, data, v)
+ n := int(C.pn_data_get_map(data)) / 2
+ if cap(*v) < n {
+ *v = make(AnyMap, n)
+ }
+ *v = (*v)[:n]
+ data.enter(*v)
+ defer data.exit(*v)
+ for i := 0; i < n; i++ {
+ data.next(*v)
+ unmarshal(&(*v)[i].Key, data)
+ data.next(*v)
+ unmarshal(&(*v)[i].Value, data)
+ }
+
+ case *interface{}:
+ getInterface(data, v)
+
+ default: // This is not one of the fixed well-known types, reflect for map and slice types
+
+ switch rt.Elem().Kind() {
+ case reflect.Map:
+ getMap(data, v)
+ case reflect.Slice:
+ getSequence(data, v)
+ default:
+ doPanic(data, v)
+ }
+ }
+}
+
+// Unmarshalling into an interface{} the type is determined by the AMQP source type,
+// since the interface{} target can hold any Go type.
+func getInterface(data *C.pn_data_t, vp *interface{}) {
+ pnType := C.pn_data_type(data)
+ switch pnType {
+ case C.PN_BOOL:
+ *vp = bool(C.pn_data_get_bool(data))
+ case C.PN_UBYTE:
+ *vp = uint8(C.pn_data_get_ubyte(data))
+ case C.PN_BYTE:
+ *vp = int8(C.pn_data_get_byte(data))
+ case C.PN_USHORT:
+ *vp = uint16(C.pn_data_get_ushort(data))
+ case C.PN_SHORT:
+ *vp = int16(C.pn_data_get_short(data))
+ case C.PN_UINT:
+ *vp = uint32(C.pn_data_get_uint(data))
+ case C.PN_INT:
+ *vp = int32(C.pn_data_get_int(data))
+ case C.PN_CHAR:
+ *vp = Char(C.pn_data_get_char(data))
+ case C.PN_ULONG:
+ *vp = uint64(C.pn_data_get_ulong(data))
+ case C.PN_LONG:
+ *vp = int64(C.pn_data_get_long(data))
+ case C.PN_FLOAT:
+ *vp = float32(C.pn_data_get_float(data))
+ case C.PN_DOUBLE:
+ *vp = float64(C.pn_data_get_double(data))
+ case C.PN_BINARY:
+ *vp = Binary(goBytes(C.pn_data_get_binary(data)))
+ case C.PN_STRING:
+ *vp = goString(C.pn_data_get_string(data))
+ case C.PN_SYMBOL:
+ *vp = Symbol(goString(C.pn_data_get_symbol(data)))
+ case C.PN_TIMESTAMP:
+ *vp = time.Unix(0, int64(C.pn_data_get_timestamp(data))*1000)
+ case C.PN_UUID:
+ var u UUID
+ unmarshal(&u, data)
+ *vp = u
+ case C.PN_MAP:
+ // We will try to unmarshal as a Map first, if that fails try AnyMap
+ m := make(Map, int(C.pn_data_get_map(data))/2)
+ if err := recoverUnmarshal(&m, data); err == nil {
+ *vp = m
+ } else {
+ am := make(AnyMap, int(C.pn_data_get_map(data))/2)
+ unmarshal(&am, data)
+ *vp = am
+ }
+ case C.PN_LIST:
+ l := List{}
+ unmarshal(&l, data)
+ *vp = l
+ case C.PN_ARRAY:
+ sp := getArrayStore(data) // interface{} containing T* for suitable T
+ unmarshal(sp, data)
+ *vp = reflect.ValueOf(sp).Elem().Interface()
+ case C.PN_DESCRIBED:
+ d := Described{}
+ unmarshal(&d, data)
+ *vp = d
+ case C.PN_NULL:
+ *vp = nil
+ case C.PN_INVALID:
+ // Allow decoding from an empty data object to an interface, treat it like NULL.
+ // This happens when optional values or properties are omitted from a message.
+ *vp = nil
+ default: // Don't know how to handle this
+ panic(newUnmarshalError(pnType, vp))
+ }
+}
+
+// Return an interface{} containing a pointer to an appropriate slice or Array
+func getArrayStore(data *C.pn_data_t) interface{} {
+ // TODO aconway 2017-11-10: described arrays.
+ switch C.pn_data_get_array_type(data) {
+ case C.PN_BOOL:
+ return new([]bool)
+ case C.PN_UBYTE:
+ return new([]uint8)
+ case C.PN_BYTE:
+ return new([]int8)
+ case C.PN_USHORT:
+ return new([]uint16)
+ case C.PN_SHORT:
+ return new([]int16)
+ case C.PN_UINT:
+ return new([]uint32)
+ case C.PN_INT:
+ return new([]int32)
+ case C.PN_CHAR:
+ return new([]Char)
+ case C.PN_ULONG:
+ return new([]uint64)
+ case C.PN_LONG:
+ return new([]int64)
+ case C.PN_FLOAT:
+ return new([]float32)
+ case C.PN_DOUBLE:
+ return new([]float64)
+ case C.PN_BINARY:
+ return new([]Binary)
+ case C.PN_STRING:
+ return new([]string)
+ case C.PN_SYMBOL:
+ return new([]Symbol)
+ case C.PN_TIMESTAMP:
+ return new([]time.Time)
+ case C.PN_UUID:
+ return new([]UUID)
+ }
+ return new(Array) // Not a simple type, use generic Array
+}
+
+var typeOfInterface = reflect.TypeOf(interface{}(nil))
+
+// get into map pointed at by v
+func getMap(data *C.pn_data_t, v interface{}) {
+ panicUnless(C.pn_data_type(data) == C.PN_MAP, data, v)
+ n := int(C.pn_data_get_map(data)) / 2
+ mapValue := reflect.ValueOf(v).Elem()
+ mapValue.Set(reflect.MakeMap(mapValue.Type())) // Clear the map
+ data.enter(v)
+ defer data.exit(v)
+ // Allocate re-usable key/val values
+ keyType := mapValue.Type().Key()
+ keyPtr := reflect.New(keyType)
+ valPtr := reflect.New(mapValue.Type().Elem())
+ for i := 0; i < n; i++ {
+ data.next(v)
+ unmarshal(keyPtr.Interface(), data)
+ if keyType.Kind() == reflect.Interface && !keyPtr.Elem().Elem().Type().Comparable() {
+ doPanicMsg(data, v, fmt.Sprintf("key %#v is not comparable", keyPtr.Elem().Interface()))
+ }
+ data.next(v)
+ unmarshal(valPtr.Interface(), data)
+ mapValue.SetMapIndex(keyPtr.Elem(), valPtr.Elem())
+ }
+}
+
+func getSequence(data *C.pn_data_t, vp interface{}) {
+ var count int
+ pnType := C.pn_data_type(data)
+ switch pnType {
+ case C.PN_LIST:
+ count = int(C.pn_data_get_list(data))
+ case C.PN_ARRAY:
+ count = int(C.pn_data_get_array(data))
+ default:
+ doPanic(data, vp)
+ }
+ listValue := reflect.MakeSlice(reflect.TypeOf(vp).Elem(), count, count)
+ data.enter(vp)
+ defer data.exit(vp)
+ for i := 0; i < count; i++ {
+ data.next(vp)
+ val := reflect.New(listValue.Type().Elem())
+ unmarshal(val.Interface(), data)
+ listValue.Index(i).Set(val.Elem())
+ }
+ reflect.ValueOf(vp).Elem().Set(listValue)
+}
+
+func getDescribed(data *C.pn_data_t, vp interface{}) {
+ d, isDescribed := vp.(*Described)
+ data.enter(vp)
+ defer data.exit(vp)
+ data.next(vp)
+ if isDescribed {
+ unmarshal(&d.Descriptor, data)
+ data.next(vp)
+ unmarshal(&d.Value, data)
+ } else {
+ data.next(vp) // Skip descriptor
+ unmarshal(vp, data) // Unmarshal plain value
+ }
+}
+
+// decode from bytes.
+// Return bytes decoded or 0 if we could not decode a complete object.
+//
+func decode(data *C.pn_data_t, bytes []byte) (int, error) {
+ n := C.pn_data_decode(data, cPtr(bytes), cLen(bytes))
+ if n == C.PN_UNDERFLOW {
+ C.pn_error_clear(C.pn_data_error(data))
+ return 0, EndOfData
+ } else if n <= 0 {
+ return 0, &UnmarshalError{s: fmt.Sprintf("unmarshal %v", PnErrorCode(n))}
+ }
+ return int(n), nil
+}
+
+// Checked versions of pn_data functions
+
+func (data *C.pn_data_t) enter(v interface{}) { checkOp(bool(C.pn_data_enter(data)), v) }
+func (data *C.pn_data_t) exit(v interface{}) { checkOp(bool(C.pn_data_exit(data)), v) }
+func (data *C.pn_data_t) next(v interface{}) { checkOp(bool(C.pn_data_next(data)), v) }
http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/6f799990/electron/auth_test.go
----------------------------------------------------------------------
diff --cc electron/auth_test.go
index 9fa9fa2,0000000..162b366
mode 100644,000000..100644
--- a/electron/auth_test.go
+++ b/electron/auth_test.go
@@@ -1,137 -1,0 +1,143 @@@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+
+package electron
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+func testAuthClientServer(t *testing.T, copts []ConnectionOption, sopts []ConnectionOption) (got connectionSettings, err error) {
+ client, server := newClientServerOpts(t, copts, sopts)
+ defer closeClientServer(client, server)
+
+ go func() {
+ for in := range server.Incoming() {
+ switch in := in.(type) {
+ case *IncomingConnection:
+ got = connectionSettings{user: in.User(), virtualHost: in.VirtualHost()}
+ }
+ in.Accept()
+ }
+ }()
+
+ err = client.Sync()
+ return
+}
+
+func TestAuthAnonymous(t *testing.T) {
- configureSASL()
+ got, err := testAuthClientServer(t,
+ []ConnectionOption{User("fred"), VirtualHost("vhost"), SASLAllowInsecure(true)},
+ []ConnectionOption{SASLAllowedMechs("ANONYMOUS"), SASLAllowInsecure(true)})
+ fatalIf(t, err)
+ errorIf(t, checkEqual(connectionSettings{user: "anonymous", virtualHost: "vhost"}, got))
+}
+
+func TestAuthPlain(t *testing.T) {
- if !SASLExtended() {
- t.Skip()
- }
- fatalIf(t, configureSASL())
++ extendedSASL.startTest(t)
+ got, err := testAuthClientServer(t,
+ []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("fred@proton"), Password([]byte("xxx"))},
+ []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")})
+ fatalIf(t, err)
+ errorIf(t, checkEqual(connectionSettings{user: "fred@proton"}, got))
+}
+
+func TestAuthBadPass(t *testing.T) {
- if !SASLExtended() {
- t.Skip()
- }
- fatalIf(t, configureSASL())
++ extendedSASL.startTest(t)
+ _, err := testAuthClientServer(t,
+ []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("fred@proton"), Password([]byte("yyy"))},
+ []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")})
+ if err == nil {
+ t.Error("Expected auth failure for bad pass")
+ }
+}
+
+func TestAuthBadUser(t *testing.T) {
- if !SASLExtended() {
- t.Skip()
- }
- fatalIf(t, configureSASL())
++ extendedSASL.startTest(t)
+ _, err := testAuthClientServer(t,
+ []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("foo@bar"), Password([]byte("yyy"))},
+ []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")})
+ if err == nil {
+ t.Error("Expected auth failure for bad user")
+ }
+}
+
- var confDir string
- var confErr error
++type extendedSASLState struct {
++ err error
++ dir string
++}
+
- func configureSASL() error {
- if confDir != "" || confErr != nil {
- return confErr
- }
- confDir, confErr = ioutil.TempDir("", "")
- if confErr != nil {
- return confErr
++func (s *extendedSASLState) setup() {
++ if SASLExtended() {
++ if s.dir, s.err = ioutil.TempDir("", ""); s.err == nil {
++ GlobalSASLConfigDir(s.dir)
++ GlobalSASLConfigName("test")
++ conf := filepath.Join(s.dir, "test.conf")
++ db := filepath.Join(s.dir, "proton.sasldb")
++ saslpasswd := os.Getenv("SASLPASSWD")
++ if saslpasswd == "" {
++ saslpasswd = "saslpasswd2"
++ }
++ cmd := exec.Command(saslpasswd, "-c", "-p", "-f", db, "-u", "proton", "fred")
++ cmd.Stdin = strings.NewReader("xxx") // Password
++ if _, s.err = cmd.CombinedOutput(); s.err == nil {
++ confStr := fmt.Sprintf(`
++sasldb_path: %s
++mech_list: EXTERNAL DIGEST-MD5 SCRAM-SHA-1 CRAM-MD5 PLAIN ANONYMOUS
++`, db)
++ s.err = ioutil.WriteFile(conf, []byte(confStr), os.ModePerm)
++ }
++ }
+ }
++ // Note we don't do anything with s.err now, tests that need the
++ // extended SASL config will fail if s.err != nil. If no such tests
++ // are run then it is not an error that we couldn't set it up.
++}
+
- GlobalSASLConfigDir(confDir)
- GlobalSASLConfigName("test")
- conf := filepath.Join(confDir, "test.conf")
-
- db := filepath.Join(confDir, "proton.sasldb")
- saslpasswd := os.Getenv("SASLPASSWD");
- if saslpasswd == "" {
- saslpasswd = "saslpasswd2"
- }
- cmd := exec.Command(saslpasswd, "-c", "-p", "-f", db, "-u", "proton", "fred")
- cmd.Stdin = strings.NewReader("xxx") // Password
- if out, err := cmd.CombinedOutput(); err != nil {
- confErr = fmt.Errorf("saslpasswd2 failed: %s\n%s", err, out)
- return confErr
++func (s extendedSASLState) teardown() {
++ if s.dir != "" {
++ _ = os.RemoveAll(s.dir)
+ }
- confStr := "sasldb_path: " + db + "\nmech_list: EXTERNAL DIGEST-MD5 SCRAM-SHA-1 CRAM-MD5 PLAIN ANONYMOUS\n"
- if err := ioutil.WriteFile(conf, []byte(confStr), os.ModePerm); err != nil {
- confErr = fmt.Errorf("write conf file %s failed: %s", conf, err)
++}
++
++func (s extendedSASLState) startTest(t *testing.T) {
++ if !SASLExtended() {
++ t.Skipf("Extended SASL not enabled")
++ } else if extendedSASL.err != nil {
++ t.Skipf("Extended SASL setup error: %v", extendedSASL.err)
+ }
- return confErr
+}
+
++var extendedSASL extendedSASLState
++
+func TestMain(m *testing.M) {
++ // Do global SASL setup/teardown in main.
++ // Doing it on-demand makes the tests fragile to parallel test runs and
++ // changes in test ordering.
++ extendedSASL.setup()
+ status := m.Run()
- if confDir != "" {
- _ = os.RemoveAll(confDir)
- }
++ extendedSASL.teardown()
+ os.Exit(status)
+}
http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/6f799990/electron/connection.go
----------------------------------------------------------------------
diff --cc electron/connection.go
index 731e64d,0000000..9c0ef31
mode 100644,000000..100644
--- a/electron/connection.go
+++ b/electron/connection.go
@@@ -1,421 -1,0 +1,442 @@@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+
+package electron
+
+// #include <proton/disposition.h>
+import "C"
+
+import (
+ "net"
+ "qpid.apache.org/proton"
+ "sync"
+ "time"
+)
+
+// Settings associated with a Connection.
+type ConnectionSettings interface {
+ // Authenticated user name associated with the connection.
+ User() string
+
+ // The AMQP virtual host name for the connection.
+ //
+ // Optional, useful when the server has multiple names and provides different
+ // service based on the name the client uses to connect.
+ //
+ // By default it is set to the DNS host name that the client uses to connect,
+ // but it can be set to something different at the client side with the
+ // VirtualHost() option.
+ //
+ // Returns error if the connection fails to authenticate.
+ VirtualHost() string
+
+ // Heartbeat is the maximum delay between sending frames that the remote peer
+ // has requested of us. If the interval expires an empty "heartbeat" frame
+ // will be sent automatically to keep the connection open.
+ Heartbeat() time.Duration
+}
+
+// Connection is an AMQP connection, created by a Container.
+type Connection interface {
+ Endpoint
+ ConnectionSettings
+
+ // Sender opens a new sender on the DefaultSession.
+ Sender(...LinkOption) (Sender, error)
+
+ // Receiver opens a new Receiver on the DefaultSession().
+ Receiver(...LinkOption) (Receiver, error)
+
+ // DefaultSession() returns a default session for the connection. It is opened
+ // on the first call to DefaultSession and returned on subsequent calls.
+ DefaultSession() (Session, error)
+
+ // Session opens a new session.
+ Session(...SessionOption) (Session, error)
+
+ // Container for the connection.
+ Container() Container
+
+ // Disconnect the connection abruptly with an error.
+ Disconnect(error)
+
+ // Wait waits for the connection to be disconnected.
+ Wait() error
+
+ // WaitTimeout is like Wait but returns Timeout if the timeout expires.
+ WaitTimeout(time.Duration) error
+
+ // Incoming returns a channel for incoming endpoints opened by the remote peer.
+ // See the Incoming interface for more detail.
+ //
+ // Note: this channel will first return an *IncomingConnection for the
+ // connection itself which allows you to look at security information and
+ // decide whether to Accept() or Reject() the connection. Then it will return
+ // *IncomingSession, *IncomingSender and *IncomingReceiver as they are opened
+ // by the remote end.
+ //
+ // Note 2: you must receiving from Incoming() and call Accept/Reject to avoid
+ // blocking electron event loop. Normally you would run a loop in a goroutine
+ // to handle incoming types that interest and Accept() those that don't.
+ Incoming() <-chan Incoming
+}
+
+type connectionSettings struct {
+ user, virtualHost string
+ heartbeat time.Duration
+}
+
+func (c connectionSettings) User() string { return c.user }
+func (c connectionSettings) VirtualHost() string { return c.virtualHost }
+func (c connectionSettings) Heartbeat() time.Duration { return c.heartbeat }
+
+// ConnectionOption can be passed when creating a connection to configure various options
+type ConnectionOption func(*connection)
+
+// User returns a ConnectionOption sets the user name for a connection
+func User(user string) ConnectionOption {
+ return func(c *connection) {
+ c.user = user
+ c.pConnection.SetUser(user)
+ }
+}
+
+// VirtualHost returns a ConnectionOption to set the AMQP virtual host for the connection.
+// Only applies to outbound client connection.
+func VirtualHost(virtualHost string) ConnectionOption {
+ return func(c *connection) {
+ c.virtualHost = virtualHost
+ c.pConnection.SetHostname(virtualHost)
+ }
+}
+
+// Password returns a ConnectionOption to set the password used to establish a
+// connection. Only applies to outbound client connection.
+//
+// The connection will erase its copy of the password from memory as soon as it
+// has been used to authenticate. If you are concerned about passwords staying in
+// memory you should never store them as strings, and should overwrite your
+// copy as soon as you are done with it.
+//
+func Password(password []byte) ConnectionOption {
+ return func(c *connection) { c.pConnection.SetPassword(password) }
+}
+
+// Server returns a ConnectionOption to put the connection in server mode for incoming connections.
+//
+// A server connection will do protocol negotiation to accept a incoming AMQP
+// connection. Normally you would call this for a connection created by
+// net.Listener.Accept()
+//
+func Server() ConnectionOption {
+ return func(c *connection) { c.engine.Server(); c.server = true; AllowIncoming()(c) }
+}
+
+// AllowIncoming returns a ConnectionOption to enable incoming endpoints, see
+// Connection.Incoming() This is automatically set for Server() connections.
+func AllowIncoming() ConnectionOption {
+ return func(c *connection) { c.incoming = make(chan Incoming) }
+}
+
+// Parent returns a ConnectionOption that associates the Connection with it's Container
+// If not set a connection will create its own default container.
+func Parent(cont Container) ConnectionOption {
+ return func(c *connection) { c.container = cont.(*container) }
+}
+
+type connection struct {
+ endpoint
+ connectionSettings
+
+ defaultSessionOnce, closeOnce sync.Once
+
+ container *container
+ conn net.Conn
+ server bool
+ incoming chan Incoming
+ handler *handler
+ engine *proton.Engine
+ pConnection proton.Connection
+
+ defaultSession Session
+}
+
+// NewConnection creates a connection with the given options.
+func NewConnection(conn net.Conn, opts ...ConnectionOption) (*connection, error) {
+ c := &connection{
+ conn: conn,
+ }
+ c.handler = newHandler(c)
+ var err error
+ c.engine, err = proton.NewEngine(c.conn, c.handler.delegator)
+ if err != nil {
+ return nil, err
+ }
+ c.pConnection = c.engine.Connection()
+ for _, set := range opts {
+ set(c)
+ }
+ if c.container == nil {
+ c.container = NewContainer("").(*container)
+ }
+ c.pConnection.SetContainer(c.container.Id())
- globalSASLInit(c.engine)
-
++ saslConfig.setup(c.engine)
+ c.endpoint.init(c.engine.String())
+ go c.run()
+ return c, nil
+}
+
+func (c *connection) run() {
+ if !c.server {
+ c.pConnection.Open()
+ }
+ _ = c.engine.Run()
+ if c.incoming != nil {
+ close(c.incoming)
+ }
+ _ = c.closed(Closed)
+}
+
+func (c *connection) Close(err error) {
+ c.err.Set(err)
+ c.engine.Close(err)
+}
+
+func (c *connection) Disconnect(err error) {
+ c.err.Set(err)
+ c.engine.Disconnect(err)
+}
+
+func (c *connection) Session(opts ...SessionOption) (Session, error) {
+ var s Session
+ err := c.engine.InjectWait(func() error {
+ if c.Error() != nil {
+ return c.Error()
+ }
+ pSession, err := c.engine.Connection().Session()
+ if err == nil {
+ pSession.Open()
+ if err == nil {
+ s = newSession(c, pSession, opts...)
+ }
+ }
+ return err
+ })
+ return s, err
+}
+
+func (c *connection) Container() Container { return c.container }
+
+func (c *connection) DefaultSession() (s Session, err error) {
+ c.defaultSessionOnce.Do(func() {
+ c.defaultSession, err = c.Session()
+ })
+ if err == nil {
+ err = c.Error()
+ }
+ return c.defaultSession, err
+}
+
+func (c *connection) Sender(opts ...LinkOption) (Sender, error) {
+ if s, err := c.DefaultSession(); err == nil {
+ return s.Sender(opts...)
+ } else {
+ return nil, err
+ }
+}
+
+func (c *connection) Receiver(opts ...LinkOption) (Receiver, error) {
+ if s, err := c.DefaultSession(); err == nil {
+ return s.Receiver(opts...)
+ } else {
+ return nil, err
+ }
+}
+
+func (c *connection) Connection() Connection { return c }
+
+func (c *connection) Wait() error { return c.WaitTimeout(Forever) }
+func (c *connection) WaitTimeout(timeout time.Duration) error {
+ _, err := timedReceive(c.done, timeout)
+ if err == Timeout {
+ return Timeout
+ }
+ return c.Error()
+}
+
+func (c *connection) Incoming() <-chan Incoming {
+ assert(c.incoming != nil, "Incoming() is only allowed for a Connection created with the Server() option: %s", c)
+ return c.incoming
+}
+
+type IncomingConnection struct {
+ incoming
+ connectionSettings
+ c *connection
+}
+
+func newIncomingConnection(c *connection) *IncomingConnection {
+ c.user = c.pConnection.Transport().User()
+ c.virtualHost = c.pConnection.RemoteHostname()
+ return &IncomingConnection{
+ incoming: makeIncoming(c.pConnection),
+ connectionSettings: c.connectionSettings,
+ c: c}
+}
+
+// AcceptConnection is like Accept() but takes ConnectionOption s
+// For example you can set the Heartbeat() for the accepted connection.
+func (in *IncomingConnection) AcceptConnection(opts ...ConnectionOption) Connection {
+ return in.accept(func() Endpoint {
+ for _, opt := range opts {
+ opt(in.c)
+ }
+ in.c.pConnection.Open()
+ return in.c
+ }).(Connection)
+}
+
+func (in *IncomingConnection) Accept() Endpoint {
+ return in.AcceptConnection()
+}
+
+func sasl(c *connection) proton.SASL { return c.engine.Transport().SASL() }
+
+// SASLEnable returns a ConnectionOption that enables SASL authentication.
+// Only required if you don't set any other SASL options.
+func SASLEnable() ConnectionOption { return func(c *connection) { sasl(c) } }
+
+// SASLAllowedMechs returns a ConnectionOption to set the list of allowed SASL
+// mechanisms.
+//
+// Can be used on the client or the server to restrict the SASL for a connection.
+// mechs is a space-separated list of mechanism names.
+//
+func SASLAllowedMechs(mechs string) ConnectionOption {
+ return func(c *connection) { sasl(c).AllowedMechs(mechs) }
+}
+
+// SASLAllowInsecure returns a ConnectionOption that allows or disallows clear
+// text SASL authentication mechanisms
+//
+// By default the SASL layer is configured not to allow mechanisms that disclose
+// the clear text of the password over an unencrypted AMQP connection. This specifically
+// will disallow the use of the PLAIN mechanism without using SSL encryption.
+//
+// This default is to avoid disclosing password information accidentally over an
+// insecure network.
+//
+func SASLAllowInsecure(b bool) ConnectionOption {
+ return func(c *connection) { sasl(c).SetAllowInsecureMechs(b) }
+}
+
+// Heartbeat returns a ConnectionOption that requests the maximum delay
+// between sending frames for the remote peer. If we don't receive any frames
+// within 2*delay we will close the connection.
+//
+func Heartbeat(delay time.Duration) ConnectionOption {
+ // Proton-C divides the idle-timeout by 2 before sending, so compensate.
+ return func(c *connection) { c.engine.Transport().SetIdleTimeout(2 * delay) }
+}
+
++type saslConfigState struct {
++ lock sync.Mutex
++ name string
++ dir string
++ initialized bool
++}
++
++func (s *saslConfigState) set(target *string, value string) {
++ s.lock.Lock()
++ defer s.lock.Unlock()
++ if s.initialized {
++ panic("SASL configuration cannot be changed after a Connection has been created")
++ }
++ *target = value
++}
++
++// Apply the global SASL configuration the first time a proton.Engine needs it
++//
++// TODO aconway 2016-09-15: Current pn_sasl C impl config is broken, so all we
++// can realistically offer is global configuration. Later if/when the pn_sasl C
++// impl is fixed we can offer per connection over-rides.
++func (s *saslConfigState) setup(eng *proton.Engine) {
++ s.lock.Lock()
++ defer s.lock.Unlock()
++ if !s.initialized {
++ s.initialized = true
++ sasl := eng.Transport().SASL()
++ if s.name != "" {
++ sasl.ConfigName(saslConfig.name)
++ }
++ if s.dir != "" {
++ sasl.ConfigPath(saslConfig.dir)
++ }
++ }
++}
++
++var saslConfig saslConfigState
++
+// GlobalSASLConfigDir sets the SASL configuration directory for every
+// Connection created in this process. If not called, the default is determined
+// by your SASL installation.
+//
+// You can set SASLAllowInsecure and SASLAllowedMechs on individual connections.
+//
- func GlobalSASLConfigDir(dir string) { globalSASLConfigDir = dir }
++// Must be called at most once, before any connections are created.
++func GlobalSASLConfigDir(dir string) { saslConfig.set(&saslConfig.dir, dir) }
+
+// GlobalSASLConfigName sets the SASL configuration name for every Connection
+// created in this process. If not called the default is "proton-server".
+//
+// The complete configuration file name is
+// <sasl-config-dir>/<sasl-config-name>.conf
+//
+// You can set SASLAllowInsecure and SASLAllowedMechs on individual connections.
+//
- func GlobalSASLConfigName(dir string) { globalSASLConfigName = dir }
++// Must be called at most once, before any connections are created.
++func GlobalSASLConfigName(name string) { saslConfig.set(&saslConfig.name, name) }
+
+// Do we support extended SASL negotiation?
+// All implementations of Proton support ANONYMOUS and EXTERNAL on both
+// client and server sides and PLAIN on the client side.
+//
+// Extended SASL implememtations use an external library (Cyrus SASL)
+// to support other mechanisms beyond these basic ones.
+func SASLExtended() bool { return proton.SASLExtended() }
+
- var (
- globalSASLConfigName string
- globalSASLConfigDir string
- )
-
- // TODO aconway 2016-09-15: Current pn_sasl C impl config is broken, so all we
- // can realistically offer is global configuration. Later if/when the pn_sasl C
- // impl is fixed we can offer per connection over-rides.
- func globalSASLInit(eng *proton.Engine) {
- sasl := eng.Transport().SASL()
- if globalSASLConfigName != "" {
- sasl.ConfigName(globalSASLConfigName)
- }
- if globalSASLConfigDir != "" {
- sasl.ConfigPath(globalSASLConfigDir)
- }
- }
-
+// Dial is shorthand for using net.Dial() then NewConnection()
+// See net.Dial() for the meaning of the network, address arguments.
+func Dial(network, address string, opts ...ConnectionOption) (c Connection, err error) {
+ conn, err := net.Dial(network, address)
+ if err == nil {
+ c, err = NewConnection(conn, opts...)
+ }
+ return
+}
+
+// DialWithDialer is shorthand for using dialer.Dial() then NewConnection()
+// See net.Dial() for the meaning of the network, address arguments.
+func DialWithDialer(dialer *net.Dialer, network, address string, opts ...ConnectionOption) (c Connection, err error) {
+ conn, err := dialer.Dial(network, address)
+ if err == nil {
+ c, err = NewConnection(conn, opts...)
+ }
+ return
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@qpid.apache.org
For additional commands, e-mail: commits-help@qpid.apache.org