http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/b25d21e6/amqp/unmarshal.go ---------------------------------------------------------------------- diff --cc amqp/unmarshal.go index 253d66d,0000000..97e8437 mode 100644,000000..100644 --- a/amqp/unmarshal.go +++ b/amqp/unmarshal.go @@@ -1,638 -1,0 +1,733 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +oor more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package amqp + +// #include <proton/codec.h> +import "C" + +import ( + "bytes" + "fmt" + "io" + "reflect" - "strings" ++ "time" + "unsafe" +) + +const minDecode = 1024 + +// Error returned if AMQP data cannot be unmarshaled as the desired Go type. +type UnmarshalError struct { + // The name of the AMQP type. + AMQPType string + // The Go type. + GoType reflect.Type + + s string +} + +func (e UnmarshalError) Error() string { return e.s } + - func newUnmarshalErrorMsg(pnType C.pn_type_t, v interface{}, msg string) *UnmarshalError { - if len(msg) > 0 && !strings.HasPrefix(msg, ":") { - msg = ": " + msg ++// Error returned if there are not enough bytes to decode a complete AMQP value. ++var EndOfData = &UnmarshalError{s: "Not enough data for AMQP value"} ++ ++var badData = &UnmarshalError{s: "Unexpected error in data"} ++ ++func newUnmarshalError(pnType C.pn_type_t, v interface{}) *UnmarshalError { ++ e := &UnmarshalError{ ++ AMQPType: C.pn_type_t(pnType).String(), ++ GoType: reflect.TypeOf(v), + } - e := &UnmarshalError{AMQPType: C.pn_type_t(pnType).String(), GoType: reflect.TypeOf(v)} - if e.GoType.Kind() != reflect.Ptr { - e.s = fmt.Sprintf("cannot unmarshal to type %s, not a pointer%s", e.GoType, msg) ++ if e.GoType == nil || e.GoType.Kind() != reflect.Ptr { ++ e.s = fmt.Sprintf("cannot unmarshal to Go type %v, not a pointer", e.GoType) + } else { - e.s = fmt.Sprintf("cannot unmarshal AMQP %s to %s%s", e.AMQPType, e.GoType, msg) ++ e.s = fmt.Sprintf("cannot unmarshal AMQP %v to Go %v", e.AMQPType, e.GoType.Elem()) + } + return e +} + - func newUnmarshalError(pnType C.pn_type_t, v interface{}) *UnmarshalError { - return newUnmarshalErrorMsg(pnType, v, "") ++func doPanic(data *C.pn_data_t, v interface{}) { ++ e := newUnmarshalError(C.pn_data_type(data), v) ++ panic(e) +} + - func newUnmarshalErrorData(data *C.pn_data_t, v interface{}) *UnmarshalError { - err := PnError(C.pn_data_error(data)) - if err == nil { - return nil - } ++func doPanicMsg(data *C.pn_data_t, v interface{}, msg string) { + e := newUnmarshalError(C.pn_data_type(data), v) - e.s = e.s + ": " + err.Error() - return e ++ e.s = e.s + ": " + msg ++ panic(e) +} + - func recoverUnmarshal(err *error) { - if r := recover(); r != nil { - if uerr, ok := r.(*UnmarshalError); ok { - *err = uerr - } else { - panic(r) - } ++func panicIfBadData(data *C.pn_data_t, v interface{}) { ++ if C.pn_data_errno(data) != 0 { ++ doPanicMsg(data, v, PnError(C.pn_data_error(data)).Error()) ++ } ++} ++ ++func panicUnless(ok bool, data *C.pn_data_t, v interface{}) { ++ if !ok { ++ doPanic(data, v) ++ } ++} ++ ++func checkOp(ok bool, v interface{}) { ++ if !ok { ++ panic(&badData) + } +} + +// +// Decoding from a pn_data_t +// +// NOTE: we use panic() to signal a decoding error, simplifies decoding logic. +// We recover() at the highest possible level - i.e. in the exported Unmarshal or Decode. +// + +// Decoder decodes AMQP values from an io.Reader. +// +type Decoder struct { + reader io.Reader + buffer bytes.Buffer +} + +// NewDecoder returns a new decoder that reads from r. +// +// The decoder has it's own buffer and may read more data than required for the +// AMQP values requested. Use Buffered to see if there is data left in the +// buffer. +// +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r, bytes.Buffer{}} +} + +// Buffered returns a reader of the data remaining in the Decoder's buffer. The +// reader is valid until the next call to Decode. +// +func (d *Decoder) Buffered() io.Reader { + return bytes.NewReader(d.buffer.Bytes()) +} + +// Decode reads the next AMQP value from the Reader and stores it in the value pointed to by v. +// +// See the documentation for Unmarshal for details about the conversion of AMQP into a Go value. +// +func (d *Decoder) Decode(v interface{}) (err error) { - defer recoverUnmarshal(&err) + data := C.pn_data(0) + defer C.pn_data_free(data) + var n int - for n == 0 { - n, err = decode(data, d.buffer.Bytes()) - if err != nil { - return err ++ for n, err = decode(data, d.buffer.Bytes()); err == EndOfData; { ++ err = d.more() ++ if err == nil { ++ n, err = decode(data, d.buffer.Bytes()) + } - if n == 0 { // n == 0 means not enough data, read more - err = d.more() - } else { - unmarshal(v, data) ++ } ++ if err == nil { ++ if err = recoverUnmarshal(v, data); err == nil { ++ d.buffer.Next(n) + } + } - d.buffer.Next(n) + return +} + +/* - Unmarshal decodes AMQP-encoded bytes and stores the result in the Go value pointed to by v. - Types are converted as follows: - - +------------------------+-------------------------------------------------+ - |To Go types |From AMQP types | - +========================+=================================================+ - |bool |bool | - +------------------------+-------------------------------------------------+ - |int, int8, int16, int32,|Equivalent or smaller signed integer type: byte, | - |int64 |short, int, long. | - +------------------------+-------------------------------------------------+ - |uint, uint8, uint16, |Equivalent or smaller unsigned integer type: | - |uint32, uint64 |ubyte, ushort, uint, ulong | - +------------------------+-------------------------------------------------+ - |float32, float64 |Equivalent or smaller float or double. | - +------------------------+-------------------------------------------------+ - |string, []byte |string, symbol or binary. | - +------------------------+-------------------------------------------------+ - |Symbol |symbol | - +------------------------+-------------------------------------------------+ - |map[K]T |map, provided all keys and values can unmarshal | - | |to types K,T | - +------------------------+-------------------------------------------------+ - |Map |map, any AMQP map | - +------------------------+-------------------------------------------------+ - |Described |described type | - +------------------------+-------------------------------------------------+ - - An AMQP described type can unmarshal into the corresponding plain type, discarding the descriptor. - For example an AMQP described string can unmarshal into a plain go string. - Unmarshal into the Described type preserves the descriptor. - - Any AMQP type can unmarshal to an interface{}, the Go type used to unmarshal is chosen from the AMQP type as follows - - +------------------------+-------------------------------------------------+ - |AMQP Type |Go Type in interface{} | - +========================+=================================================+ - |bool |bool | - +------------------------+-------------------------------------------------+ - |byte,short,int,long |int8,int16,int32,int64 | - +------------------------+-------------------------------------------------+ - |ubyte,ushort,uint,ulong |uint8,uint16,uint32,uint64 | - +------------------------+-------------------------------------------------+ - |float, double |float32, float64 | - +------------------------+-------------------------------------------------+ - |string |string | - +------------------------+-------------------------------------------------+ - |symbol |Symbol | - +------------------------+-------------------------------------------------+ - |binary |Binary | - +------------------------+-------------------------------------------------+ - |null |nil | - +------------------------+-------------------------------------------------+ - |map |Map | - +------------------------+-------------------------------------------------+ - |list |List | - +------------------------+-------------------------------------------------+ - |described type |Described | - +--------------------------------------------------------------------------+ - - The following Go types cannot be unmarshaled: uintptr, function, interface, channel, array (use slice), struct - - TODO: Not yet implemented: - - AMQP types: decimal32/64/128, char (round trip), timestamp, uuid. - - AMQP maps with mixed key types, or key types that are not legal Go map keys. ++ ++Unmarshal decodes AMQP-encoded bytes and stores the result in the Go value ++pointed to by v. Legal conversions from the source AMQP type to the target Go ++type as follows: ++ ++ +----------------------------+-------------------------------------------------+ ++ |Target Go type | Allowed AMQP types ++ +============================+==================================================+ ++ |bool |bool | ++ +----------------------------+--------------------------------------------------+ ++ |int, int8, int16, int32, |Equivalent or smaller signed integer type: | ++ |int64 |byte, short, int, long or char | ++ +----------------------------+--------------------------------------------------+ ++ |uint, uint8, uint16, uint32,|Equivalent or smaller unsigned integer type: | ++ |uint64 |ubyte, ushort, uint, ulong | ++ +----------------------------+--------------------------------------------------+ ++ |float32, float64 |Equivalent or smaller float or double | ++ +----------------------------+--------------------------------------------------+ ++ |string, []byte |string, symbol or binary | ++ +----------------------------+--------------------------------------------------+ ++ |Symbol |symbol | ++ +----------------------------+--------------------------------------------------+ ++ |Char |char | ++ +----------------------------+--------------------------------------------------+ ++ |Described |AMQP described type [1] | ++ +----------------------------+--------------------------------------------------+ ++ |Time |timestamp | ++ +----------------------------+--------------------------------------------------+ ++ |UUID |uuid | ++ +----------------------------+--------------------------------------------------+ ++ |map[interface{}]interface{} |Any AMQP map | ++ +----------------------------+--------------------------------------------------+ ++ |map[K]T |map, provided all keys and values can unmarshal | ++ | |to types K,T | ++ +----------------------------+--------------------------------------------------+ ++ |[]interface{} |AMQP list or array | ++ +----------------------------+--------------------------------------------------+ ++ |[]T |list or array if elements can unmarshal as T | ++ +----------------------------+------------------n-------------------------------+ ++ |interface{} |any AMQP type[2] | ++ +----------------------------+--------------------------------------------------+ ++ ++[1] An AMQP described value can also unmarshal to a plain value, discarding the ++descriptor. Unmarshalling into the special amqp.Described type preserves the ++descriptor. ++ ++[2] Any AMQP value can be unmarshalled to an interface{}. The Go type is ++determined by the AMQP type as follows: ++ ++ +----------------------------+--------------------------------------------------+ ++ |Source AMQP Type |Go Type in target interface{} | ++ +============================+==================================================+ ++ |bool |bool | ++ +----------------------------+--------------------------------------------------+ ++ |byte,short,int,long |int8,int16,int32,int64 | ++ +----------------------------+--------------------------------------------------+ ++ |ubyte,ushort,uint,ulong |uint8,uint16,uint32,uint64 | ++ +----------------------------+--------------------------------------------------+ ++ |float, double |float32, float64 | ++ +----------------------------+--------------------------------------------------+ ++ |string |string | ++ +----------------------------+--------------------------------------------------+ ++ |symbol |Symbol | ++ +----------------------------+--------------------------------------------------+ ++ |char |Char | ++ +----------------------------+--------------------------------------------------+ ++ |binary |Binary | ++ +----------------------------+--------------------------------------------------+ ++ |null |nil | ++ +----------------------------+--------------------------------------------------+ ++ |described type |Described | ++ +----------------------------+--------------------------------------------------+ ++ |timestamp |time.Time | ++ +----------------------------+--------------------------------------------------+ ++ |uuid |UUID | ++ +----------------------------+--------------------------------------------------+ ++ |map |Map or AnyMap[4] | ++ +----------------------------+--------------------------------------------------+ ++ |list |List | ++ +----------------------------+--------------------------------------------------+ ++ |array |[]T for simple types, T is chosen as above [3] | ++ +----------------------------+--------------------------------------------------+ ++ ++[3] An AMQP array of simple types unmarshalls as a slice of the corresponding Go type. ++An AMQP array containing complex types (lists, maps or nested arrays) unmarshals ++to the generic array type amqp.Array ++ ++[4] An AMQP map unmarshals as the generic `type Map map[interface{}]interface{}` ++unless it contains key values that are illegal as Go map types, in which case ++it unmarshals as type AnyMap. ++ ++The following Go types cannot be unmarshaled: uintptr, function, interface, ++channel, array (use slice), struct ++ ++AMQP types not yet supported: ++- decimal32/64/128 ++- maps with key values that are not legal Go map keys. +*/ +func Unmarshal(bytes []byte, v interface{}) (n int, err error) { - defer recoverUnmarshal(&err) - + data := C.pn_data(0) + defer C.pn_data_free(data) + n, err = decode(data, bytes) - if err != nil { - return 0, err - } - if n == 0 { - return 0, fmt.Errorf("not enough data") - } else { - unmarshal(v, data) ++ if err == nil { ++ err = recoverUnmarshal(v, data) + } - return n, nil ++ return +} + +// Internal - func UnmarshalUnsafe(pn_data unsafe.Pointer, v interface{}) (err error) { - defer recoverUnmarshal(&err) - unmarshal(v, (*C.pn_data_t)(pn_data)) - return ++func UnmarshalUnsafe(pnData unsafe.Pointer, v interface{}) (err error) { ++ return recoverUnmarshal(v, (*C.pn_data_t)(pnData)) +} + +// more reads more data when we can't parse a complete AMQP type +func (d *Decoder) more() error { + var readSize int64 = minDecode + if int64(d.buffer.Len()) > readSize { // Grow by doubling + readSize = int64(d.buffer.Len()) + } + var n int64 + n, err := d.buffer.ReadFrom(io.LimitReader(d.reader, readSize)) + if n == 0 && err == nil { // ReadFrom won't report io.EOF, just returns 0 + err = io.EOF + } + return err +} + - // Unmarshal from data into value pointed at by v. ++// Call unmarshal(), convert panic to error value ++func recoverUnmarshal(v interface{}, data *C.pn_data_t) (err error) { ++ defer func() { ++ if r := recover(); r != nil { ++ if uerr, ok := r.(*UnmarshalError); ok { ++ err = uerr ++ } else { ++ panic(r) ++ } ++ } ++ }() ++ unmarshal(v, data) ++ return nil ++} ++ ++// Unmarshal from data into value pointed at by v. Returns v. ++// NOTE: If you update this you also need to update getInterface() +func unmarshal(v interface{}, data *C.pn_data_t) { - pnType := C.pn_data_type(data) ++ rt := reflect.TypeOf(v) ++ rv := reflect.ValueOf(v) ++ panicUnless(v != nil && rt.Kind() == reflect.Ptr && !rv.IsNil(), data, v) + + // Check for PN_DESCRIBED first, as described types can unmarshal into any of the Go types. - // Interfaces are handled in the switch below, even for described types. ++ // An interface{} target is handled in the switch below, even for described types. + if _, isInterface := v.(*interface{}); !isInterface && bool(C.pn_data_is_described(data)) { + getDescribed(data, v) + return + } + + // Unmarshal based on the target type ++ pnType := C.pn_data_type(data) + switch v := v.(type) { ++ + case *bool: - switch pnType { - case C.PN_BOOL: - *v = bool(C.pn_data_get_bool(data)) - default: - panic(newUnmarshalError(pnType, v)) - } ++ panicUnless(pnType == C.PN_BOOL, data, v) ++ *v = bool(C.pn_data_get_bool(data)) ++ + case *int8: - switch pnType { - case C.PN_CHAR: - *v = int8(C.pn_data_get_char(data)) - case C.PN_BYTE: - *v = int8(C.pn_data_get_byte(data)) - default: - panic(newUnmarshalError(pnType, v)) - } ++ panicUnless(pnType == C.PN_BYTE, data, v) ++ *v = int8(C.pn_data_get_byte(data)) ++ + case *uint8: - switch pnType { - case C.PN_CHAR: - *v = uint8(C.pn_data_get_char(data)) - case C.PN_UBYTE: - *v = uint8(C.pn_data_get_ubyte(data)) - default: - panic(newUnmarshalError(pnType, v)) - } ++ panicUnless(pnType == C.PN_UBYTE, data, v) ++ *v = uint8(C.pn_data_get_ubyte(data)) ++ + case *int16: - switch pnType { - case C.PN_CHAR: - *v = int16(C.pn_data_get_char(data)) ++ switch C.pn_data_type(data) { + case C.PN_BYTE: + *v = int16(C.pn_data_get_byte(data)) + case C.PN_SHORT: + *v = int16(C.pn_data_get_short(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } ++ + case *uint16: + switch pnType { - case C.PN_CHAR: - *v = uint16(C.pn_data_get_char(data)) + case C.PN_UBYTE: + *v = uint16(C.pn_data_get_ubyte(data)) + case C.PN_USHORT: + *v = uint16(C.pn_data_get_ushort(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } ++ + case *int32: + switch pnType { + case C.PN_CHAR: + *v = int32(C.pn_data_get_char(data)) + case C.PN_BYTE: + *v = int32(C.pn_data_get_byte(data)) + case C.PN_SHORT: + *v = int32(C.pn_data_get_short(data)) + case C.PN_INT: + *v = int32(C.pn_data_get_int(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } ++ + case *uint32: + switch pnType { + case C.PN_CHAR: + *v = uint32(C.pn_data_get_char(data)) + case C.PN_UBYTE: + *v = uint32(C.pn_data_get_ubyte(data)) + case C.PN_USHORT: + *v = uint32(C.pn_data_get_ushort(data)) + case C.PN_UINT: + *v = uint32(C.pn_data_get_uint(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + + case *int64: + switch pnType { + case C.PN_CHAR: + *v = int64(C.pn_data_get_char(data)) + case C.PN_BYTE: + *v = int64(C.pn_data_get_byte(data)) + case C.PN_SHORT: + *v = int64(C.pn_data_get_short(data)) + case C.PN_INT: + *v = int64(C.pn_data_get_int(data)) + case C.PN_LONG: + *v = int64(C.pn_data_get_long(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + + case *uint64: + switch pnType { + case C.PN_CHAR: + *v = uint64(C.pn_data_get_char(data)) + case C.PN_UBYTE: + *v = uint64(C.pn_data_get_ubyte(data)) + case C.PN_USHORT: + *v = uint64(C.pn_data_get_ushort(data)) + case C.PN_ULONG: + *v = uint64(C.pn_data_get_ulong(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + + case *int: + switch pnType { + case C.PN_CHAR: + *v = int(C.pn_data_get_char(data)) + case C.PN_BYTE: + *v = int(C.pn_data_get_byte(data)) + case C.PN_SHORT: + *v = int(C.pn_data_get_short(data)) + case C.PN_INT: + *v = int(C.pn_data_get_int(data)) + case C.PN_LONG: - if unsafe.Sizeof(int(0)) == 8 { ++ if intIs64 { + *v = int(C.pn_data_get_long(data)) + } else { - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + + case *uint: + switch pnType { + case C.PN_CHAR: + *v = uint(C.pn_data_get_char(data)) + case C.PN_UBYTE: + *v = uint(C.pn_data_get_ubyte(data)) + case C.PN_USHORT: + *v = uint(C.pn_data_get_ushort(data)) + case C.PN_UINT: + *v = uint(C.pn_data_get_uint(data)) + case C.PN_ULONG: - if unsafe.Sizeof(int(0)) == 8 { ++ if intIs64 { + *v = uint(C.pn_data_get_ulong(data)) + } else { - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + + case *float32: - switch pnType { - case C.PN_FLOAT: - *v = float32(C.pn_data_get_float(data)) - default: - panic(newUnmarshalError(pnType, v)) - } ++ panicUnless(pnType == C.PN_FLOAT, data, v) ++ *v = float32(C.pn_data_get_float(data)) + + case *float64: + switch pnType { + case C.PN_FLOAT: + *v = float64(C.pn_data_get_float(data)) + case C.PN_DOUBLE: + *v = float64(C.pn_data_get_double(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + + case *string: + switch pnType { + case C.PN_STRING: + *v = goString(C.pn_data_get_string(data)) + case C.PN_SYMBOL: + *v = goString(C.pn_data_get_symbol(data)) + case C.PN_BINARY: + *v = goString(C.pn_data_get_binary(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + + case *[]byte: + switch pnType { + case C.PN_STRING: + *v = goBytes(C.pn_data_get_string(data)) + case C.PN_SYMBOL: + *v = goBytes(C.pn_data_get_symbol(data)) + case C.PN_BINARY: + *v = goBytes(C.pn_data_get_binary(data)) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } ++ return ++ ++ case *Char: ++ panicUnless(pnType == C.PN_CHAR, data, v) ++ *v = Char(C.pn_data_get_char(data)) + + case *Binary: - switch pnType { - case C.PN_BINARY: - *v = Binary(goBytes(C.pn_data_get_binary(data))) - default: - panic(newUnmarshalError(pnType, v)) - } ++ panicUnless(pnType == C.PN_BINARY, data, v) ++ *v = Binary(goBytes(C.pn_data_get_binary(data))) + + case *Symbol: - switch pnType { - case C.PN_SYMBOL: - *v = Symbol(goBytes(C.pn_data_get_symbol(data))) - default: - panic(newUnmarshalError(pnType, v)) - } ++ panicUnless(pnType == C.PN_SYMBOL, data, v) ++ *v = Symbol(goBytes(C.pn_data_get_symbol(data))) + - case *interface{}: - getInterface(data, v) ++ case *time.Time: ++ panicUnless(pnType == C.PN_TIMESTAMP, data, v) ++ *v = time.Unix(0, int64(C.pn_data_get_timestamp(data))*1000) ++ ++ case *UUID: ++ panicUnless(pnType == C.PN_UUID, data, v) ++ pn := C.pn_data_get_uuid(data) ++ copy((*v)[:], C.GoBytes(unsafe.Pointer(&pn.bytes), 16)) + + case *AnnotationKey: - if pnType == C.PN_ULONG || pnType == C.PN_SYMBOL || pnType == C.PN_STRING { - unmarshal(&v.value, data) - } else { - panic(newUnmarshalError(pnType, v)) ++ panicUnless(pnType == C.PN_ULONG || pnType == C.PN_SYMBOL || pnType == C.PN_STRING, data, v) ++ unmarshal(&v.value, data) ++ ++ case *AnyMap: ++ panicUnless(C.pn_data_type(data) == C.PN_MAP, data, v) ++ n := int(C.pn_data_get_map(data)) / 2 ++ if cap(*v) < n { ++ *v = make(AnyMap, n) ++ } ++ *v = (*v)[:n] ++ data.enter(*v) ++ defer data.exit(*v) ++ for i := 0; i < n; i++ { ++ data.next(*v) ++ unmarshal(&(*v)[i].Key, data) ++ data.next(*v) ++ unmarshal(&(*v)[i].Value, data) + } + ++ case *interface{}: ++ getInterface(data, v) ++ + default: // This is not one of the fixed well-known types, reflect for map and slice types - if reflect.TypeOf(v).Kind() != reflect.Ptr { - panic(newUnmarshalError(pnType, v)) - } - switch reflect.TypeOf(v).Elem().Kind() { ++ ++ switch rt.Elem().Kind() { + case reflect.Map: + getMap(data, v) + case reflect.Slice: - getList(data, v) ++ getSequence(data, v) + default: - panic(newUnmarshalError(pnType, v)) ++ doPanic(data, v) + } + } - if err := newUnmarshalErrorData(data, v); err != nil { - panic(err) - } - return - } - - func rewindUnmarshal(v interface{}, data *C.pn_data_t) { - C.pn_data_rewind(data) - C.pn_data_next(data) - unmarshal(v, data) +} + - // Getting into an interface is driven completely by the AMQP type, since the interface{} - // target is type-neutral. - func getInterface(data *C.pn_data_t, v *interface{}) { ++// Unmarshalling into an interface{} the type is determined by the AMQP source type, ++// since the interface{} target can hold any Go type. ++func getInterface(data *C.pn_data_t, vp *interface{}) { + pnType := C.pn_data_type(data) + switch pnType { + case C.PN_BOOL: - *v = bool(C.pn_data_get_bool(data)) ++ *vp = bool(C.pn_data_get_bool(data)) + case C.PN_UBYTE: - *v = uint8(C.pn_data_get_ubyte(data)) ++ *vp = uint8(C.pn_data_get_ubyte(data)) + case C.PN_BYTE: - *v = int8(C.pn_data_get_byte(data)) ++ *vp = int8(C.pn_data_get_byte(data)) + case C.PN_USHORT: - *v = uint16(C.pn_data_get_ushort(data)) ++ *vp = uint16(C.pn_data_get_ushort(data)) + case C.PN_SHORT: - *v = int16(C.pn_data_get_short(data)) ++ *vp = int16(C.pn_data_get_short(data)) + case C.PN_UINT: - *v = uint32(C.pn_data_get_uint(data)) ++ *vp = uint32(C.pn_data_get_uint(data)) + case C.PN_INT: - *v = int32(C.pn_data_get_int(data)) ++ *vp = int32(C.pn_data_get_int(data)) + case C.PN_CHAR: - *v = uint8(C.pn_data_get_char(data)) ++ *vp = Char(C.pn_data_get_char(data)) + case C.PN_ULONG: - *v = uint64(C.pn_data_get_ulong(data)) ++ *vp = uint64(C.pn_data_get_ulong(data)) + case C.PN_LONG: - *v = int64(C.pn_data_get_long(data)) ++ *vp = int64(C.pn_data_get_long(data)) + case C.PN_FLOAT: - *v = float32(C.pn_data_get_float(data)) ++ *vp = float32(C.pn_data_get_float(data)) + case C.PN_DOUBLE: - *v = float64(C.pn_data_get_double(data)) ++ *vp = float64(C.pn_data_get_double(data)) + case C.PN_BINARY: - *v = Binary(goBytes(C.pn_data_get_binary(data))) ++ *vp = Binary(goBytes(C.pn_data_get_binary(data))) + case C.PN_STRING: - *v = goString(C.pn_data_get_string(data)) ++ *vp = goString(C.pn_data_get_string(data)) + case C.PN_SYMBOL: - *v = Symbol(goString(C.pn_data_get_symbol(data))) ++ *vp = Symbol(goString(C.pn_data_get_symbol(data))) ++ case C.PN_TIMESTAMP: ++ *vp = time.Unix(0, int64(C.pn_data_get_timestamp(data))*1000) ++ case C.PN_UUID: ++ var u UUID ++ unmarshal(&u, data) ++ *vp = u + case C.PN_MAP: - m := make(Map) - unmarshal(&m, data) - *v = m ++ // We will try to unmarshal as a Map first, if that fails try AnyMap ++ m := make(Map, int(C.pn_data_get_map(data))/2) ++ if err := recoverUnmarshal(&m, data); err == nil { ++ *vp = m ++ } else { ++ am := make(AnyMap, int(C.pn_data_get_map(data))/2) ++ unmarshal(&am, data) ++ *vp = am ++ } + case C.PN_LIST: - l := make(List, 0) ++ l := List{} + unmarshal(&l, data) - *v = l ++ *vp = l ++ case C.PN_ARRAY: ++ sp := getArrayStore(data) // interface{} containing T* for suitable T ++ unmarshal(sp, data) ++ *vp = reflect.ValueOf(sp).Elem().Interface() + case C.PN_DESCRIBED: + d := Described{} + unmarshal(&d, data) - *v = d ++ *vp = d + case C.PN_NULL: - *v = nil ++ *vp = nil + case C.PN_INVALID: + // Allow decoding from an empty data object to an interface, treat it like NULL. + // This happens when optional values or properties are omitted from a message. - *v = nil ++ *vp = nil + default: // Don't know how to handle this - panic(newUnmarshalError(pnType, v)) ++ panic(newUnmarshalError(pnType, vp)) + } +} + ++// Return an interface{} containing a pointer to an appropriate slice or Array ++func getArrayStore(data *C.pn_data_t) interface{} { ++ // TODO aconway 2017-11-10: described arrays. ++ switch C.pn_data_get_array_type(data) { ++ case C.PN_BOOL: ++ return new([]bool) ++ case C.PN_UBYTE: ++ return new([]uint8) ++ case C.PN_BYTE: ++ return new([]int8) ++ case C.PN_USHORT: ++ return new([]uint16) ++ case C.PN_SHORT: ++ return new([]int16) ++ case C.PN_UINT: ++ return new([]uint32) ++ case C.PN_INT: ++ return new([]int32) ++ case C.PN_CHAR: ++ return new([]Char) ++ case C.PN_ULONG: ++ return new([]uint64) ++ case C.PN_LONG: ++ return new([]int64) ++ case C.PN_FLOAT: ++ return new([]float32) ++ case C.PN_DOUBLE: ++ return new([]float64) ++ case C.PN_BINARY: ++ return new([]Binary) ++ case C.PN_STRING: ++ return new([]string) ++ case C.PN_SYMBOL: ++ return new([]Symbol) ++ case C.PN_TIMESTAMP: ++ return new([]time.Time) ++ case C.PN_UUID: ++ return new([]UUID) ++ } ++ return new(Array) // Not a simple type, use generic Array ++} ++ ++var typeOfInterface = reflect.TypeOf(interface{}(nil)) ++ +// get into map pointed at by v +func getMap(data *C.pn_data_t, v interface{}) { ++ panicUnless(C.pn_data_type(data) == C.PN_MAP, data, v) ++ n := int(C.pn_data_get_map(data)) / 2 + mapValue := reflect.ValueOf(v).Elem() + mapValue.Set(reflect.MakeMap(mapValue.Type())) // Clear the map - switch pnType := C.pn_data_type(data); pnType { - case C.PN_MAP: - count := int(C.pn_data_get_map(data)) - if bool(C.pn_data_enter(data)) { - defer C.pn_data_exit(data) - for i := 0; i < count/2; i++ { - if bool(C.pn_data_next(data)) { - key := reflect.New(mapValue.Type().Key()) - unmarshal(key.Interface(), data) - if bool(C.pn_data_next(data)) { - val := reflect.New(mapValue.Type().Elem()) - unmarshal(val.Interface(), data) - mapValue.SetMapIndex(key.Elem(), val.Elem()) - } - } - } ++ data.enter(v) ++ defer data.exit(v) ++ // Allocate re-usable key/val values ++ keyType := mapValue.Type().Key() ++ keyPtr := reflect.New(keyType) ++ valPtr := reflect.New(mapValue.Type().Elem()) ++ for i := 0; i < n; i++ { ++ data.next(v) ++ unmarshal(keyPtr.Interface(), data) ++ if keyType.Kind() == reflect.Interface && !keyPtr.Elem().Elem().Type().Comparable() { ++ doPanicMsg(data, v, fmt.Sprintf("key %#v is not comparable", keyPtr.Elem().Interface())) + } - default: // Empty/error/unknown, leave map empty ++ data.next(v) ++ unmarshal(valPtr.Interface(), data) ++ mapValue.SetMapIndex(keyPtr.Elem(), valPtr.Elem()) + } +} + - func getList(data *C.pn_data_t, v interface{}) { ++func getSequence(data *C.pn_data_t, vp interface{}) { ++ var count int + pnType := C.pn_data_type(data) - if pnType != C.PN_LIST { - panic(newUnmarshalError(pnType, v)) ++ switch pnType { ++ case C.PN_LIST: ++ count = int(C.pn_data_get_list(data)) ++ case C.PN_ARRAY: ++ count = int(C.pn_data_get_array(data)) ++ default: ++ doPanic(data, vp) + } - count := int(C.pn_data_get_list(data)) - listValue := reflect.MakeSlice(reflect.TypeOf(v).Elem(), count, count) - if bool(C.pn_data_enter(data)) { - for i := 0; i < count; i++ { - if bool(C.pn_data_next(data)) { - val := reflect.New(listValue.Type().Elem()) - unmarshal(val.Interface(), data) - listValue.Index(i).Set(val.Elem()) - } - } - C.pn_data_exit(data) ++ listValue := reflect.MakeSlice(reflect.TypeOf(vp).Elem(), count, count) ++ data.enter(vp) ++ defer data.exit(vp) ++ for i := 0; i < count; i++ { ++ data.next(vp) ++ val := reflect.New(listValue.Type().Elem()) ++ unmarshal(val.Interface(), data) ++ listValue.Index(i).Set(val.Elem()) + } - reflect.ValueOf(v).Elem().Set(listValue) ++ reflect.ValueOf(vp).Elem().Set(listValue) +} + - func getDescribed(data *C.pn_data_t, v interface{}) { - d, _ := v.(*Described) - pnType := C.pn_data_type(data) - if bool(C.pn_data_enter(data)) { - defer C.pn_data_exit(data) - if bool(C.pn_data_next(data)) { - if d != nil { - unmarshal(&d.Descriptor, data) - } - if bool(C.pn_data_next(data)) { - if d != nil { - unmarshal(&d.Value, data) - } else { - unmarshal(v, data) - } - return - } - } ++func getDescribed(data *C.pn_data_t, vp interface{}) { ++ d, isDescribed := vp.(*Described) ++ data.enter(vp) ++ defer data.exit(vp) ++ data.next(vp) ++ if isDescribed { ++ unmarshal(&d.Descriptor, data) ++ data.next(vp) ++ unmarshal(&d.Value, data) ++ } else { ++ data.next(vp) // Skip descriptor ++ unmarshal(vp, data) // Unmarshal plain value + } - // The pn_data cursor didn't move as expected - panic(newUnmarshalErrorMsg(pnType, v, "bad described value encoding")) +} + +// decode from bytes. +// Return bytes decoded or 0 if we could not decode a complete object. +// +func decode(data *C.pn_data_t, bytes []byte) (int, error) { - if len(bytes) == 0 { - return 0, nil - } - n := int(C.pn_data_decode(data, cPtr(bytes), cLen(bytes))) - if n == int(C.PN_UNDERFLOW) { ++ n := C.pn_data_decode(data, cPtr(bytes), cLen(bytes)) ++ if n == C.PN_UNDERFLOW { + C.pn_error_clear(C.pn_data_error(data)) - return 0, nil ++ return 0, EndOfData + } else if n <= 0 { - return 0, fmt.Errorf("unmarshal %s", PnErrorCode(n)) ++ return 0, &UnmarshalError{s: fmt.Sprintf("unmarshal %v", PnErrorCode(n))} + } - return n, nil ++ return int(n), nil +} ++ ++// Checked versions of pn_data functions ++ ++func (data *C.pn_data_t) enter(v interface{}) { checkOp(bool(C.pn_data_enter(data)), v) } ++func (data *C.pn_data_t) exit(v interface{}) { checkOp(bool(C.pn_data_exit(data)), v) } ++func (data *C.pn_data_t) next(v interface{}) { checkOp(bool(C.pn_data_next(data)), v) }
http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/b25d21e6/amqp/url_test.go ---------------------------------------------------------------------- diff --cc amqp/url_test.go index f52d4bf,0000000..192e2fb mode 100644,000000..100644 --- a/amqp/url_test.go +++ b/amqp/url_test.go @@@ -1,59 -1,0 +1,59 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package amqp + +import ( + "fmt" +) + +func ExampleParseURL() { + for _, s := range []string{ + "amqp://username:password@host:1234/path", + "host:1234", + "host", + "host/path", + "amqps://host", + "/path", + "", + ":1234", - // Taken out becasue the go 1.4 URL parser isn't the same as later ++ // Taken out because the go 1.4 URL parser isn't the same as later + //"[::1]", + //"[::1", + // Output would be: + // amqp://[::1]:amqp + // parse amqp://[::1: missing ']' in host + } { + u, err := ParseURL(s) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(u) + } + } + // Output: + // amqp://username:password@host:1234/path + // amqp://host:1234 + // amqp://host:amqp + // amqp://host:amqp/path + // amqps://host:amqps + // amqp://localhost:amqp/path + // amqp://localhost:amqp + // parse :1234: missing protocol scheme +} http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/b25d21e6/electron/auth_test.go ---------------------------------------------------------------------- diff --cc electron/auth_test.go index 9eb48c0,0000000..9fa9fa2 mode 100644,000000..100644 --- a/electron/auth_test.go +++ b/electron/auth_test.go @@@ -1,133 -1,0 +1,137 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package electron + +import ( + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func testAuthClientServer(t *testing.T, copts []ConnectionOption, sopts []ConnectionOption) (got connectionSettings, err error) { + client, server := newClientServerOpts(t, copts, sopts) + defer closeClientServer(client, server) + + go func() { + for in := range server.Incoming() { + switch in := in.(type) { + case *IncomingConnection: + got = connectionSettings{user: in.User(), virtualHost: in.VirtualHost()} + } + in.Accept() + } + }() + + err = client.Sync() + return +} + +func TestAuthAnonymous(t *testing.T) { - fatalIf(t, configureSASL()) ++ configureSASL() + got, err := testAuthClientServer(t, + []ConnectionOption{User("fred"), VirtualHost("vhost"), SASLAllowInsecure(true)}, + []ConnectionOption{SASLAllowedMechs("ANONYMOUS"), SASLAllowInsecure(true)}) + fatalIf(t, err) + errorIf(t, checkEqual(connectionSettings{user: "anonymous", virtualHost: "vhost"}, got)) +} + +func TestAuthPlain(t *testing.T) { + if !SASLExtended() { + t.Skip() + } + fatalIf(t, configureSASL()) + got, err := testAuthClientServer(t, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("fred@proton"), Password([]byte("xxx"))}, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")}) + fatalIf(t, err) + errorIf(t, checkEqual(connectionSettings{user: "fred@proton"}, got)) +} + +func TestAuthBadPass(t *testing.T) { + if !SASLExtended() { + t.Skip() + } + fatalIf(t, configureSASL()) + _, err := testAuthClientServer(t, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("fred@proton"), Password([]byte("yyy"))}, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")}) + if err == nil { + t.Error("Expected auth failure for bad pass") + } +} + +func TestAuthBadUser(t *testing.T) { + if !SASLExtended() { + t.Skip() + } + fatalIf(t, configureSASL()) + _, err := testAuthClientServer(t, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN"), User("foo@bar"), Password([]byte("yyy"))}, + []ConnectionOption{SASLAllowInsecure(true), SASLAllowedMechs("PLAIN")}) + if err == nil { + t.Error("Expected auth failure for bad user") + } +} + +var confDir string +var confErr error + +func configureSASL() error { + if confDir != "" || confErr != nil { + return confErr + } + confDir, confErr = ioutil.TempDir("", "") + if confErr != nil { + return confErr + } + + GlobalSASLConfigDir(confDir) + GlobalSASLConfigName("test") + conf := filepath.Join(confDir, "test.conf") + + db := filepath.Join(confDir, "proton.sasldb") - cmd := exec.Command("saslpasswd2", "-c", "-p", "-f", db, "-u", "proton", "fred") ++ saslpasswd := os.Getenv("SASLPASSWD"); ++ if saslpasswd == "" { ++ saslpasswd = "saslpasswd2" ++ } ++ cmd := exec.Command(saslpasswd, "-c", "-p", "-f", db, "-u", "proton", "fred") + cmd.Stdin = strings.NewReader("xxx") // Password + if out, err := cmd.CombinedOutput(); err != nil { + confErr = fmt.Errorf("saslpasswd2 failed: %s\n%s", err, out) + return confErr + } + confStr := "sasldb_path: " + db + "\nmech_list: EXTERNAL DIGEST-MD5 SCRAM-SHA-1 CRAM-MD5 PLAIN ANONYMOUS\n" + if err := ioutil.WriteFile(conf, []byte(confStr), os.ModePerm); err != nil { + confErr = fmt.Errorf("write conf file %s failed: %s", conf, err) + } + return confErr +} + +func TestMain(m *testing.M) { + status := m.Run() + if confDir != "" { + _ = os.RemoveAll(confDir) + } + os.Exit(status) +} http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/b25d21e6/electron/connection.go ---------------------------------------------------------------------- diff --cc electron/connection.go index 2749b2b,0000000..731e64d mode 100644,000000..100644 --- a/electron/connection.go +++ b/electron/connection.go @@@ -1,421 -1,0 +1,421 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package electron + +// #include <proton/disposition.h> +import "C" + +import ( + "net" + "qpid.apache.org/proton" + "sync" + "time" +) + +// Settings associated with a Connection. +type ConnectionSettings interface { + // Authenticated user name associated with the connection. + User() string + + // The AMQP virtual host name for the connection. + // + // Optional, useful when the server has multiple names and provides different + // service based on the name the client uses to connect. + // + // By default it is set to the DNS host name that the client uses to connect, + // but it can be set to something different at the client side with the + // VirtualHost() option. + // + // Returns error if the connection fails to authenticate. + VirtualHost() string + + // Heartbeat is the maximum delay between sending frames that the remote peer + // has requested of us. If the interval expires an empty "heartbeat" frame + // will be sent automatically to keep the connection open. + Heartbeat() time.Duration +} + +// Connection is an AMQP connection, created by a Container. +type Connection interface { + Endpoint + ConnectionSettings + + // Sender opens a new sender on the DefaultSession. + Sender(...LinkOption) (Sender, error) + + // Receiver opens a new Receiver on the DefaultSession(). + Receiver(...LinkOption) (Receiver, error) + + // DefaultSession() returns a default session for the connection. It is opened + // on the first call to DefaultSession and returned on subsequent calls. + DefaultSession() (Session, error) + + // Session opens a new session. + Session(...SessionOption) (Session, error) + + // Container for the connection. + Container() Container + + // Disconnect the connection abruptly with an error. + Disconnect(error) + + // Wait waits for the connection to be disconnected. + Wait() error + + // WaitTimeout is like Wait but returns Timeout if the timeout expires. + WaitTimeout(time.Duration) error + + // Incoming returns a channel for incoming endpoints opened by the remote peer. + // See the Incoming interface for more detail. + // + // Note: this channel will first return an *IncomingConnection for the + // connection itself which allows you to look at security information and + // decide whether to Accept() or Reject() the connection. Then it will return + // *IncomingSession, *IncomingSender and *IncomingReceiver as they are opened + // by the remote end. + // + // Note 2: you must receiving from Incoming() and call Accept/Reject to avoid + // blocking electron event loop. Normally you would run a loop in a goroutine + // to handle incoming types that interest and Accept() those that don't. + Incoming() <-chan Incoming +} + +type connectionSettings struct { + user, virtualHost string + heartbeat time.Duration +} + +func (c connectionSettings) User() string { return c.user } +func (c connectionSettings) VirtualHost() string { return c.virtualHost } +func (c connectionSettings) Heartbeat() time.Duration { return c.heartbeat } + +// ConnectionOption can be passed when creating a connection to configure various options +type ConnectionOption func(*connection) + +// User returns a ConnectionOption sets the user name for a connection +func User(user string) ConnectionOption { + return func(c *connection) { + c.user = user + c.pConnection.SetUser(user) + } +} + +// VirtualHost returns a ConnectionOption to set the AMQP virtual host for the connection. +// Only applies to outbound client connection. +func VirtualHost(virtualHost string) ConnectionOption { + return func(c *connection) { + c.virtualHost = virtualHost + c.pConnection.SetHostname(virtualHost) + } +} + +// Password returns a ConnectionOption to set the password used to establish a +// connection. Only applies to outbound client connection. +// +// The connection will erase its copy of the password from memory as soon as it - // has been used to authenticate. If you are concerned about paswords staying in ++// has been used to authenticate. If you are concerned about passwords staying in +// memory you should never store them as strings, and should overwrite your +// copy as soon as you are done with it. +// +func Password(password []byte) ConnectionOption { + return func(c *connection) { c.pConnection.SetPassword(password) } +} + +// Server returns a ConnectionOption to put the connection in server mode for incoming connections. +// +// A server connection will do protocol negotiation to accept a incoming AMQP +// connection. Normally you would call this for a connection created by +// net.Listener.Accept() +// +func Server() ConnectionOption { + return func(c *connection) { c.engine.Server(); c.server = true; AllowIncoming()(c) } +} + +// AllowIncoming returns a ConnectionOption to enable incoming endpoints, see +// Connection.Incoming() This is automatically set for Server() connections. +func AllowIncoming() ConnectionOption { + return func(c *connection) { c.incoming = make(chan Incoming) } +} + +// Parent returns a ConnectionOption that associates the Connection with it's Container +// If not set a connection will create its own default container. +func Parent(cont Container) ConnectionOption { + return func(c *connection) { c.container = cont.(*container) } +} + +type connection struct { + endpoint + connectionSettings + + defaultSessionOnce, closeOnce sync.Once + + container *container + conn net.Conn + server bool + incoming chan Incoming + handler *handler + engine *proton.Engine + pConnection proton.Connection + + defaultSession Session +} + +// NewConnection creates a connection with the given options. +func NewConnection(conn net.Conn, opts ...ConnectionOption) (*connection, error) { + c := &connection{ + conn: conn, + } + c.handler = newHandler(c) + var err error + c.engine, err = proton.NewEngine(c.conn, c.handler.delegator) + if err != nil { + return nil, err + } + c.pConnection = c.engine.Connection() + for _, set := range opts { + set(c) + } + if c.container == nil { + c.container = NewContainer("").(*container) + } + c.pConnection.SetContainer(c.container.Id()) + globalSASLInit(c.engine) + + c.endpoint.init(c.engine.String()) + go c.run() + return c, nil +} + +func (c *connection) run() { + if !c.server { + c.pConnection.Open() + } + _ = c.engine.Run() + if c.incoming != nil { + close(c.incoming) + } + _ = c.closed(Closed) +} + +func (c *connection) Close(err error) { + c.err.Set(err) + c.engine.Close(err) +} + +func (c *connection) Disconnect(err error) { + c.err.Set(err) + c.engine.Disconnect(err) +} + +func (c *connection) Session(opts ...SessionOption) (Session, error) { + var s Session + err := c.engine.InjectWait(func() error { + if c.Error() != nil { + return c.Error() + } + pSession, err := c.engine.Connection().Session() + if err == nil { + pSession.Open() + if err == nil { + s = newSession(c, pSession, opts...) + } + } + return err + }) + return s, err +} + +func (c *connection) Container() Container { return c.container } + +func (c *connection) DefaultSession() (s Session, err error) { + c.defaultSessionOnce.Do(func() { + c.defaultSession, err = c.Session() + }) + if err == nil { + err = c.Error() + } + return c.defaultSession, err +} + +func (c *connection) Sender(opts ...LinkOption) (Sender, error) { + if s, err := c.DefaultSession(); err == nil { + return s.Sender(opts...) + } else { + return nil, err + } +} + +func (c *connection) Receiver(opts ...LinkOption) (Receiver, error) { + if s, err := c.DefaultSession(); err == nil { + return s.Receiver(opts...) + } else { + return nil, err + } +} + +func (c *connection) Connection() Connection { return c } + +func (c *connection) Wait() error { return c.WaitTimeout(Forever) } +func (c *connection) WaitTimeout(timeout time.Duration) error { + _, err := timedReceive(c.done, timeout) + if err == Timeout { + return Timeout + } + return c.Error() +} + +func (c *connection) Incoming() <-chan Incoming { + assert(c.incoming != nil, "Incoming() is only allowed for a Connection created with the Server() option: %s", c) + return c.incoming +} + +type IncomingConnection struct { + incoming + connectionSettings + c *connection +} + +func newIncomingConnection(c *connection) *IncomingConnection { + c.user = c.pConnection.Transport().User() + c.virtualHost = c.pConnection.RemoteHostname() + return &IncomingConnection{ + incoming: makeIncoming(c.pConnection), + connectionSettings: c.connectionSettings, + c: c} +} + +// AcceptConnection is like Accept() but takes ConnectionOption s +// For example you can set the Heartbeat() for the accepted connection. +func (in *IncomingConnection) AcceptConnection(opts ...ConnectionOption) Connection { + return in.accept(func() Endpoint { + for _, opt := range opts { + opt(in.c) + } + in.c.pConnection.Open() + return in.c + }).(Connection) +} + +func (in *IncomingConnection) Accept() Endpoint { + return in.AcceptConnection() +} + +func sasl(c *connection) proton.SASL { return c.engine.Transport().SASL() } + +// SASLEnable returns a ConnectionOption that enables SASL authentication. +// Only required if you don't set any other SASL options. +func SASLEnable() ConnectionOption { return func(c *connection) { sasl(c) } } + +// SASLAllowedMechs returns a ConnectionOption to set the list of allowed SASL +// mechanisms. +// +// Can be used on the client or the server to restrict the SASL for a connection. +// mechs is a space-separated list of mechanism names. +// +func SASLAllowedMechs(mechs string) ConnectionOption { + return func(c *connection) { sasl(c).AllowedMechs(mechs) } +} + +// SASLAllowInsecure returns a ConnectionOption that allows or disallows clear +// text SASL authentication mechanisms +// +// By default the SASL layer is configured not to allow mechanisms that disclose +// the clear text of the password over an unencrypted AMQP connection. This specifically +// will disallow the use of the PLAIN mechanism without using SSL encryption. +// +// This default is to avoid disclosing password information accidentally over an +// insecure network. +// +func SASLAllowInsecure(b bool) ConnectionOption { + return func(c *connection) { sasl(c).SetAllowInsecureMechs(b) } +} + +// Heartbeat returns a ConnectionOption that requests the maximum delay +// between sending frames for the remote peer. If we don't receive any frames +// within 2*delay we will close the connection. +// +func Heartbeat(delay time.Duration) ConnectionOption { + // Proton-C divides the idle-timeout by 2 before sending, so compensate. + return func(c *connection) { c.engine.Transport().SetIdleTimeout(2 * delay) } +} + +// GlobalSASLConfigDir sets the SASL configuration directory for every +// Connection created in this process. If not called, the default is determined +// by your SASL installation. +// +// You can set SASLAllowInsecure and SASLAllowedMechs on individual connections. +// +func GlobalSASLConfigDir(dir string) { globalSASLConfigDir = dir } + +// GlobalSASLConfigName sets the SASL configuration name for every Connection +// created in this process. If not called the default is "proton-server". +// +// The complete configuration file name is +// <sasl-config-dir>/<sasl-config-name>.conf +// +// You can set SASLAllowInsecure and SASLAllowedMechs on individual connections. +// +func GlobalSASLConfigName(dir string) { globalSASLConfigName = dir } + +// Do we support extended SASL negotiation? +// All implementations of Proton support ANONYMOUS and EXTERNAL on both +// client and server sides and PLAIN on the client side. +// +// Extended SASL implememtations use an external library (Cyrus SASL) +// to support other mechanisms beyond these basic ones. +func SASLExtended() bool { return proton.SASLExtended() } + +var ( + globalSASLConfigName string + globalSASLConfigDir string +) + +// TODO aconway 2016-09-15: Current pn_sasl C impl config is broken, so all we +// can realistically offer is global configuration. Later if/when the pn_sasl C +// impl is fixed we can offer per connection over-rides. +func globalSASLInit(eng *proton.Engine) { + sasl := eng.Transport().SASL() + if globalSASLConfigName != "" { + sasl.ConfigName(globalSASLConfigName) + } + if globalSASLConfigDir != "" { + sasl.ConfigPath(globalSASLConfigDir) + } +} + +// Dial is shorthand for using net.Dial() then NewConnection() +// See net.Dial() for the meaning of the network, address arguments. +func Dial(network, address string, opts ...ConnectionOption) (c Connection, err error) { + conn, err := net.Dial(network, address) + if err == nil { + c, err = NewConnection(conn, opts...) + } + return +} + +// DialWithDialer is shorthand for using dialer.Dial() then NewConnection() +// See net.Dial() for the meaning of the network, address arguments. +func DialWithDialer(dialer *net.Dialer, network, address string, opts ...ConnectionOption) (c Connection, err error) { + conn, err := dialer.Dial(network, address) + if err == nil { + c, err = NewConnection(conn, opts...) + } + return +} http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/b25d21e6/electron/electron_test.go ---------------------------------------------------------------------- diff --cc electron/electron_test.go index 4cd8453,0000000..74759f5 mode 100644,000000..100644 --- a/electron/electron_test.go +++ b/electron/electron_test.go @@@ -1,546 -1,0 +1,546 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package electron + +import ( + "fmt" + "net" + "path" + "qpid.apache.org/amqp" + "reflect" + "runtime" + "testing" + "time" +) + +func fatalIf(t *testing.T, err error) { + if err != nil { + _, file, line, ok := runtime.Caller(1) // annotate with location of caller. + if ok { + _, file = path.Split(file) + } + t.Fatalf("(from %s:%d) %v", file, line, err) + } +} + +func errorIf(t *testing.T, err error) { + if err != nil { + _, file, line, ok := runtime.Caller(1) // annotate with location of caller. + if ok { + _, file = path.Split(file) + } + t.Errorf("(from %s:%d) %v", file, line, err) + } +} + +func checkEqual(want interface{}, got interface{}) error { + if !reflect.DeepEqual(want, got) { + return fmt.Errorf("%#v != %#v", want, got) + } + return nil +} + +// Start a server, return listening addr and channel for incoming Connections. +func newServer(t *testing.T, cont Container, opts ...ConnectionOption) (net.Addr, <-chan Connection) { + listener, err := net.Listen("tcp4", "") // For systems with ipv6 disabled + fatalIf(t, err) + addr := listener.Addr() + ch := make(chan Connection) + go func() { + conn, err := listener.Accept() + c, err := cont.Connection(conn, append([]ConnectionOption{Server()}, opts...)...) + fatalIf(t, err) + ch <- c + }() + return addr, ch +} + +// Open a client connection and session, return the session. +func newClient(t *testing.T, cont Container, addr net.Addr, opts ...ConnectionOption) Session { + conn, err := net.Dial(addr.Network(), addr.String()) + fatalIf(t, err) - c, err := cont.Connection(conn, opts...) - fatalIf(t, err) - sn, err := c.Session() - fatalIf(t, err) ++ // Don't bother checking error here, it's an async error so it's racy to do so anyway. ++ // Let caller use Sync() or catch it on first use. ++ c, _ := cont.Connection(conn, opts...) ++ sn, _ := c.Session() + return sn +} + +// Return client and server ends of the same connection. +func newClientServerOpts(t *testing.T, copts []ConnectionOption, sopts []ConnectionOption) (client Session, server Connection) { + addr, ch := newServer(t, NewContainer("test-server"), sopts...) + client = newClient(t, NewContainer("test-client"), addr, copts...) + return client, <-ch +} + +// Return client and server ends of the same connection. +func newClientServer(t *testing.T) (client Session, server Connection) { + return newClientServerOpts(t, nil, nil) +} + +// Close client and server +func closeClientServer(client Session, server Connection) { + client.Connection().Close(nil) + server.Close(nil) +} + +// Send a message one way with a client sender and server receiver, verify ack. +func TestClientSendServerReceive(t *testing.T) { + nLinks := 3 + nMessages := 3 + + rchan := make(chan Receiver, nLinks) + client, server := newClientServer(t) + go func() { + for in := range server.Incoming() { + switch in := in.(type) { + case *IncomingReceiver: + in.SetCapacity(1) + in.SetPrefetch(false) + rchan <- in.Accept().(Receiver) + default: + in.Accept() + } + } + }() + + defer func() { closeClientServer(client, server) }() + + s := make([]Sender, nLinks) + for i := 0; i < nLinks; i++ { + var err error + s[i], err = client.Sender(Target(fmt.Sprintf("foo%d", i))) + if err != nil { + t.Fatal(err) + } + } + r := make([]Receiver, nLinks) + for i := 0; i < nLinks; i++ { + r[i] = <-rchan + } + + for i := 0; i < nLinks; i++ { + for j := 0; j < nMessages; j++ { + // Client send + ack := make(chan Outcome, 1) + sendDone := make(chan struct{}) + go func() { + defer close(sendDone) + m := amqp.NewMessageWith(fmt.Sprintf("foobar%v-%v", i, j)) + var err error + s[i].SendAsync(m, ack, "testing") + if err != nil { + t.Fatal(err) + } + }() + - // Server recieve ++ // Server receive + rm, err := r[i].Receive() + if err != nil { + t.Fatal(err) + } + if want, got := interface{}(fmt.Sprintf("foobar%v-%v", i, j)), rm.Message.Body(); want != got { + t.Errorf("%#v != %#v", want, got) + } + + // Should not be acknowledged on client yet + <-sendDone + select { + case <-ack: + t.Errorf("unexpected ack") + default: + } + + // Server send ack + if err := rm.Reject(); err != nil { + t.Error(err) + } + // Client get ack. + if a := <-ack; a.Value != "testing" || a.Error != nil || a.Status != Rejected { + t.Error("unexpected ack: ", a.Status, a.Error, a.Value) + } + } + } +} + +func TestClientReceiver(t *testing.T) { + nMessages := 3 + client, server := newClientServer(t) + go func() { + for in := range server.Incoming() { + switch in := in.(type) { + case *IncomingSender: + s := in.Accept().(Sender) + go func() { + for i := int32(0); i < int32(nMessages); i++ { + out := s.SendSync(amqp.NewMessageWith(i)) + if out.Error != nil { + t.Error(out.Error) + return + } + } + s.Close(nil) + }() + default: + in.Accept() + } + } + }() + + r, err := client.Receiver(Source("foo")) + if err != nil { + t.Fatal(err) + } + for i := int32(0); i < int32(nMessages); i++ { + rm, err := r.Receive() + if err != nil { + if err != Closed { + t.Error(err) + } + break + } + if err := rm.Accept(); err != nil { + t.Error(err) + } + if b, ok := rm.Message.Body().(int32); !ok || b != i { + t.Errorf("want %v, true got %v, %v", i, b, ok) + } + } + server.Close(nil) + client.Connection().Close(nil) +} + +// Test timeout versions of waiting functions. +func TestTimeouts(t *testing.T) { + var err error + rchan := make(chan Receiver, 1) + client, server := newClientServer(t) + go func() { + for i := range server.Incoming() { + switch i := i.(type) { + case *IncomingReceiver: + i.SetCapacity(1) + i.SetPrefetch(false) + rchan <- i.Accept().(Receiver) // Issue credit only on receive + default: + i.Accept() + } + } + }() + defer func() { closeClientServer(client, server) }() + + // Open client sender + snd, err := client.Sender(Target("test")) + if err != nil { + t.Fatal(err) + } + rcv := <-rchan + + // Test send with timeout + short := time.Millisecond + long := time.Second + m := amqp.NewMessage() + if err := snd.SendSyncTimeout(m, 0).Error; err != Timeout { // No credit, expect timeout. + t.Error("want Timeout got", err) + } + if err := snd.SendSyncTimeout(m, short).Error; err != Timeout { // No credit, expect timeout. + t.Error("want Timeout got", err) + } + // Test receive with timeout + if _, err = rcv.ReceiveTimeout(0); err != Timeout { // No credit, expect timeout. + t.Error("want Timeout got", err) + } + // Test receive with timeout + if _, err = rcv.ReceiveTimeout(short); err != Timeout { // No credit, expect timeout. + t.Error("want Timeout got", err) + } + // There is now a credit on the link due to receive + ack := make(chan Outcome) + snd.SendAsyncTimeout(m, ack, nil, short) + // Disposition should timeout + select { + case <-ack: + t.Errorf("want Timeout got %#v", ack) + case <-time.After(short): + } + + // Receive and accept + rm, err := rcv.ReceiveTimeout(long) + if err != nil { + t.Fatal(err) + } + if err := rm.Accept(); err != nil { + t.Fatal(err) + } + // Sender get ack + if a := <-ack; a.Status != Accepted || a.Error != nil { + t.Errorf("want (accepted, nil) got %#v", a) + } +} + +// A server that returns the opposite end of each client link via channels. +type pairs struct { + t *testing.T + client Session + server Connection + rchan chan Receiver + schan chan Sender + capacity int + prefetch bool +} + +func newPairs(t *testing.T, capacity int, prefetch bool) *pairs { + p := &pairs{t: t, rchan: make(chan Receiver, 1), schan: make(chan Sender, 1)} + p.client, p.server = newClientServer(t) + go func() { + for i := range p.server.Incoming() { + switch i := i.(type) { + case *IncomingReceiver: + i.SetCapacity(capacity) + i.SetPrefetch(prefetch) + p.rchan <- i.Accept().(Receiver) + case *IncomingSender: + p.schan <- i.Accept().(Sender) + default: + i.Accept() + } + } + }() + return p +} + +func (p *pairs) close() { + closeClientServer(p.client, p.server) +} + +// Return a client sender and server receiver +func (p *pairs) senderReceiver() (Sender, Receiver) { + snd, err := p.client.Sender() + fatalIf(p.t, err) + rcv := <-p.rchan + return snd, rcv +} + +// Return a client receiver and server sender +func (p *pairs) receiverSender() (Receiver, Sender) { + rcv, err := p.client.Receiver() + fatalIf(p.t, err) + snd := <-p.schan + return rcv, snd +} + +type result struct { + label string + err error + value interface{} +} + +func (r result) String() string { return fmt.Sprintf("%v(%v)", r.err, r.label) } + +func doSend(snd Sender, results chan result) { + err := snd.SendSync(amqp.NewMessage()).Error + results <- result{"send", err, nil} +} + +func doReceive(rcv Receiver, results chan result) { + msg, err := rcv.Receive() + results <- result{"receive", err, msg} +} + +func doDisposition(ack <-chan Outcome, results chan result) { + results <- result{"disposition", (<-ack).Error, nil} +} + +// Senders get credit immediately if receivers have prefetch set +func TestSendReceivePrefetch(t *testing.T) { + pairs := newPairs(t, 1, true) + s, r := pairs.senderReceiver() + s.SendAsyncTimeout(amqp.NewMessage(), nil, nil, time.Second) // Should not block for credit. + if _, err := r.Receive(); err != nil { + t.Error(err) + } +} + +// Senders do not get credit till Receive() if receivers don't have prefetch +func TestSendReceiveNoPrefetch(t *testing.T) { + pairs := newPairs(t, 1, false) + s, r := pairs.senderReceiver() + done := make(chan struct{}, 1) + go func() { + s.SendAsyncTimeout(amqp.NewMessage(), nil, nil, time.Second) // Should block for credit. + close(done) + }() + select { + case <-done: + t.Errorf("send should be blocked on credit") + default: + if _, err := r.Receive(); err != nil { + t.Error(err) + } else { + <-done + } // Should be unblocked now + } +} + +// Test that closing Links interrupts blocked link functions. +func TestLinkCloseInterrupt(t *testing.T) { + want := amqp.Error{Name: "x", Description: "all bad"} + pairs := newPairs(t, 1, false) + results := make(chan result) // Collect expected errors + + // Note closing the link does not interrupt Send() calls, the AMQP spec says + // that deliveries can be settled after the link is closed. + + // Receiver.Close() interrupts Receive() + snd, rcv := pairs.senderReceiver() + go doReceive(rcv, results) + rcv.Close(want) + if r := <-results; want != r.err { + t.Errorf("want %#v got %#v", want, r) + } + + // Remote Sender.Close() interrupts Receive() + snd, rcv = pairs.senderReceiver() + go doReceive(rcv, results) + snd.Close(want) + if r := <-results; want != r.err { + t.Errorf("want %#v got %#v", want, r) + } +} + +// Test closing the server end of a connection. +func TestConnectionCloseInterrupt1(t *testing.T) { + want := amqp.Error{Name: "x", Description: "bad"} + pairs := newPairs(t, 1, true) + results := make(chan result) // Collect expected errors + + // Connection.Close() interrupts Send, Receive, Disposition. + snd, rcv := pairs.senderReceiver() + go doSend(snd, results) + + if _, err := rcv.Receive(); err != nil { + t.Error("receive", err) + } + rcv, snd = pairs.receiverSender() + go doReceive(rcv, results) + + snd, rcv = pairs.senderReceiver() + ack := snd.SendWaitable(amqp.NewMessage()) + if _, err := rcv.Receive(); err != nil { + t.Error("receive", err) + } + go doDisposition(ack, results) + + pairs.server.Close(want) + for i := 0; i < 3; i++ { + if r := <-results; want != r.err { + t.Errorf("want %v got %v", want, r) + } + } +} + +// Test closing the client end of the connection. +func TestConnectionCloseInterrupt2(t *testing.T) { + want := amqp.Error{Name: "x", Description: "bad"} + pairs := newPairs(t, 1, true) + results := make(chan result) // Collect expected errors + + // Connection.Close() interrupts Send, Receive, Disposition. + snd, rcv := pairs.senderReceiver() + go doSend(snd, results) + if _, err := rcv.Receive(); err != nil { + t.Error("receive", err) + } + + rcv, snd = pairs.receiverSender() + go doReceive(rcv, results) + + snd, rcv = pairs.senderReceiver() + ack := snd.SendWaitable(amqp.NewMessage()) + go doDisposition(ack, results) + + pairs.client.Connection().Close(want) + for i := 0; i < 3; i++ { + if r := <-results; want != r.err { + t.Errorf("want %v got %v", want, r.err) + } + } +} + +func heartbeat(c Connection) time.Duration { + return c.(*connection).engine.Transport().RemoteIdleTimeout() +} + +func TestHeartbeat(t *testing.T) { + client, server := newClientServerOpts(t, + []ConnectionOption{Heartbeat(102 * time.Millisecond)}, + nil) + defer closeClientServer(client, server) + + var serverHeartbeat time.Duration + + go func() { + for in := range server.Incoming() { + switch in := in.(type) { + case *IncomingConnection: + serverHeartbeat = in.Heartbeat() + in.AcceptConnection(Heartbeat(101 * time.Millisecond)) + default: + in.Accept() + } + } + }() + + // Freeze the server to stop it sending heartbeats. + unfreeze := make(chan bool) + defer close(unfreeze) + freeze := func() error { return server.(*connection).engine.Inject(func() { <-unfreeze }) } + + fatalIf(t, client.Sync()) + errorIf(t, checkEqual(101*time.Millisecond, heartbeat(client.Connection()))) + errorIf(t, checkEqual(102*time.Millisecond, serverHeartbeat)) + errorIf(t, client.Connection().Error()) + + // Freeze the server for less than a heartbeat + fatalIf(t, freeze()) + time.Sleep(50 * time.Millisecond) + unfreeze <- true + // Make sure server is still responding. + s, err := client.Sender() + errorIf(t, err) + errorIf(t, s.Sync()) + + // Freeze the server till the client times out the connection + fatalIf(t, freeze()) + select { + case <-client.Done(): + if amqp.ResourceLimitExceeded != client.Error().(amqp.Error).Name { + t.Error("bad timeout error:", client.Error()) + } + case <-time.After(400 * time.Millisecond): + t.Error("connection failed to time out") + } + + unfreeze <- true // Unfreeze the server + <-server.Done() + if amqp.ResourceLimitExceeded != server.Error().(amqp.Error).Name { + t.Error("bad timeout error:", server.Error()) + } +} http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/b25d21e6/electron/receiver.go ---------------------------------------------------------------------- diff --cc electron/receiver.go index 781fd7c,0000000..26b46a8 mode 100644,000000..100644 --- a/electron/receiver.go +++ b/electron/receiver.go @@@ -1,236 -1,0 +1,236 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package electron + +import ( + "fmt" + "qpid.apache.org/amqp" + "qpid.apache.org/proton" + "time" +) + +// Receiver is a Link that receives messages. +// +type Receiver interface { + Endpoint + LinkSettings + + // Receive blocks until a message is available or until the Receiver is closed + // and has no more buffered messages. + Receive() (ReceivedMessage, error) + + // ReceiveTimeout is like Receive but gives up after timeout, see Timeout. + // + // Note that that if Prefetch is false, after a Timeout the credit issued by + // Receive remains on the link. It will be used by the next call to Receive. + ReceiveTimeout(timeout time.Duration) (ReceivedMessage, error) + + // Prefetch==true means the Receiver will automatically issue credit to the + // remote sender to keep its buffer as full as possible, i.e. it will + // "pre-fetch" messages independently of the application calling + // Receive(). This gives good throughput for applications that handle a + // continuous stream of messages. Larger capacity may improve throughput, the + // optimal value depends on the characteristics of your application. + // + // Prefetch==false means the Receiver will issue only issue credit when you + // call Receive(), and will only issue enough credit to satisfy the calls + // actually made. This gives lower throughput but will not fetch any messages + // in advance. It is good for synchronous applications that need to evaluate + // each message before deciding whether to receive another. The + // request-response pattern is a typical example. If you make concurrent + // calls to Receive with pre-fetch disabled, you can improve performance by + // setting the capacity close to the expected number of concurrent calls. + // + Prefetch() bool + + // Capacity is the size (number of messages) of the local message buffer + // These are messages received but not yet returned to the application by a call to Receive() + Capacity() int +} + +// Receiver implementation +type receiver struct { + link + buffer chan ReceivedMessage + callers int +} + +func (r *receiver) Capacity() int { return cap(r.buffer) } +func (r *receiver) Prefetch() bool { return r.prefetch } + +// Call in proton goroutine +func newReceiver(ls linkSettings) *receiver { + r := &receiver{link: link{linkSettings: ls}} + r.endpoint.init(r.link.pLink.String()) + if r.capacity < 1 { + r.capacity = 1 + } + r.buffer = make(chan ReceivedMessage, r.capacity) + r.handler().addLink(r.pLink, r) + r.link.pLink.Open() + if r.prefetch { + r.flow(r.maxFlow()) + } + return r +} + - // Call in proton gorotine. Max additional credit we can request. ++// Call in proton goroutine. Max additional credit we can request. +func (r *receiver) maxFlow() int { return cap(r.buffer) - len(r.buffer) - r.pLink.Credit() } + +func (r *receiver) flow(credit int) { + if credit > 0 { + r.pLink.Flow(credit) + } +} + +// Inject flow check per-caller call when prefetch is off. +// Called with inc=1 at start of call, inc = -1 at end +func (r *receiver) caller(inc int) { + _ = r.engine().Inject(func() { + r.callers += inc + need := r.callers - (len(r.buffer) + r.pLink.Credit()) + max := r.maxFlow() + if need > max { + need = max + } + r.flow(need) + }) +} + +// Inject flow top-up if prefetch is enabled +func (r *receiver) flowTopUp() { + if r.prefetch { + _ = r.engine().Inject(func() { r.flow(r.maxFlow()) }) + } +} + +func (r *receiver) Receive() (rm ReceivedMessage, err error) { + return r.ReceiveTimeout(Forever) +} + +func (r *receiver) ReceiveTimeout(timeout time.Duration) (rm ReceivedMessage, err error) { + assert(r.buffer != nil, "Receiver is not open: %s", r) + if !r.prefetch { // Per-caller flow control + select { // Check for immediate availability, avoid caller() inject + case rm2, ok := <-r.buffer: + if ok { + rm = rm2 + } else { + err = r.Error() + } + return + default: // Not immediately available, inject caller() counts + r.caller(+1) + defer r.caller(-1) + } + } + rmi, err := timedReceive(r.buffer, timeout) + switch err { + case nil: + r.flowTopUp() + rm = rmi.(ReceivedMessage) + case Closed: + err = r.Error() + } + return +} + +// Called in proton goroutine on MMessage event. +func (r *receiver) message(delivery proton.Delivery) { + if r.pLink.State().RemoteClosed() { + localClose(r.pLink, r.pLink.RemoteCondition().Error()) + return + } + if delivery.HasMessage() { + m, err := delivery.Message() + if err != nil { + localClose(r.pLink, err) + return + } + assert(m != nil) + r.pLink.Advance() + if r.pLink.Credit() < 0 { + localClose(r.pLink, fmt.Errorf("received message in excess of credit limit")) + } else { + // We never issue more credit than cap(buffer) so this will not block. + r.buffer <- ReceivedMessage{m, delivery, r} + } + } +} + +func (r *receiver) closed(err error) error { + e := r.link.closed(err) + if r.buffer != nil { + close(r.buffer) + } + return e +} + +// ReceivedMessage contains an amqp.Message and allows the message to be acknowledged. +type ReceivedMessage struct { + // Message is the received message. + Message amqp.Message + + pDelivery proton.Delivery + receiver Receiver +} + +// Acknowledge a ReceivedMessage with the given delivery status. +func (rm *ReceivedMessage) acknowledge(status uint64) error { + return rm.receiver.(*receiver).engine().Inject(func() { + // Deliveries are valid as long as the connection is, unless settled. + rm.pDelivery.SettleAs(uint64(status)) + }) +} + +// Accept tells the sender that we take responsibility for processing the message. +func (rm *ReceivedMessage) Accept() error { return rm.acknowledge(proton.Accepted) } + +// Reject tells the sender we consider the message invalid and unusable. +func (rm *ReceivedMessage) Reject() error { return rm.acknowledge(proton.Rejected) } + +// Release tells the sender we will not process the message but some other +// receiver might. +func (rm *ReceivedMessage) Release() error { return rm.acknowledge(proton.Released) } + +// IncomingReceiver is sent on the Connection.Incoming() channel when there is +// an incoming request to open a receiver link. +type IncomingReceiver struct { + incoming + linkSettings +} + +func newIncomingReceiver(sn *session, pLink proton.Link) *IncomingReceiver { + return &IncomingReceiver{ + incoming: makeIncoming(pLink), + linkSettings: makeIncomingLinkSettings(pLink, sn), + } +} + +// SetCapacity sets the capacity of the incoming receiver, call before Accept() +func (in *IncomingReceiver) SetCapacity(capacity int) { in.capacity = capacity } + +// SetPrefetch sets the pre-fetch mode of the incoming receiver, call before Accept() +func (in *IncomingReceiver) SetPrefetch(prefetch bool) { in.prefetch = prefetch } + +// Accept accepts an incoming receiver endpoint +func (in *IncomingReceiver) Accept() Endpoint { + return in.accept(func() Endpoint { return newReceiver(in.linkSettings) }) +} http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/b25d21e6/electron/time.go ---------------------------------------------------------------------- diff --cc electron/time.go index 51bfbc5,0000000..52f0cee mode 100644,000000..100644 --- a/electron/time.go +++ b/electron/time.go @@@ -1,83 -1,0 +1,83 @@@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package electron + +import ( + "fmt" + "math" + "reflect" + "time" +) + +// Timeout is the error returned if an operation does not complete on time. +// +// Methods named *Timeout in this package take time.Duration timeout parameter. +// +// If timeout > 0 and there is no result available before the timeout, they +// return a zero or nil value and Timeout as an error. +// - // If timeout == 0 they will return a result if one is immediatley available or ++// If timeout == 0 they will return a result if one is immediately available or +// nil/zero and Timeout as an error if not. +// +// If timeout == Forever the function will return only when there is a result or +// some non-timeout error occurs. +// +var Timeout = fmt.Errorf("timeout") + +// Forever can be used as a timeout parameter to indicate wait forever. +const Forever time.Duration = math.MaxInt64 + +// timedReceive receives on channel (which can be a chan of any type), waiting +// up to timeout. +// +// timeout==0 means do a non-blocking receive attempt. timeout < 0 means block +// forever. Other values mean block up to the timeout. +// +// Returns error Timeout on timeout, Closed on channel close. +func timedReceive(channel interface{}, timeout time.Duration) (interface{}, error) { + cases := []reflect.SelectCase{ + reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(channel)}, + } + if timeout == 0 { // Non-blocking + cases = append(cases, reflect.SelectCase{Dir: reflect.SelectDefault}) + } else { // Block up to timeout + cases = append(cases, + reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(After(timeout))}) + } + chosen, value, ok := reflect.Select(cases) + switch { + case chosen == 0 && ok: + return value.Interface(), nil + case chosen == 0 && !ok: + return nil, Closed + default: + return nil, Timeout + } +} + +// After is like time.After but returns a nil channel if timeout == Forever +// since selecting on a nil channel will never return. +func After(timeout time.Duration) <-chan time.Time { + if timeout == Forever { + return nil + } else { + return time.After(timeout) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
