You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@dubbo.apache.org by li...@apache.org on 2024/04/26 01:42:19 UTC

(dubbo-go) branch main updated: fix: Server return with Attachment (#2648)

This is an automated email from the ASF dual-hosted git repository.

liujun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 72f80ae88 fix: Server return with Attachment (#2648)
72f80ae88 is described below

commit 72f80ae88cb409a63103d0ec7b6d250a5c9f5863
Author: YarBor <11...@users.noreply.github.com>
AuthorDate: Fri Apr 26 09:42:14 2024 +0800

    fix: Server return with Attachment (#2648)
---
 client/client.go                                   |  1 +
 protocol/triple/triple_invoker.go                  | 59 +++++++++--------
 protocol/triple/triple_invoker_test.go             | 22 ++++---
 .../triple/triple_protocol/duplex_http_call.go     |  4 ++
 protocol/triple/triple_protocol/handler.go         | 28 ++++++--
 protocol/triple/triple_protocol/header.go          | 74 ++++++++++++++++------
 6 files changed, 128 insertions(+), 60 deletions(-)

diff --git a/client/client.go b/client/client.go
index 98ab66733..a2bcaef1e 100644
--- a/client/client.go
+++ b/client/client.go
@@ -122,6 +122,7 @@ func (cli *Client) dial(interfaceName string, info *ClientInfo, opts ...Referenc
 
 	return &Connection{refOpts: newRefOpts}, nil
 }
+
 func generateInvocation(methodName string, reqs []interface{}, resp interface{}, callType string, opts *CallOptions) (protocol.Invocation, error) {
 	var paramsRawVals []interface{}
 	for _, req := range reqs {
diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go
index e08778429..4c1eb6868 100644
--- a/protocol/triple/triple_invoker.go
+++ b/protocol/triple/triple_invoker.go
@@ -81,11 +81,18 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat
 		return &result
 	}
 
-	ctx, callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation)
+	callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation)
 	if err != nil {
 		result.SetError(err)
 		return &result
 	}
+
+	ctx, err = mergeAttachmentToOutgoing(ctx, invocation)
+	if err != nil {
+		result.SetError(err)
+		return &result
+	}
+
 	inRawLen := len(inRaw)
 
 	if !ti.clientManager.isIDL {
@@ -136,16 +143,33 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat
 	return &result
 }
 
+func mergeAttachmentToOutgoing(ctx context.Context, inv protocol.Invocation) (context.Context, error) {
+	for key, valRaw := range inv.Attachments() {
+		if str, ok := valRaw.(string); ok {
+			ctx = tri.AppendToOutgoingContext(ctx, key, str)
+			continue
+		}
+		if strs, ok := valRaw.([]string); ok {
+			for _, str := range strs {
+				ctx = tri.AppendToOutgoingContext(ctx, key, str)
+			}
+			continue
+		}
+		return ctx, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key)
+	}
+	return ctx, nil
+}
+
 // parseInvocation retrieves information from invocation.
 // it returns ctx, callType, inRaw, method, error
-func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, string, []interface{}, string, error) {
+func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (string, []interface{}, string, error) {
 	callTypeRaw, ok := invocation.GetAttribute(constant.CallTypeKey)
 	if !ok {
-		return nil, "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker")
+		return "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker")
 	}
 	callType, ok := callTypeRaw.(string)
 	if !ok {
-		return nil, "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw)
+		return "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw)
 	}
 	// please refer to methods of client.Client or code generated by new triple for the usage of inRaw and inRawLen
 	// e.g. Client.CallUnary(... req, resp []interface, ...)
@@ -153,19 +177,16 @@ func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.I
 	inRaw := invocation.ParameterRawValues()
 	method := invocation.MethodName()
 	if method == "" {
-		return nil, "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker")
+		return "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker")
 	}
 
-	ctx, err := parseAttachments(ctx, url, invocation)
-	if err != nil {
-		return nil, "", nil, "", err
-	}
+	parseAttachments(ctx, url, invocation)
 
-	return ctx, callType, inRaw, method, nil
+	return callType, inRaw, method, nil
 }
 
 // parseAttachments retrieves attachments from users passed-in and URL, then injects them into ctx
