You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/12 16:48:40 UTC
[arrow-adbc] branch main updated: refactor(go/adbc/driver/flightsql): factor out server-based tests (#763)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 52439b91 refactor(go/adbc/driver/flightsql): factor out server-based tests (#763)
52439b91 is described below
commit 52439b9143d1321bf711de28281be487765a2c3c
Author: David Li <li...@gmail.com>
AuthorDate: Mon Jun 12 12:48:35 2023 -0400
refactor(go/adbc/driver/flightsql): factor out server-based tests (#763)
Make it slightly cleaner to set up tests that use a custom Flight SQL
server.
Fixes #699.
---
.../driver/flightsql/flightsql_adbc_server_test.go | 426 +++++++++++++++++++++
go/adbc/driver/flightsql/flightsql_adbc_test.go | 370 +-----------------
2 files changed, 428 insertions(+), 368 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
new file mode 100644
index 00000000..c4b5524f
--- /dev/null
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -0,0 +1,426 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Tests that use custom server implementations.
+
+package flightsql_test
+
+import (
+ "context"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ driver "github.com/apache/arrow-adbc/go/adbc/driver/flightsql"
+ "github.com/apache/arrow/go/v13/arrow"
+ "github.com/apache/arrow/go/v13/arrow/array"
+ "github.com/apache/arrow/go/v13/arrow/flight"
+ "github.com/apache/arrow/go/v13/arrow/flight/flightsql"
+ "github.com/apache/arrow/go/v13/arrow/memory"
+ "github.com/stretchr/testify/suite"
+ "golang.org/x/exp/maps"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
+)
+
+// ---- Common Infra --------------------
+
+type ServerBasedTests struct {
+ suite.Suite
+
+ s flight.Server
+ db adbc.Database
+ cnxn adbc.Connection
+}
+
+func (suite *ServerBasedTests) DoSetupSuite(srv flightsql.Server, srvMiddleware []flight.ServerMiddleware, dbArgs map[string]string) {
+ suite.s = flight.NewServerWithMiddleware(srvMiddleware)
+ suite.s.RegisterFlightService(flightsql.NewFlightServer(srv))
+ suite.Require().NoError(suite.s.Init("localhost:0"))
+ suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
+ go func() {
+ _ = suite.s.Serve()
+ }()
+
+ uri := "grpc+tcp://" + suite.s.Addr().String()
+ var err error
+
+ args := map[string]string{
+ "uri": uri,
+ }
+ maps.Copy(args, dbArgs)
+ suite.db, err = (driver.Driver{}).NewDatabase(args)
+ suite.Require().NoError(err)
+}
+
+func (suite *ServerBasedTests) DoSetupTest() {
+ var err error
+ suite.cnxn, err = suite.db.Open(context.Background())
+ suite.Require().NoError(err)
+}
+
+func (suite *ServerBasedTests) DoTearDownTest() {
+ suite.Require().NoError(suite.cnxn.Close())
+}
+
+func (suite *ServerBasedTests) DoTearDownSuite() {
+ suite.db = nil
+ suite.s.Shutdown()
+}
+
+// ---- Tests --------------------
+
+func TestAuthn(t *testing.T) {
+ suite.Run(t, &AuthnTests{})
+}
+
+func TestTimeout(t *testing.T) {
+ suite.Run(t, &TimeoutTests{})
+}
+
+// ---- AuthN Tests --------------------
+
+type AuthnTestServer struct {
+ flightsql.BaseServer
+}
+
+func (server *AuthnTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
+ md := metadata.MD{}
+ md.Set("authorization", "Bearer final")
+ if err := grpc.SendHeader(ctx, md); err != nil {
+ return nil, err
+ }
+ tkt, _ := flightsql.CreateStatementQueryTicket([]byte{})
+ info := &flight.FlightInfo{
+ FlightDescriptor: desc,
+ Endpoint: []*flight.FlightEndpoint{
+ {Ticket: &flight.Ticket{Ticket: tkt}},
+ },
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }
+ return info, nil
+}
+
+func (server *AuthnTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
+ sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
+ if err != nil {
+ return nil, nil, err
+ }
+
+ ch := make(chan flight.StreamChunk)
+ go func() {
+ defer close(ch)
+ ch <- flight.StreamChunk{
+ Data: rec,
+ Desc: nil,
+ Err: nil,
+ }
+ }()
+ return sc, ch, nil
+}
+
+func authnTestUnary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+ md, ok := metadata.FromIncomingContext(ctx)
+ if !ok {
+ return nil, status.Error(codes.InvalidArgument, "Could not get metadata")
+ }
+ auth := md.Get("authorization")
+ if len(auth) == 0 {
+ return nil, status.Error(codes.Unauthenticated, "No token")
+ } else if auth[0] != "Bearer initial" {
+ return nil, status.Error(codes.Unauthenticated, "Invalid token for unary call: "+auth[0])
+ }
+
+ md.Set("authorization", "Bearer final")
+ ctx = metadata.NewOutgoingContext(ctx, md)
+ return handler(ctx, req)
+}
+
+func authnTestStream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ md, ok := metadata.FromIncomingContext(ss.Context())
+ if !ok {
+ return status.Error(codes.InvalidArgument, "Could not get metadata")
+ }
+ auth := md.Get("authorization")
+ if len(auth) == 0 {
+ return status.Error(codes.Unauthenticated, "No token")
+ } else if auth[0] != "Bearer final" {
+ return status.Error(codes.Unauthenticated, "Invalid token for stream call: "+auth[0])
+ }
+
+ return handler(srv, ss)
+}
+
+type AuthnTests struct {
+ ServerBasedTests
+}
+
+func (suite *AuthnTests) SetupSuite() {
+ suite.DoSetupSuite(&AuthnTestServer{}, []flight.ServerMiddleware{
+ {Stream: authnTestStream, Unary: authnTestUnary},
+ }, map[string]string{
+ driver.OptionAuthorizationHeader: "Bearer initial",
+ })
+}
+
+func (suite *AuthnTests) SetupTest() {
+ suite.DoSetupTest()
+}
+
+func (suite *AuthnTests) TearDownTest() {
+ suite.DoTearDownTest()
+}
+
+func (suite *AuthnTests) TearDownSuite() {
+ suite.DoTearDownSuite()
+}
+
+func (suite *AuthnTests) TestBearerTokenUpdated() {
+ // apache/arrow-adbc#584: when setting the auth header directly, the client should use any updated token value from the server if given
+ stmt, err := suite.cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer stmt.Close()
+
+ suite.Require().NoError(stmt.SetSqlQuery("timeout"))
+ reader, _, err := stmt.ExecuteQuery(context.Background())
+ suite.NoError(err)
+ defer reader.Release()
+}
+
+// ---- Timeout Tests --------------------
+
+type TimeoutTestServer struct {
+ flightsql.BaseServer
+}
+
+func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
+ if string(tkt.GetStatementHandle()) == "sleep and succeed" {
+ time.Sleep(1 * time.Second)
+ sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
+ if err != nil {
+ return nil, nil, err
+ }
+
+ ch := make(chan flight.StreamChunk)
+ go func() {
+ defer close(ch)
+ ch <- flight.StreamChunk{
+ Data: rec,
+ Desc: nil,
+ Err: nil,
+ }
+ }()
+ return sc, ch, nil
+ }
+
+ // wait till the context is cancelled
+ <-ctx.Done()
+ return nil, nil, ctx.Err()
+}
+
+func (ts *TimeoutTestServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) {
+ if cmd.GetQuery() == "timeout" {
+ <-ctx.Done()
+ return -1, ctx.Err()
+ }
+ return -1, arrow.ErrNotImplemented
+}
+
+func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
+ switch cmd.GetQuery() {
+ case "timeout":
+ <-ctx.Done()
+ case "fetch":
+ tkt, _ := flightsql.CreateStatementQueryTicket([]byte("fetch"))
+ info := &flight.FlightInfo{
+ FlightDescriptor: desc,
+ Endpoint: []*flight.FlightEndpoint{
+ {Ticket: &flight.Ticket{Ticket: tkt}},
+ },
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }
+ return info, nil
+ case "notimeout":
+ time.Sleep(1 * time.Second)
+ tkt, _ := flightsql.CreateStatementQueryTicket([]byte("sleep and succeed"))
+ info := &flight.FlightInfo{
+ FlightDescriptor: desc,
+ Endpoint: []*flight.FlightEndpoint{
+ {Ticket: &flight.Ticket{Ticket: tkt}},
+ },
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }
+ return info, nil
+ }
+ return nil, ctx.Err()
+}
+
+func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) {
+ <-ctx.Done()
+ return result, ctx.Err()
+}
+
+type TimeoutTests struct {
+ ServerBasedTests
+}
+
+func (suite *TimeoutTests) SetupSuite() {
+ suite.DoSetupSuite(&TimeoutTestServer{}, nil, nil)
+}
+
+func (suite *TimeoutTests) SetupTest() {
+ suite.DoSetupTest()
+}
+
+func (suite *TimeoutTests) TearDownTest() {
+ suite.DoTearDownTest()
+}
+
+func (suite *TimeoutTests) TearDownSuite() {
+ suite.DoTearDownSuite()
+}
+
+func (ts *TimeoutTests) TestInvalidValues() {
+ keys := []string{
+ "adbc.flight.sql.rpc.timeout_seconds.fetch",
+ "adbc.flight.sql.rpc.timeout_seconds.query",
+ "adbc.flight.sql.rpc.timeout_seconds.update",
+ }
+ values := []string{"1.1f", "asdf", "inf", "NaN", "-1"}
+
+ for _, k := range keys {
+ for _, v := range values {
+ ts.Run("key="+k+",val="+v, func() {
+ err := ts.cnxn.(adbc.PostInitOptions).SetOption(k, v)
+ var adbcErr adbc.Error
+ ts.ErrorAs(err, &adbcErr)
+ ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
+ ts.ErrorContains(err, "invalid timeout option value")
+ })
+ }
+ }
+}
+
+func (ts *TimeoutTests) TestRemoveTimeout() {
+ keys := []string{
+ "adbc.flight.sql.rpc.timeout_seconds.fetch",
+ "adbc.flight.sql.rpc.timeout_seconds.query",
+ "adbc.flight.sql.rpc.timeout_seconds.update",
+ }
+ for _, k := range keys {
+ ts.Run(k, func() {
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).SetOption(k, "1.0"))
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).SetOption(k, "0"))
+ })
+ }
+}
+
+func (ts *TimeoutTests) TestDoActionTimeout() {
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1"))
+
+ stmt, err := ts.cnxn.NewStatement()
+ ts.Require().NoError(err)
+ defer stmt.Close()
+
+ ts.Require().NoError(stmt.SetSqlQuery("fetch"))
+ var adbcErr adbc.Error
+ ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
+}
+
+func (ts *TimeoutTests) TestDoGetTimeout() {
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.fetch", "0.1"))
+
+ stmt, err := ts.cnxn.NewStatement()
+ ts.Require().NoError(err)
+ defer stmt.Close()
+
+ ts.Require().NoError(stmt.SetSqlQuery("fetch"))
+ var adbcErr adbc.Error
+ _, _, err = stmt.ExecuteQuery(context.Background())
+ ts.ErrorAs(err, &adbcErr)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
+}
+
+func (ts *TimeoutTests) TestDoPutTimeout() {
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "1.1"))
+
+ stmt, err := ts.cnxn.NewStatement()
+ ts.Require().NoError(err)
+ defer stmt.Close()
+
+ ts.Require().NoError(stmt.SetSqlQuery("timeout"))
+ var adbcErr adbc.Error
+ _, err = stmt.ExecuteUpdate(context.Background())
+ ts.ErrorAs(err, &adbcErr)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
+}
+
+func (ts *TimeoutTests) TestGetFlightInfoTimeout() {
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.query", "0.1"))
+
+ stmt, err := ts.cnxn.NewStatement()
+ ts.Require().NoError(err)
+ defer stmt.Close()
+
+ ts.Require().NoError(stmt.SetSqlQuery("timeout"))
+ var adbcErr adbc.Error
+ _, _, err = stmt.ExecuteQuery(context.Background())
+ ts.ErrorAs(err, &adbcErr)
+ ts.NotEqual(adbc.StatusNotImplemented, adbcErr.Code, adbcErr.Error())
+}
+
+func (ts *TimeoutTests) TestDontTimeout() {
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.fetch", "2.0"))
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.query", "2.0"))
+
+ stmt, err := ts.cnxn.NewStatement()
+ ts.Require().NoError(err)
+ defer stmt.Close()
+
+ ts.Require().NoError(stmt.SetSqlQuery("notimeout"))
+ // GetFlightInfo will sleep for one second and DoGet will also
+ // sleep for one second. But our timeout is 2 seconds, which is
+ // per-operation. So we shouldn't time out and all should succeed.
+ rr, _, err := stmt.ExecuteQuery(context.Background())
+ ts.Require().NoError(err)
+ defer rr.Release()
+
+ ts.True(rr.Next())
+ rec := rr.Record()
+
+ sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ expected, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
+ ts.Require().NoError(err)
+ defer expected.Release()
+ ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", expected, rec)
+}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index f8ae8469..87f859ea 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+// Tests that use the SQLite server example.
+
package flightsql_test
import (
@@ -49,10 +51,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
- "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
- "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
@@ -287,11 +287,9 @@ func TestADBCFlightSQL(t *testing.T) {
suite.Run(t, &DefaultDialOptionsTests{Quirks: q})
suite.Run(t, &HeaderTests{Quirks: q})
- suite.Run(t, &AuthnTests{})
suite.Run(t, &OptionTests{Quirks: q})
suite.Run(t, &PartitionTests{Quirks: q})
suite.Run(t, &StatementTests{Quirks: q})
- suite.Run(t, &TimeoutTestSuite{})
suite.Run(t, &TLSTests{Quirks: &FlightSQLQuirks{db: db}})
suite.Run(t, &ConnectionTests{})
suite.Run(t, &DomainSocketTests{db: db})
@@ -744,370 +742,6 @@ func (suite *HeaderTests) TestPrepared() {
suite.Contains(suite.Quirks.middle.recordedHeaders.Get("x-header-two"), "value 2")
}
-type AuthnTests struct {
- suite.Suite
-
- s flight.Server
- db adbc.Database
- cnxn adbc.Connection
-}
-
-type AuthnTestServer struct {
- flightsql.BaseServer
-}
-
-func (server *AuthnTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
- md := metadata.MD{}
- md.Set("authorization", "Bearer final")
- if err := grpc.SendHeader(ctx, md); err != nil {
- return nil, err
- }
- tkt, _ := flightsql.CreateStatementQueryTicket([]byte{})
- info := &flight.FlightInfo{
- FlightDescriptor: desc,
- Endpoint: []*flight.FlightEndpoint{
- {Ticket: &flight.Ticket{Ticket: tkt}},
- },
- TotalRecords: -1,
- TotalBytes: -1,
- }
- return info, nil
-}
-
-func (server *AuthnTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
- sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
- rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
- if err != nil {
- return nil, nil, err
- }
-
- ch := make(chan flight.StreamChunk)
- go func() {
- defer close(ch)
- ch <- flight.StreamChunk{
- Data: rec,
- Desc: nil,
- Err: nil,
- }
- }()
- return sc, ch, nil
-}
-
-func authnTestUnary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
- md, ok := metadata.FromIncomingContext(ctx)
- if !ok {
- return nil, status.Error(codes.InvalidArgument, "Could not get metadata")
- }
- auth := md.Get("authorization")
- if len(auth) == 0 {
- return nil, status.Error(codes.Unauthenticated, "No token")
- } else if auth[0] != "Bearer initial" {
- return nil, status.Error(codes.Unauthenticated, "Invalid token for unary call: "+auth[0])
- }
-
- md.Set("authorization", "Bearer final")
- ctx = metadata.NewOutgoingContext(ctx, md)
- return handler(ctx, req)
-}
-
-func authnTestStream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
- md, ok := metadata.FromIncomingContext(ss.Context())
- if !ok {
- return status.Error(codes.InvalidArgument, "Could not get metadata")
- }
- auth := md.Get("authorization")
- if len(auth) == 0 {
- return status.Error(codes.Unauthenticated, "No token")
- } else if auth[0] != "Bearer final" {
- return status.Error(codes.Unauthenticated, "Invalid token for stream call: "+auth[0])
- }
-
- return handler(srv, ss)
-}
-
-func (suite *AuthnTests) SetupSuite() {
- suite.s = flight.NewServerWithMiddleware([]flight.ServerMiddleware{
- {Stream: authnTestStream, Unary: authnTestUnary},
- })
- suite.s.RegisterFlightService(flightsql.NewFlightServer(&AuthnTestServer{}))
- suite.Require().NoError(suite.s.Init("localhost:0"))
- suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
- go func() {
- _ = suite.s.Serve()
- }()
-
- uri := "grpc+tcp://" + suite.s.Addr().String()
- var err error
- suite.db, err = (driver.Driver{}).NewDatabase(map[string]string{
- "uri": uri,
- driver.OptionAuthorizationHeader: "Bearer initial",
- })
- suite.Require().NoError(err)
-}
-
-func (suite *AuthnTests) SetupTest() {
- var err error
- suite.cnxn, err = suite.db.Open(context.Background())
- suite.Require().NoError(err)
-}
-
-func (suite *AuthnTests) TearDownTest() {
- suite.Require().NoError(suite.cnxn.Close())
-}
-
-func (suite *AuthnTests) TearDownSuite() {
- suite.db = nil
- suite.s.Shutdown()
-}
-
-func (suite *AuthnTests) TestBearerTokenUpdated() {
- // apache/arrow-adbc#584: when setting the auth header directly, the client should use any updated token value from the server if given
- stmt, err := suite.cnxn.NewStatement()
- suite.Require().NoError(err)
- defer stmt.Close()
-
- suite.Require().NoError(stmt.SetSqlQuery("timeout"))
- reader, _, err := stmt.ExecuteQuery(context.Background())
- suite.NoError(err)
- defer reader.Release()
-}
-
-type TimeoutTestServer struct {
- flightsql.BaseServer
-}
-
-func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
- if string(tkt.GetStatementHandle()) == "sleep and succeed" {
- time.Sleep(1 * time.Second)
- sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
- rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
- if err != nil {
- return nil, nil, err
- }
-
- ch := make(chan flight.StreamChunk)
- go func() {
- defer close(ch)
- ch <- flight.StreamChunk{
- Data: rec,
- Desc: nil,
- Err: nil,
- }
- }()
- return sc, ch, nil
- }
-
- // wait till the context is cancelled
- <-ctx.Done()
- return nil, nil, ctx.Err()
-}
-
-func (ts *TimeoutTestServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) {
- if cmd.GetQuery() == "timeout" {
- <-ctx.Done()
- return -1, ctx.Err()
- }
- return -1, arrow.ErrNotImplemented
-}
-
-func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
- switch cmd.GetQuery() {
- case "timeout":
- <-ctx.Done()
- case "fetch":
- tkt, _ := flightsql.CreateStatementQueryTicket([]byte("fetch"))
- info := &flight.FlightInfo{
- FlightDescriptor: desc,
- Endpoint: []*flight.FlightEndpoint{
- {Ticket: &flight.Ticket{Ticket: tkt}},
- },
- TotalRecords: -1,
- TotalBytes: -1,
- }
- return info, nil
- case "notimeout":
- time.Sleep(1 * time.Second)
- tkt, _ := flightsql.CreateStatementQueryTicket([]byte("sleep and succeed"))
- info := &flight.FlightInfo{
- FlightDescriptor: desc,
- Endpoint: []*flight.FlightEndpoint{
- {Ticket: &flight.Ticket{Ticket: tkt}},
- },
- TotalRecords: -1,
- TotalBytes: -1,
- }
- return info, nil
- }
- return nil, ctx.Err()
-}
-
-func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) {
- <-ctx.Done()
- return result, ctx.Err()
-}
-
-type TimeoutTestSuite struct {
- suite.Suite
-
- s flight.Server
- db adbc.Database
- cnxn adbc.Connection
-}
-
-func (ts *TimeoutTestSuite) SetupSuite() {
- ts.s = flight.NewServerWithMiddleware(nil)
- ts.s.RegisterFlightService(flightsql.NewFlightServer(&TimeoutTestServer{}))
- ts.Require().NoError(ts.s.Init("localhost:0"))
- ts.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
- go func() {
- _ = ts.s.Serve()
- }()
-
- uri := "grpc+tcp://" + ts.s.Addr().String()
- var err error
- ts.db, err = (driver.Driver{}).NewDatabase(map[string]string{
- "uri": uri,
- })
- ts.Require().NoError(err)
-}
-
-func (ts *TimeoutTestSuite) SetupTest() {
- var err error
- ts.cnxn, err = ts.db.Open(context.Background())
- ts.Require().NoError(err)
-}
-
-func (ts *TimeoutTestSuite) TearDownTest() {
- ts.Require().NoError(ts.cnxn.Close())
-}
-
-func (ts *TimeoutTestSuite) TearDownSuite() {
- ts.db = nil
- ts.s.Shutdown()
-}
-
-func (ts *TimeoutTestSuite) TestInvalidValues() {
- keys := []string{
- "adbc.flight.sql.rpc.timeout_seconds.fetch",
- "adbc.flight.sql.rpc.timeout_seconds.query",
- "adbc.flight.sql.rpc.timeout_seconds.update",
- }
- values := []string{"1.1f", "asdf", "inf", "NaN", "-1"}
-
- for _, k := range keys {
- for _, v := range values {
- ts.Run("key="+k+",val="+v, func() {
- err := ts.cnxn.(adbc.PostInitOptions).SetOption(k, v)
- var adbcErr adbc.Error
- ts.ErrorAs(err, &adbcErr)
- ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
- ts.ErrorContains(err, "invalid timeout option value")
- })
- }
- }
-}
-
-func (ts *TimeoutTestSuite) TestRemoveTimeout() {
- keys := []string{
- "adbc.flight.sql.rpc.timeout_seconds.fetch",
- "adbc.flight.sql.rpc.timeout_seconds.query",
- "adbc.flight.sql.rpc.timeout_seconds.update",
- }
- for _, k := range keys {
- ts.Run(k, func() {
- ts.NoError(ts.cnxn.(adbc.PostInitOptions).SetOption(k, "1.0"))
- ts.NoError(ts.cnxn.(adbc.PostInitOptions).SetOption(k, "0"))
- })
- }
-}
-
-func (ts *TimeoutTestSuite) TestDoActionTimeout() {
- ts.NoError(ts.cnxn.(adbc.PostInitOptions).
- SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1"))
-
- stmt, err := ts.cnxn.NewStatement()
- ts.Require().NoError(err)
- defer stmt.Close()
-
- ts.Require().NoError(stmt.SetSqlQuery("fetch"))
- var adbcErr adbc.Error
- ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
- ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
-}
-
-func (ts *TimeoutTestSuite) TestDoGetTimeout() {
- ts.NoError(ts.cnxn.(adbc.PostInitOptions).
- SetOption("adbc.flight.sql.rpc.timeout_seconds.fetch", "0.1"))
-
- stmt, err := ts.cnxn.NewStatement()
- ts.Require().NoError(err)
- defer stmt.Close()
-
- ts.Require().NoError(stmt.SetSqlQuery("fetch"))
- var adbcErr adbc.Error
- _, _, err = stmt.ExecuteQuery(context.Background())
- ts.ErrorAs(err, &adbcErr)
- ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
-}
-
-func (ts *TimeoutTestSuite) TestDoPutTimeout() {
- ts.NoError(ts.cnxn.(adbc.PostInitOptions).
- SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "1.1"))
-
- stmt, err := ts.cnxn.NewStatement()
- ts.Require().NoError(err)
- defer stmt.Close()
-
- ts.Require().NoError(stmt.SetSqlQuery("timeout"))
- var adbcErr adbc.Error
- _, err = stmt.ExecuteUpdate(context.Background())
- ts.ErrorAs(err, &adbcErr)
- ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
-}
-
-func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
- ts.NoError(ts.cnxn.(adbc.PostInitOptions).
- SetOption("adbc.flight.sql.rpc.timeout_seconds.query", "0.1"))
-
- stmt, err := ts.cnxn.NewStatement()
- ts.Require().NoError(err)
- defer stmt.Close()
-
- ts.Require().NoError(stmt.SetSqlQuery("timeout"))
- var adbcErr adbc.Error
- _, _, err = stmt.ExecuteQuery(context.Background())
- ts.ErrorAs(err, &adbcErr)
- ts.NotEqual(adbc.StatusNotImplemented, adbcErr.Code, adbcErr.Error())
-}
-
-func (ts *TimeoutTestSuite) TestDontTimeout() {
- ts.NoError(ts.cnxn.(adbc.PostInitOptions).
- SetOption("adbc.flight.sql.rpc.timeout_seconds.fetch", "2.0"))
- ts.NoError(ts.cnxn.(adbc.PostInitOptions).
- SetOption("adbc.flight.sql.rpc.timeout_seconds.query", "2.0"))
-
- stmt, err := ts.cnxn.NewStatement()
- ts.Require().NoError(err)
- defer stmt.Close()
-
- ts.Require().NoError(stmt.SetSqlQuery("notimeout"))
- // GetFlightInfo will sleep for one second and DoGet will also
- // sleep for one second. But our timeout is 2 seconds, which is
- // per-operation. So we shouldn't time out and all should succeed.
- rr, _, err := stmt.ExecuteQuery(context.Background())
- ts.Require().NoError(err)
- defer rr.Release()
-
- ts.True(rr.Next())
- rec := rr.Record()
-
- sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
- expected, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
- ts.Require().NoError(err)
- defer expected.Release()
- ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", expected, rec)
-}
-
type TLSTests struct {
suite.Suite