You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2023/06/14 19:49:47 UTC

[arrow] branch main updated: GH-36070: [Go][Flight] Add Flight Client Cookie Middleware (#36071)

This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new b4ac585ecb GH-36070: [Go][Flight] Add Flight Client Cookie Middleware (#36071)
b4ac585ecb is described below

commit b4ac585ecb4da610cc64e346e564ca86594aec53
Author: Matt Topol <zo...@gmail.com>
AuthorDate: Wed Jun 14 15:49:38 2023 -0400

    GH-36070: [Go][Flight] Add Flight Client Cookie Middleware (#36071)
    
    
    
    ### Rationale for this change
    See https://github.com/apache/arrow-adbc/issues/716
    
    ### What changes are included in this PR?
    `NewClientCookieMiddleware` function is added to the Flight package which returns a `ClientMiddleware` which can be used with flight and flightsql clients.
    
    ### Are these changes tested?
    Yes.
    
    ### Are there any user-facing changes?
    No.
    
    * Closes: #36070
    
    Authored-by: Matt Topol <zo...@gmail.com>
    Signed-off-by: Matt Topol <zo...@gmail.com>
---
 go/arrow/flight/cookie_middleware.go      | 122 +++++++++++++++
 go/arrow/flight/cookie_middleware_test.go | 241 ++++++++++++++++++++++++++++++
 2 files changed, 363 insertions(+)

diff --git a/go/arrow/flight/cookie_middleware.go b/go/arrow/flight/cookie_middleware.go
new file mode 100644
index 0000000000..27754a13b8
--- /dev/null
+++ b/go/arrow/flight/cookie_middleware.go
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight
+
+import (
+	"context"
+	"net/http"
+	"strings"
+	"sync"
+	"time"
+
+	"google.golang.org/grpc/metadata"
+)
+
+// endOfTime is the time when session (non-persistent) cookies expire.
+// This instant is representable in most date/time formats (not just
+// Go's time.Time) and should be far enough in the future.
+// taken from Go's net/http/cookiejar/jar.go
+var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
+
+// NewClientCookieMiddleware returns a go-routine safe middleware for flight
+// clients which properly handles Set-Cookie headers to store cookies
+// in a cookie jar, and then requests are sent with those cookies added
+// as a Cookie header.
+func NewClientCookieMiddleware() ClientMiddleware {
+	return CreateClientMiddleware(&clientCookieMiddleware{jar: make(map[string]http.Cookie)})
+}
+
+type clientCookieMiddleware struct {
+	jar map[string]http.Cookie
+	mx  sync.Mutex
+}
+
+func (cc *clientCookieMiddleware) StartCall(ctx context.Context) context.Context {
+	cc.mx.Lock()
+	defer cc.mx.Unlock()
+
+	if len(cc.jar) == 0 {
+		return ctx
+	}
+
+	now := time.Now()
+
+	// Per RFC 6265 section 5.4, rather than adding multiple cookie strings
+	// or multiple cookie headers, multiple cookies are all sent as a single
+	// header value separated by semicolons.
+
+	// we will also clear any expired cookies from the jar while we determine
+	// the cookies to send.
+	cookies := make([]string, 0, len(cc.jar))
+	for id, c := range cc.jar {
+		if !c.Expires.After(now) {
+			delete(cc.jar, id)
+			continue
+		}
+
+		cookies = append(cookies, (&http.Cookie{Name: c.Name, Value: c.Value}).String())
+	}
+
+	if len(cookies) == 0 {
+		return ctx
+	}
+
+	return metadata.AppendToOutgoingContext(ctx, "Cookie", strings.Join(cookies, ";"))
+}
+
+func processCookieExpire(c *http.Cookie, now time.Time) (remove bool) {
+	// MaxAge takes precedence over Expires
+	if c.MaxAge < 0 {
+		return true
+	} else if c.MaxAge > 0 {
+		c.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
+	} else {
+		if c.Expires.IsZero() {
+			c.Expires = endOfTime
+		} else {
+			if !c.Expires.After(now) {
+				return true
+			}
+		}
+	}
+
+	return
+}
+
+func (cc *clientCookieMiddleware) HeadersReceived(ctx context.Context, md metadata.MD) {
+	// instead of replicating the logic for processing the Set-Cookie
+	// header, let's just make a fake response and use the built-in
+	// cookie processing. It's very non-trivial
+	cookies := (&http.Response{
+		Header: http.Header{"Set-Cookie": md.Get("set-cookie")},
+	}).Cookies()
+
+	now := time.Now()
+
+	cc.mx.Lock()
+	defer cc.mx.Unlock()
+
+	for _, c := range cookies {
+		id := c.Name + c.Path
+		if processCookieExpire(c, now) {
+			delete(cc.jar, id)
+			continue
+		}
+
+		cc.jar[id] = *c
+	}
+}
diff --git a/go/arrow/flight/cookie_middleware_test.go b/go/arrow/flight/cookie_middleware_test.go
new file mode 100644
index 0000000000..e48e9e6577
--- /dev/null
+++ b/go/arrow/flight/cookie_middleware_test.go
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight_test
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"net/textproto"
+	"reflect"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/apache/arrow/go/v13/arrow/flight"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/credentials/insecure"
+	"google.golang.org/grpc/metadata"
+)
+
+// strings.Cut is go1.18+ so let's just stick a duplicate of it in here
+// for now since we want to support go1.17
+func cut(s, sep string) (before, after string, found bool) {
+	if i := strings.Index(s, sep); i >= 0 {
+		return s[:i], s[i+len(sep):], true
+	}
+	return s, "", false
+}
+
+type serverAddCookieMiddleware struct {
+	expectedCookies map[string]string
+
+	cookies []*http.Cookie
+}
+
+func (s *serverAddCookieMiddleware) StartCall(ctx context.Context) context.Context {
+	if s.expectedCookies == nil {
+		md := make(metadata.MD)
+		for _, c := range s.cookies {
+			md.Append("Set-Cookie", c.String())
+		}
+		grpc.SetHeader(ctx, md)
+		return nil
+	}
+
+	cookies := metadata.ValueFromIncomingContext(ctx, "cookie")
+
+	got := make(map[string]string)
+	for _, line := range cookies {
+		line = textproto.TrimString(line)
+
+		var part string
+		for len(line) > 0 {
+			part, line, _ = cut(line, ";")
+			part = textproto.TrimString(part)
+			if part == "" {
+				continue
+			}
+
+			name, val, _ := cut(part, "=")
+			name = textproto.TrimString(name)
+			if len(val) > 1 && val[0] == '"' && val[len(val)-1] == '"' {
+				val = val[1 : len(val)-1]
+			}
+
+			got[name] = val
+		}
+	}
+
+	if !reflect.DeepEqual(s.expectedCookies, got) {
+		panic(fmt.Sprintf("did not get expected cookies, expected %+v, got %+v", s.expectedCookies, got))
+	}
+
+	return nil
+}
+
+func (s *serverAddCookieMiddleware) CallCompleted(ctx context.Context, err error) {}
+
+func TestClientCookieMiddleware(t *testing.T) {
+	cookieMiddleware := &serverAddCookieMiddleware{}
+
+	s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+		flight.CreateServerMiddleware(cookieMiddleware),
+	})
+	s.Init("localhost:0")
+	f := &flightServer{}
+	s.RegisterFlightService(f)
+
+	go s.Serve()
+	defer s.Shutdown()
+
+	credsOpt := grpc.WithTransportCredentials(insecure.NewCredentials())
+
+	tests := []struct {
+		testname string
+		cookies  []*http.Cookie
+		expected map[string]string
+	}{
+		{"single cookie", []*http.Cookie{{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}},
+			map[string]string{"Cookie-1": "v$1"}},
+		{"expired", []*http.Cookie{{
+			Name: "NID", Value: "99=YsDT5", Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC),
+			RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", Raw: "NID=99=YsDT5; expires=Wed, 23-Nov-11 01:05:03 GMT"}},
+			map[string]string{}},
+		{"multiple", []*http.Cookie{
+			{Name: "negative maxage", Value: "foobar", MaxAge: -1},
+			{Name: "special-1", Value: " z"},
+			{Name: "cookie-2", Value: "v$2"},
+		},
+			map[string]string{"special-1": " z", "cookie-2": "v$2"}},
+	}
+
+	makeReq := func(c flight.Client, t *testing.T) {
+		flightStream, err := c.ListFlights(context.Background(), &flight.Criteria{})
+		assert.NoError(t, err)
+
+		for {
+			_, err := flightStream.Recv()
+			if err != nil {
+				if errors.Is(err, io.EOF) {
+					break
+				}
+				assert.NoError(t, err)
+			}
+		}
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.testname, func(t *testing.T) {
+			cookieMiddleware.expectedCookies = nil
+
+			client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil,
+				[]flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, credsOpt)
+			require.NoError(t, err)
+			defer client.Close()
+
+			cookieMiddleware.cookies = tt.cookies
+			makeReq(client, t)
+
+			cookieMiddleware.expectedCookies = tt.expected
+			makeReq(client, t)
+		})
+	}
+}
+
+func TestCookieExpiration(t *testing.T) {
+	cookieMiddleware := &serverAddCookieMiddleware{}
+
+	s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+		flight.CreateServerMiddleware(cookieMiddleware),
+	})
+	s.Init("localhost:0")
+	f := &flightServer{}
+	s.RegisterFlightService(f)
+
+	go s.Serve()
+	defer s.Shutdown()
+
+	makeReq := func(c flight.Client, t *testing.T) {
+		flightStream, err := c.ListFlights(context.Background(), &flight.Criteria{})
+		assert.NoError(t, err)
+
+		for {
+			_, err := flightStream.Recv()
+			if err != nil {
+				if errors.Is(err, io.EOF) {
+					break
+				}
+				assert.NoError(t, err)
+			}
+		}
+	}
+
+	credsOpt := grpc.WithTransportCredentials(insecure.NewCredentials())
+	client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil,
+		[]flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, credsOpt)
+	require.NoError(t, err)
+	defer client.Close()
+
+	// set cookies
+	cookieMiddleware.cookies = []*http.Cookie{
+		{Name: "foo", Value: "bar"},
+		{Name: "foo2", Value: "bar2", MaxAge: 1},
+	}
+	makeReq(client, t)
+
+	// validate set
+	cookieMiddleware.expectedCookies = map[string]string{
+		"foo": "bar", "foo2": "bar2",
+	}
+	makeReq(client, t)
+
+	// wait for foo2 to expire and validate it doesn't get sent
+	time.Sleep(1 * time.Second)
+	cookieMiddleware.expectedCookies = map[string]string{
+		"foo": "bar",
+	}
+	makeReq(client, t)
+
+	// update value
+	cookieMiddleware.cookies = []*http.Cookie{
+		{Name: "foo", Value: "baz"},
+	}
+	cookieMiddleware.expectedCookies = nil
+	makeReq(client, t)
+
+	// validate updated value is sent
+	cookieMiddleware.expectedCookies = map[string]string{
+		"foo": "baz",
+	}
+	makeReq(client, t)
+
+	// force delete cookie
+	cookieMiddleware.expectedCookies = nil
+	cookieMiddleware.cookies = []*http.Cookie{
+		{Name: "foo", MaxAge: -1}, // delete now!
+	}
+	makeReq(client, t)
+
+	// verify it's been deleted
+	cookieMiddleware.expectedCookies = map[string]string{}
+	makeReq(client, t)
+}