-func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, error) {
+func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) {
 	// retrieve users passed-in attachment
 	attaRaw := ctx.Value(constant.AttachmentKey)
 	if attaRaw != nil {
@@ -181,22 +202,6 @@ func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.
 			invocation.SetAttachment(key, val)
 		}
 	}
-	// inject attachments
-	for key, valRaw := range invocation.Attachments() {
-		if str, ok := valRaw.(string); ok {
-			ctx = tri.AppendToOutgoingContext(ctx, key, str)
-			continue
-		}
-		if strs, ok := valRaw.([]string); ok {
-			for _, str := range strs {
-				ctx = tri.AppendToOutgoingContext(ctx, key, str)
-			}
-			continue
-		}
-		return nil, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key)
-	}
-
-	return ctx, nil
 }
 
 // IsAvailable get available status
diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go
index 7d14d4239..e9dcc968c 100644
--- a/protocol/triple/triple_invoker_test.go
+++ b/protocol/triple/triple_invoker_test.go
@@ -19,6 +19,7 @@ package triple
 
 import (
 	"context"
+	"net/http"
 	"testing"
 
 	"dubbo.apache.org/dubbo-go/v3/common"
@@ -35,7 +36,7 @@ func Test_parseInvocation(t *testing.T) {
 		ctx    func() context.Context
 		url    *common.URL
 		invo   func() protocol.Invocation
-		expect func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error)
+		expect func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error)
 	}{
 		{
 			desc: "miss callType",
@@ -46,7 +47,7 @@ func Test_parseInvocation(t *testing.T) {
 			invo: func() protocol.Invocation {
 				return invocation.NewRPCInvocationWithOptions()
 			},
-			expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+			expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
 				assert.NotNil(t, err)
 			},
 		},
@@ -61,7 +62,7 @@ func Test_parseInvocation(t *testing.T) {
 				iv.SetAttribute(constant.CallTypeKey, 1)
 				return iv
 			},
-			expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+			expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
 				assert.NotNil(t, err)
 			},
 		},
@@ -76,7 +77,7 @@ func Test_parseInvocation(t *testing.T) {
 				iv.SetAttribute(constant.CallTypeKey, constant.CallUnary)
 				return iv
 			},
-			expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
+			expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
 				assert.NotNil(t, err)
 			},
 		},
@@ -84,8 +85,8 @@ func Test_parseInvocation(t *testing.T) {
 
 	for _, test := range tests {
 		t.Run(test.desc, func(t *testing.T) {
-			ctx, callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo())
-			test.expect(t, ctx, callType, inRaw, methodName, err)
+			callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo())
+			test.expect(t, callType, inRaw, methodName, err)
 		})
 	}
 }
