You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@dubbo.apache.org by li...@apache.org on 2024/04/26 01:42:19 UTC
(dubbo-go) branch main updated: fix: Server return with Attachment (#2648)
This is an automated email from the ASF dual-hosted git repository.
liujun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
The following commit(s) were added to refs/heads/main by this push:
new 72f80ae88 fix: Server return with Attachment (#2648)
72f80ae88 is described below
commit 72f80ae88cb409a63103d0ec7b6d250a5c9f5863
Author: YarBor <11...@users.noreply.github.com>
AuthorDate: Fri Apr 26 09:42:14 2024 +0800
fix: Server return with Attachment (#2648)
---
client/client.go | 1 +
protocol/triple/triple_invoker.go | 59 +++++++++--------
protocol/triple/triple_invoker_test.go | 22 ++++---
.../triple/triple_protocol/duplex_http_call.go | 4 ++
protocol/triple/triple_protocol/handler.go | 28 ++++++--
protocol/triple/triple_protocol/header.go | 74 ++++++++++++++++------
6 files changed, 128 insertions(+), 60 deletions(-)
diff --git a/client/client.go b/client/client.go
index 98ab66733..a2bcaef1e 100644
--- a/client/client.go
+++ b/client/client.go
@@ -122,6 +122,7 @@ func (cli *Client) dial(interfaceName string, info *ClientInfo, opts ...Referenc
return &Connection{refOpts: newRefOpts}, nil
}
+
func generateInvocation(methodName string, reqs []interface{}, resp interface{}, callType string, opts *CallOptions) (protocol.Invocation, error) {
var paramsRawVals []interface{}
for _, req := range reqs {
diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go
index e08778429..4c1eb6868 100644
--- a/protocol/triple/triple_invoker.go
+++ b/protocol/triple/triple_invoker.go
@@ -81,11 +81,18 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat
return &result
}
- ctx, callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation)
+ callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation)
if err != nil {
result.SetError(err)
return &result
}
+
+ ctx, err = mergeAttachmentToOutgoing(ctx, invocation)
+ if err != nil {
+ result.SetError(err)
+ return &result
+ }
+
inRawLen := len(inRaw)
if !ti.clientManager.isIDL {
@@ -136,16 +143,33 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat
return &result
}
+func mergeAttachmentToOutgoing(ctx context.Context, inv protocol.Invocation) (context.Context, error) {
+ for key, valRaw := range inv.Attachments() {
+ if str, ok := valRaw.(string); ok {
+ ctx = tri.AppendToOutgoingContext(ctx, key, str)
+ continue
+ }
+ if strs, ok := valRaw.([]string); ok {
+ for _, str := range strs {
+ ctx = tri.AppendToOutgoingContext(ctx, key, str)
+ }
+ continue
+ }
+ return ctx, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key)
+ }
+ return ctx, nil
+}
+
// parseInvocation retrieves information from invocation.
// it returns ctx, callType, inRaw, method, error
-func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, string, []interface{}, string, error) {
+func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (string, []interface{}, string, error) {
callTypeRaw, ok := invocation.GetAttribute(constant.CallTypeKey)
if !ok {
- return nil, "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker")
+ return "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker")
}
callType, ok := callTypeRaw.(string)
if !ok {
- return nil, "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw)
+ return "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw)
}
// please refer to methods of client.Client or code generated by new triple for the usage of inRaw and inRawLen
// e.g. Client.CallUnary(... req, resp []interface, ...)
@@ -153,19 +177,16 @@ func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.I
inRaw := invocation.ParameterRawValues()
method := invocation.MethodName()
if method == "" {
- return nil, "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker")
+ return "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker")
}
- ctx, err := parseAttachments(ctx, url, invocation)
- if err != nil {
- return nil, "", nil, "", err
- }
+ parseAttachments(ctx, url, invocation)
- return ctx, callType, inRaw, method, nil
+ return callType, inRaw, method, nil
}
// parseAttachments retrieves attachments from users passed-in and URL, then injects them into ctx
-func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, error) {
+func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) {
// retrieve users passed-in attachment
attaRaw := ctx.Value(constant.AttachmentKey)
if attaRaw != nil {
@@ -181,22 +202,6 @@ func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.
invocation.SetAttachment(key, val)
}
}
- // inject attachments
- for key, valRaw := range invocation.Attachments() {
- if str, ok := valRaw.(string); ok {
- ctx = tri.AppendToOutgoingContext(ctx, key, str)
- continue
- }
- if strs, ok := valRaw.([]string); ok {
- for _, str := range strs {
- ctx = tri.AppendToOutgoingContext(ctx, key, str)
- }
- continue
- }
- return nil, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key)
- }
-
- return ctx, nil
}
// IsAvailable get available status
diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go
index 7d14d4239..e9dcc968c 100644
--- a/protocol/triple/triple_invoker_test.go
+++ b/protocol/triple/triple_invoker_test.go
@@ -19,6 +19,7 @@ package triple
import (
"context"
+ "net/http"
"testing"
"dubbo.apache.org/dubbo-go/v3/common"
@@ -35,7 +36,7 @@ func Test_parseInvocation(t *testing.T) {
ctx func() context.Context
url *common.URL
invo func() protocol.Invocation
- expect func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error)
+ expect func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error)
}{
{
desc: "miss callType",
@@ -46,7 +47,7 @@ func Test_parseInvocation(t *testing.T) {
invo: func() protocol.Invocation {
return invocation.NewRPCInvocationWithOptions()
},
- expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+ expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
assert.NotNil(t, err)
},
},
@@ -61,7 +62,7 @@ func Test_parseInvocation(t *testing.T) {
iv.SetAttribute(constant.CallTypeKey, 1)
return iv
},
- expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+ expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
assert.NotNil(t, err)
},
},
@@ -76,7 +77,7 @@ func Test_parseInvocation(t *testing.T) {
iv.SetAttribute(constant.CallTypeKey, constant.CallUnary)
return iv
},
- expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+ expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
assert.NotNil(t, err)
},
},
@@ -84,8 +85,8 @@ func Test_parseInvocation(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
- ctx, callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo())
- test.expect(t, ctx, callType, inRaw, methodName, err)
+ callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo())
+ test.expect(t, callType, inRaw, methodName, err)
})
}
}
@@ -112,7 +113,7 @@ func Test_parseAttachments(t *testing.T) {
},
expect: func(t *testing.T, ctx context.Context, err error) {
assert.Nil(t, err)
- header := tri.ExtractFromOutgoingContext(ctx)
+ header := http.Header(tri.ExtractFromOutgoingContext(ctx))
assert.NotNil(t, header)
assert.Equal(t, "interface", header.Get(constant.InterfaceKey))
assert.Equal(t, "token", header.Get(constant.TokenKey))
@@ -132,7 +133,7 @@ func Test_parseAttachments(t *testing.T) {
},
expect: func(t *testing.T, ctx context.Context, err error) {
assert.Nil(t, err)
- header := tri.ExtractFromOutgoingContext(ctx)
+ header := http.Header(tri.ExtractFromOutgoingContext(ctx))
assert.NotNil(t, header)
assert.Equal(t, "val1", header.Get("key1"))
assert.Equal(t, []string{"key2_1", "key2_2"}, header.Values("key2"))
@@ -157,7 +158,10 @@ func Test_parseAttachments(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
- ctx, err := parseAttachments(test.ctx(), test.url, test.invo())
+ ctx := test.ctx()
+ inv := test.invo()
+ parseAttachments(ctx, test.url, inv)
+ ctx, err := mergeAttachmentToOutgoing(ctx, inv)
test.expect(t, ctx, err)
})
}
diff --git a/protocol/triple/triple_protocol/duplex_http_call.go b/protocol/triple/triple_protocol/duplex_http_call.go
index 3865a56ab..6c4c02176 100644
--- a/protocol/triple/triple_protocol/duplex_http_call.go
+++ b/protocol/triple/triple_protocol/duplex_http_call.go
@@ -184,6 +184,10 @@ func (d *duplexHTTPCall) CloseRead() error {
if err := discard(d.response.Body); err != nil {
return wrapIfRSTError(err)
}
+ // Return incoming data via context, if set outgoing data.
+ if ExtractFromOutgoingContext(d.ctx) != nil {
+ newIncomingContext(d.ctx, d.ResponseTrailer())
+ }
return wrapIfRSTError(d.response.Body.Close())
}
diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go
index 56f83b3fe..7a44f8535 100644
--- a/protocol/triple/triple_protocol/handler.go
+++ b/protocol/triple/triple_protocol/handler.go
@@ -112,6 +112,10 @@ func generateUnaryHandlerFunc(
// merge headers
mergeHeaders(conn.ResponseHeader(), response.Header())
mergeHeaders(conn.ResponseTrailer(), response.Trailer())
+ //Write the server-side return-attachment-data in the tailer to send to the caller
+ if data := ExtractFromOutgoingContext(ctx); data != nil {
+ mergeHeaders(conn.ResponseTrailer(), data)
+ }
return conn.Send(response.Any())
}
@@ -160,6 +164,9 @@ func generateClientStreamHandlerFunc(
}
mergeHeaders(conn.ResponseHeader(), res.header)
mergeHeaders(conn.ResponseTrailer(), res.trailer)
+ if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
+ mergeHeaders(conn.ResponseTrailer(), outgoingData)
+ }
return conn.Send(res.Msg)
}
if interceptor != nil {
@@ -205,7 +212,7 @@ func generateServerStreamHandlerFunc(
}
// embed header in context so that user logic could process them via FromIncomingContext
ctx = newIncomingContext(ctx, conn.RequestHeader())
- return streamFunc(
+ err := streamFunc(
ctx,
&Request{
Msg: req,
@@ -215,6 +222,13 @@ func generateServerStreamHandlerFunc(
},
&ServerStream{conn: conn},
)
+ if err != nil {
+ return err
+ }
+ if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
+ mergeHeaders(conn.ResponseTrailer(), outgoingData)
+ }
+ return nil
}
if interceptor != nil {
implementation = interceptor.WrapStreamingHandler(implementation)
@@ -253,10 +267,14 @@ func generateBidiStreamHandlerFunc(
implementation := func(ctx context.Context, conn StreamingHandlerConn) error {
// embed header in context so that user logic could process them via FromIncomingContext
ctx = newIncomingContext(ctx, conn.RequestHeader())
- return streamFunc(
- ctx,
- &BidiStream{conn: conn},
- )
+ err := streamFunc(ctx, &BidiStream{conn: conn})
+ if err != nil {
+ return err
+ }
+ if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
+ mergeHeaders(conn.ResponseTrailer(), outgoingData)
+ }
+ return nil
}
if interceptor != nil {
implementation = interceptor.WrapStreamingHandler(implementation)
diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go
index 28618e0dc..791cf4a30 100644
--- a/protocol/triple/triple_protocol/header.go
+++ b/protocol/triple/triple_protocol/header.go
@@ -19,6 +19,7 @@ import (
"encoding/base64"
"fmt"
"net/http"
+ "strings"
)
// EncodeBinaryHeader base64-encodes the data. It always emits unpadded values.
@@ -88,20 +89,45 @@ func addHeaderCanonical(h http.Header, key, value string) {
h[key] = append(h[key], value)
}
-type headerIncomingKey struct{}
-type headerOutgoingKey struct{}
+type extraDataKey struct{}
+
+const headerIncomingKey string = "headerIncomingKey"
+const headerOutgoingKey string = "headerOutgoingKey"
+
type handlerOutgoingKey struct{}
-func newIncomingContext(ctx context.Context, header http.Header) context.Context {
- return context.WithValue(ctx, headerIncomingKey{}, header)
+func newIncomingContext(ctx context.Context, data http.Header) context.Context {
+ var header = http.Header{}
+ extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+ if !ok {
+ extraData = map[string]http.Header{}
+ }
+ if data != nil {
+ for key, vals := range data {
+ header[strings.ToLower(key)] = vals
+ }
+ }
+ extraData[headerIncomingKey] = header
+ return context.WithValue(ctx, extraDataKey{}, extraData)
}
// NewOutgoingContext sets headers entirely. If there are existing headers, they would be replaced.
// It is used for passing headers to server-side.
// It is like grpc.NewOutgoingContext.
// Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata.
-func NewOutgoingContext(ctx context.Context, header http.Header) context.Context {
- return context.WithValue(ctx, headerOutgoingKey{}, header)
+func NewOutgoingContext(ctx context.Context, data http.Header) context.Context {
+ var header = http.Header{}
+ if data != nil {
+ for key, vals := range data {
+ header[strings.ToLower(key)] = vals
+ }
+ }
+ extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+ if !ok {
+ extraData = map[string]http.Header{}
+ }
+ extraData[headerOutgoingKey] = header
+ return context.WithValue(ctx, extraDataKey{}, extraData)
}
// AppendToOutgoingContext merges kv pairs from user and existing headers.
@@ -112,37 +138,47 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context
if len(kv)%2 == 1 {
panic(fmt.Sprintf("AppendToOutgoingContext got an odd number of input pairs for header: %d", len(kv)))
}
- var header http.Header
- headerRaw := ctx.Value(headerOutgoingKey{})
- if headerRaw == nil {
+ extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+ if !ok {
+ extraData = map[string]http.Header{}
+ ctx = context.WithValue(ctx, extraDataKey{}, extraData)
+ }
+ header, ok := extraData[headerOutgoingKey]
+ if !ok {
header = make(http.Header)
- } else {
- header = headerRaw.(http.Header)
+ extraData[headerOutgoingKey] = header
}
for i := 0; i < len(kv); i += 2 {
// todo(DMwangnima): think about lowering
- header.Add(kv[i], kv[i+1])
+ header.Add(strings.ToLower(kv[i]), kv[i+1])
}
- return context.WithValue(ctx, headerOutgoingKey{}, header)
+ return ctx
}
func ExtractFromOutgoingContext(ctx context.Context) http.Header {
- headerRaw := ctx.Value(headerOutgoingKey{})
- if headerRaw == nil {
+ extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+ if !ok {
return nil
}
- // since headerOutgoingKey is only used in triple_protocol package, we need not verify the type
- return headerRaw.(http.Header)
+ if outGoingDataHeader, ok := extraData[headerOutgoingKey]; !ok {
+ return nil
+ } else {
+ return outGoingDataHeader
+ }
}
// FromIncomingContext retrieves headers passed by client-side. It is like grpc.FromIncomingContext.
+// it must call after append/setOutgoingContext to return current value
// Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#receiving-metadata-1.
func FromIncomingContext(ctx context.Context) (http.Header, bool) {
- header, ok := ctx.Value(headerIncomingKey{}).(http.Header)
+ data, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
if !ok {
return nil, false
+ } else if incomingDataHeader, ok := data[headerIncomingKey]; !ok {
+ return nil, false
+ } else {
+ return incomingDataHeader, true
}
- return header, true
}
// SetHeader is used for setting response header in server-side. It is like grpc.SendHeader(ctx, header) but