You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2023/06/14 19:49:47 UTC
[arrow] branch main updated: GH-36070: [Go][Flight] Add Flight Client Cookie Middleware (#36071)
This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new b4ac585ecb GH-36070: [Go][Flight] Add Flight Client Cookie Middleware (#36071)
b4ac585ecb is described below
commit b4ac585ecb4da610cc64e346e564ca86594aec53
Author: Matt Topol <zo...@gmail.com>
AuthorDate: Wed Jun 14 15:49:38 2023 -0400
GH-36070: [Go][Flight] Add Flight Client Cookie Middleware (#36071)
### Rationale for this change
See https://github.com/apache/arrow-adbc/issues/716
### What changes are included in this PR?
`NewClientCookieMiddleware` function is added to the Flight package which returns a `ClientMiddleware` which can be used with flight and flightsql clients.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
No.
* Closes: #36070
Authored-by: Matt Topol <zo...@gmail.com>
Signed-off-by: Matt Topol <zo...@gmail.com>
---
go/arrow/flight/cookie_middleware.go | 122 +++++++++++++++
go/arrow/flight/cookie_middleware_test.go | 241 ++++++++++++++++++++++++++++++
2 files changed, 363 insertions(+)
diff --git a/go/arrow/flight/cookie_middleware.go b/go/arrow/flight/cookie_middleware.go
new file mode 100644
index 0000000000..27754a13b8
--- /dev/null
+++ b/go/arrow/flight/cookie_middleware.go
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight
+
+import (
+ "context"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "google.golang.org/grpc/metadata"
+)
+
+// endOfTime is the time when session (non-persistent) cookies expire.
+// This instant is representable in most date/time formats (not just
+// Go's time.Time) and should be far enough in the future.
+// taken from Go's net/http/cookiejar/jar.go
+var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
+
+// NewClientCookieMiddleware returns a go-routine safe middleware for flight
+// clients which properly handles Set-Cookie headers to store cookies
+// in a cookie jar, and then requests are sent with those cookies added
+// as a Cookie header.
+func NewClientCookieMiddleware() ClientMiddleware {
+ return CreateClientMiddleware(&clientCookieMiddleware{jar: make(map[string]http.Cookie)})
+}
+
+type clientCookieMiddleware struct {
+ jar map[string]http.Cookie
+ mx sync.Mutex
+}
+
+func (cc *clientCookieMiddleware) StartCall(ctx context.Context) context.Context {
+ cc.mx.Lock()
+ defer cc.mx.Unlock()
+
+ if len(cc.jar) == 0 {
+ return ctx
+ }
+
+ now := time.Now()
+
+ // Per RFC 6265 section 5.4, rather than adding multiple cookie strings
+ // or multiple cookie headers, multiple cookies are all sent as a single
+ // header value separated by semicolons.
+
+ // we will also clear any expired cookies from the jar while we determine
+ // the cookies to send.
+ cookies := make([]string, 0, len(cc.jar))
+ for id, c := range cc.jar {
+ if !c.Expires.After(now) {
+ delete(cc.jar, id)
+ continue
+ }
+
+ cookies = append(cookies, (&http.Cookie{Name: c.Name, Value: c.Value}).String())
+ }
+
+ if len(cookies) == 0 {
+ return ctx
+ }
+
+ return metadata.AppendToOutgoingContext(ctx, "Cookie", strings.Join(cookies, ";"))
+}
+
+func processCookieExpire(c *http.Cookie, now time.Time) (remove bool) {
+ // MaxAge takes precedence over Expires
+ if c.MaxAge < 0 {
+ return true
+ } else if c.MaxAge > 0 {
+ c.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
+ } else {
+ if c.Expires.IsZero() {
+ c.Expires = endOfTime
+ } else {
+ if !c.Expires.After(now) {
+ return true
+ }
+ }
+ }
+
+ return
+}
+
+func (cc *clientCookieMiddleware) HeadersReceived(ctx context.Context, md metadata.MD) {
+ // instead of replicating the logic for processing the Set-Cookie
+ // header, let's just make a fake response and use the built-in
+ // cookie processing. It's very non-trivial
+ cookies := (&http.Response{
+ Header: http.Header{"Set-Cookie": md.Get("set-cookie")},
+ }).Cookies()
+
+ now := time.Now()
+
+ cc.mx.Lock()
+ defer cc.mx.Unlock()
+
+ for _, c := range cookies {
+ id := c.Name + c.Path
+ if processCookieExpire(c, now) {
+ delete(cc.jar, id)
+ continue
+ }
+
+ cc.jar[id] = *c
+ }
+}
diff --git a/go/arrow/flight/cookie_middleware_test.go b/go/arrow/flight/cookie_middleware_test.go
new file mode 100644
index 0000000000..e48e9e6577
--- /dev/null
+++ b/go/arrow/flight/cookie_middleware_test.go
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight_test
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/textproto"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/apache/arrow/go/v13/arrow/flight"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/grpc/metadata"
+)
+
+// strings.Cut is go1.18+ so let's just stick a duplicate of it in here
+// for now since we want to support go1.17
+func cut(s, sep string) (before, after string, found bool) {
+ if i := strings.Index(s, sep); i >= 0 {
+ return s[:i], s[i+len(sep):], true
+ }
+ return s, "", false
+}
+
+type serverAddCookieMiddleware struct {
+ expectedCookies map[string]string
+
+ cookies []*http.Cookie
+}
+
+func (s *serverAddCookieMiddleware) StartCall(ctx context.Context) context.Context {
+ if s.expectedCookies == nil {
+ md := make(metadata.MD)
+ for _, c := range s.cookies {
+ md.Append("Set-Cookie", c.String())
+ }
+ grpc.SetHeader(ctx, md)
+ return nil
+ }
+
+ cookies := metadata.ValueFromIncomingContext(ctx, "cookie")
+
+ got := make(map[string]string)
+ for _, line := range cookies {
+ line = textproto.TrimString(line)
+
+ var part string
+ for len(line) > 0 {
+ part, line, _ = cut(line, ";")
+ part = textproto.TrimString(part)
+ if part == "" {
+ continue
+ }
+
+ name, val, _ := cut(part, "=")
+ name = textproto.TrimString(name)
+ if len(val) > 1 && val[0] == '"' && val[len(val)-1] == '"' {
+ val = val[1 : len(val)-1]
+ }
+
+ got[name] = val
+ }
+ }
+
+ if !reflect.DeepEqual(s.expectedCookies, got) {
+ panic(fmt.Sprintf("did not get expected cookies, expected %+v, got %+v", s.expectedCookies, got))
+ }
+
+ return nil
+}
+
+func (s *serverAddCookieMiddleware) CallCompleted(ctx context.Context, err error) {}
+
+func TestClientCookieMiddleware(t *testing.T) {
+ cookieMiddleware := &serverAddCookieMiddleware{}
+
+ s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+ flight.CreateServerMiddleware(cookieMiddleware),
+ })
+ s.Init("localhost:0")
+ f := &flightServer{}
+ s.RegisterFlightService(f)
+
+ go s.Serve()
+ defer s.Shutdown()
+
+ credsOpt := grpc.WithTransportCredentials(insecure.NewCredentials())
+
+ tests := []struct {
+ testname string
+ cookies []*http.Cookie
+ expected map[string]string
+ }{
+ {"single cookie", []*http.Cookie{{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}},
+ map[string]string{"Cookie-1": "v$1"}},
+ {"expired", []*http.Cookie{{
+ Name: "NID", Value: "99=YsDT5", Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC),
+ RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", Raw: "NID=99=YsDT5; expires=Wed, 23-Nov-11 01:05:03 GMT"}},
+ map[string]string{}},
+ {"multiple", []*http.Cookie{
+ {Name: "negative maxage", Value: "foobar", MaxAge: -1},
+ {Name: "special-1", Value: " z"},
+ {Name: "cookie-2", Value: "v$2"},
+ },
+ map[string]string{"special-1": " z", "cookie-2": "v$2"}},
+ }
+
+ makeReq := func(c flight.Client, t *testing.T) {
+ flightStream, err := c.ListFlights(context.Background(), &flight.Criteria{})
+ assert.NoError(t, err)
+
+ for {
+ _, err := flightStream.Recv()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ assert.NoError(t, err)
+ }
+ }
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.testname, func(t *testing.T) {
+ cookieMiddleware.expectedCookies = nil
+
+ client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil,
+ []flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, credsOpt)
+ require.NoError(t, err)
+ defer client.Close()
+
+ cookieMiddleware.cookies = tt.cookies
+ makeReq(client, t)
+
+ cookieMiddleware.expectedCookies = tt.expected
+ makeReq(client, t)
+ })
+ }
+}
+
+func TestCookieExpiration(t *testing.T) {
+ cookieMiddleware := &serverAddCookieMiddleware{}
+
+ s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+ flight.CreateServerMiddleware(cookieMiddleware),
+ })
+ s.Init("localhost:0")
+ f := &flightServer{}
+ s.RegisterFlightService(f)
+
+ go s.Serve()
+ defer s.Shutdown()
+
+ makeReq := func(c flight.Client, t *testing.T) {
+ flightStream, err := c.ListFlights(context.Background(), &flight.Criteria{})
+ assert.NoError(t, err)
+
+ for {
+ _, err := flightStream.Recv()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ assert.NoError(t, err)
+ }
+ }
+ }
+
+ credsOpt := grpc.WithTransportCredentials(insecure.NewCredentials())
+ client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil,
+ []flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, credsOpt)
+ require.NoError(t, err)
+ defer client.Close()
+
+ // set cookies
+ cookieMiddleware.cookies = []*http.Cookie{
+ {Name: "foo", Value: "bar"},
+ {Name: "foo2", Value: "bar2", MaxAge: 1},
+ }
+ makeReq(client, t)
+
+ // validate set
+ cookieMiddleware.expectedCookies = map[string]string{
+ "foo": "bar", "foo2": "bar2",
+ }
+ makeReq(client, t)
+
+ // wait for foo2 to expire and validate it doesn't get sent
+ time.Sleep(1 * time.Second)
+ cookieMiddleware.expectedCookies = map[string]string{
+ "foo": "bar",
+ }
+ makeReq(client, t)
+
+ // update value
+ cookieMiddleware.cookies = []*http.Cookie{
+ {Name: "foo", Value: "baz"},
+ }
+ cookieMiddleware.expectedCookies = nil
+ makeReq(client, t)
+
+ // validate updated value is sent
+ cookieMiddleware.expectedCookies = map[string]string{
+ "foo": "baz",
+ }
+ makeReq(client, t)
+
+ // force delete cookie
+ cookieMiddleware.expectedCookies = nil
+ cookieMiddleware.cookies = []*http.Cookie{
+ {Name: "foo", MaxAge: -1}, // delete now!
+ }
+ makeReq(client, t)
+
+ // verify it's been deleted
+ cookieMiddleware.expectedCookies = map[string]string{}
+ makeReq(client, t)
+}