@@ -112,7 +113,7 @@ func Test_parseAttachments(t *testing.T) {
 			},
 			expect: func(t *testing.T, ctx context.Context, err error) {
 				assert.Nil(t, err)
-				header := tri.ExtractFromOutgoingContext(ctx)
+				header := http.Header(tri.ExtractFromOutgoingContext(ctx))
 				assert.NotNil(t, header)
 				assert.Equal(t, "interface", header.Get(constant.InterfaceKey))
 				assert.Equal(t, "token", header.Get(constant.TokenKey))
@@ -132,7 +133,7 @@ func Test_parseAttachments(t *testing.T) {
 			},
 			expect: func(t *testing.T, ctx context.Context, err error) {
 				assert.Nil(t, err)
-				header := tri.ExtractFromOutgoingContext(ctx)
+				header := http.Header(tri.ExtractFromOutgoingContext(ctx))
 				assert.NotNil(t, header)
 				assert.Equal(t, "val1", header.Get("key1"))
 				assert.Equal(t, []string{"key2_1", "key2_2"}, header.Values("key2"))
@@ -157,7 +158,10 @@ func Test_parseAttachments(t *testing.T) {
 
 	for _, test := range tests {
 		t.Run(test.desc, func(t *testing.T) {
-			ctx, err := parseAttachments(test.ctx(), test.url, test.invo())
+			ctx := test.ctx()
+			inv := test.invo()
+			parseAttachments(ctx, test.url, inv)
+			ctx, err := mergeAttachmentToOutgoing(ctx, inv)
 			test.expect(t, ctx, err)
 		})
 	}
diff --git a/protocol/triple/triple_protocol/duplex_http_call.go b/protocol/triple/triple_protocol/duplex_http_call.go
index 3865a56ab..6c4c02176 100644
--- a/protocol/triple/triple_protocol/duplex_http_call.go
+++ b/protocol/triple/triple_protocol/duplex_http_call.go
@@ -184,6 +184,10 @@ func (d *duplexHTTPCall) CloseRead() error {
 	if err := discard(d.response.Body); err != nil {
 		return wrapIfRSTError(err)
 	}
+	// Return incoming data via context, if set outgoing data.
+	if ExtractFromOutgoingContext(d.ctx) != nil {
+		newIncomingContext(d.ctx, d.ResponseTrailer())
+	}
 	return wrapIfRSTError(d.response.Body.Close())
 }
 
diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go
index 56f83b3fe..7a44f8535 100644
--- a/protocol/triple/triple_protocol/handler.go
+++ b/protocol/triple/triple_protocol/handler.go
@@ -112,6 +112,10 @@ func generateUnaryHandlerFunc(
 		// merge headers
 		mergeHeaders(conn.ResponseHeader(), response.Header())
 		mergeHeaders(conn.ResponseTrailer(), response.Trailer())
+		//Write the server-side return-attachment-data in the tailer to send to the caller
+		if data := ExtractFromOutgoingContext(ctx); data != nil {
+			mergeHeaders(conn.ResponseTrailer(), data)
+		}
 		return conn.Send(response.Any())
 	}
 
@@ -160,6 +164,9 @@ func generateClientStreamHandlerFunc(
 		}
 		mergeHeaders(conn.ResponseHeader(), res.header)
 		mergeHeaders(conn.ResponseTrailer(), res.trailer)
+		if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
+			mergeHeaders(conn.ResponseTrailer(), outgoingData)
+		}
 		return conn.Send(res.Msg)
 	}
 	if interceptor != nil {
@@ -205,7 +212,7 @@ func generateServerStreamHandlerFunc(
 		}
 		// embed header in context so that user logic could process them via FromIncomingContext
 		ctx = newIncomingContext(ctx, conn.RequestHeader())
-		return streamFunc(
+		err := streamFunc(
 			ctx,
 			&Request{
 				Msg:    req,
@@ -215,6 +222,13 @@ func generateServerStreamHandlerFunc(
 			},
 			&ServerStream{conn: conn},
 		)
+		if err != nil {
+			return err
+		}
+		if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
+			mergeHeaders(conn.ResponseTrailer(), outgoingData)
+		}
+		return nil
 	}
 	if interceptor != nil {
 		implementation = interceptor.WrapStreamingHandler(implementation)
@@ -253,10 +267,14 @@ func generateBidiStreamHandlerFunc(
 	implementation := func(ctx context.Context, conn StreamingHandlerConn) error {
 		// embed header in context so that user logic could process them via FromIncomingContext
 		ctx = newIncomingContext(ctx, conn.RequestHeader())
-		return streamFunc(
-			ctx,
-			&BidiStream{conn: conn},
-		)
+		err := streamFunc(ctx, &BidiStream{conn: conn})
+		if err != nil {
+			return err
+		}
+		if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
+			mergeHeaders(conn.ResponseTrailer(), outgoingData)
+		}
+		return nil
 	}
 	if interceptor != nil {
 		implementation = interceptor.WrapStreamingHandler(implementation)
diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go
index 28618e0dc..791cf4a30 100644
--- a/protocol/triple/triple_protocol/header.go
+++ b/protocol/triple/triple_protocol/header.go
@@ -19,6 +19,7 @@ import (
 	"encoding/base64"
 	"fmt"
 	"net/http"
+	"strings"
 )
 
 // EncodeBinaryHeader base64-encodes the data. It always emits unpadded values.
@@ -88,20 +89,45 @@ func addHeaderCanonical(h http.Header, key, value string) {
 	h[key] = append(h[key], value)
 }
 
-type headerIncomingKey struct{}
-type headerOutgoingKey struct{}
+type extraDataKey struct{}
+
+const headerIncomingKey string = "headerIncomingKey"
+const headerOutgoingKey string = "headerOutgoingKey"
+
 type handlerOutgoingKey struct{}
 
-func newIncomingContext(ctx context.Context, header http.Header) context.Context {
-	return context.WithValue(ctx, headerIncomingKey{}, header)
+func newIncomingContext(ctx context.Context, data http.Header) context.Context {
+	var header = http.Header{}
+	extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+	if !ok {
+		extraData = map[string]http.Header{}
+	}
+	if data != nil {
+		for key, vals := range data {
+			header[strings.ToLower(key)] = vals
+		}
+	}
+	extraData[headerIncomingKey] = header
+	return context.WithValue(ctx, extraDataKey{}, extraData)
 }
 
 // NewOutgoingContext sets headers entirely. If there are existing headers, they would be replaced.
 // It is used for passing headers to server-side.
 // It is like grpc.NewOutgoingContext.
 // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata.
-func NewOutgoingContext(ctx context.Context, header http.Header) context.Context {
-	return context.WithValue(ctx, headerOutgoingKey{}, header)
+func NewOutgoingContext(ctx context.Context, data http.Header) context.Context {
+	var header = http.Header{}
+	if data != nil {
+		for key, vals := range data {
+			header[strings.ToLower(key)] = vals
+		}
+	}
+	extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+	if !ok {
+		extraData = map[string]http.Header{}
+	}
+	extraData[headerOutgoingKey] = header
+	return context.WithValue(ctx, extraDataKey{}, extraData)
 }
 
 // AppendToOutgoingContext merges kv pairs from user and existing headers.
@@ -112,37 +138,47 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context
 	if len(kv)%2 == 1 {
 		panic(fmt.Sprintf("AppendToOutgoingContext got an odd number of input pairs for header: %d", len(kv)))
 	}
-	var header http.Header
-	headerRaw := ctx.Value(headerOutgoingKey{})
-	if headerRaw == nil {
+	extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+	if !ok {
+		extraData = map[string]http.Header{}
+		ctx = context.WithValue(ctx, extraDataKey{}, extraData)
+	}
+	header, ok := extraData[headerOutgoingKey]
+	if !ok {
 		header = make(http.Header)
-	} else {
-		header = headerRaw.(http.Header)
+		extraData[headerOutgoingKey] = header
 	}
 	for i := 0; i < len(kv); i += 2 {
 		// todo(DMwangnima): think about lowering
-		header.Add(kv[i], kv[i+1])
+		header.Add(strings.ToLower(kv[i]), kv[i+1])
 	}
-	return context.WithValue(ctx, headerOutgoingKey{}, header)
+	return ctx
 }
 
 func ExtractFromOutgoingContext(ctx context.Context) http.Header {
-	headerRaw := ctx.Value(headerOutgoingKey{})
-	if headerRaw == nil {
+	extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
+	if !ok {
 		return nil
 	}
-	// since headerOutgoingKey is only used in triple_protocol package, we need not verify the type
-	return headerRaw.(http.Header)
+	if outGoingDataHeader, ok := extraData[headerOutgoingKey]; !ok {
+		return nil
+	} else {
+		return outGoingDataHeader
+	}
 }
 
 // FromIncomingContext retrieves headers passed by client-side. It is like grpc.FromIncomingContext.
+// it must call after append/setOutgoingContext to return current value
 // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#receiving-metadata-1.
 func FromIncomingContext(ctx context.Context) (http.Header, bool) {
-	header, ok := ctx.Value(headerIncomingKey{}).(http.Header)
+	data, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
 	if !ok {
 		return nil, false
+	} else if incomingDataHeader, ok := data[headerIncomingKey]; !ok {
+		return nil, false
+	} else {
+		return incomingDataHeader, true
 	}
-	return header, true
 }
 
 // SetHeader is used for setting response header in server-side. It is like grpc.SendHeader(ctx, header) but