You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2023/04/11 15:46:07 UTC
[arrow] branch main updated: GH-34332: [Go][FlightRPC] Add driver for `database/sql` framework (#34331)
This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new c40e658fbb GH-34332: [Go][FlightRPC] Add driver for `database/sql` framework (#34331)
c40e658fbb is described below
commit c40e658fbbd6201132c4378eb0fefb746ff5915f
Author: Sven Rebhan <36...@users.noreply.github.com>
AuthorDate: Tue Apr 11 17:45:58 2023 +0200
GH-34332: [Go][FlightRPC] Add driver for `database/sql` framework (#34331)
### Rationale for this change
Using Golang's `database/sql` framework is well known, offers goodies like connection pooling and is easy to use. Therefore using FlightSQL trough this framework is a good starting point for users performing simple queries, inserts etc.
### What changes are included in this PR?
This PR adds an `database/sql/driver` implementation currently supporting `sqlite` and `InfluxData IOx` (query only). Unit-tests are added using the SQLite server example implementation and the driver and driver settings are documented.
### Are these changes tested?
Yes, a test-suite is added for the driver. Futhermore, the IOx backend is additionally tested against a real local instance using [this code](https://github.com/srebhan/go-flightsql-example).
### Are there any user-facing changes?
This PR does not contain breaking changes. All modifications to the FlightSQL client code are transparent to the user.
* Closes: #34332
Authored-by: Sven Rebhan <sr...@influxdata.com>
Signed-off-by: Matt Topol <zo...@gmail.com>
---
go/arrow/flight/client.go | 6 +-
go/arrow/flight/flightsql/client.go | 10 +-
go/arrow/flight/flightsql/driver/README.md | 151 +++++
go/arrow/flight/flightsql/driver/config.go | 125 ++++
go/arrow/flight/flightsql/driver/driver.go | 492 ++++++++++++++
go/arrow/flight/flightsql/driver/driver_test.go | 816 ++++++++++++++++++++++++
go/arrow/flight/flightsql/driver/utils.go | 272 ++++++++
7 files changed, 1869 insertions(+), 3 deletions(-)
diff --git a/go/arrow/flight/client.go b/go/arrow/flight/client.go
index 5ad3c9be07..da6b60c89b 100644
--- a/go/arrow/flight/client.go
+++ b/go/arrow/flight/client.go
@@ -271,6 +271,10 @@ func NewFlightClient(addr string, auth ClientAuthHandler, opts ...grpc.DialOptio
// being the inner most wrapper around the actual call. It also passes along the dialoptions passed in such
// as TLS certs and so on.
func NewClientWithMiddleware(addr string, auth ClientAuthHandler, middleware []ClientMiddleware, opts ...grpc.DialOption) (Client, error) {
+ return NewClientWithMiddlewareCtx(context.Background(), addr, auth, middleware, opts...)
+}
+
+func NewClientWithMiddlewareCtx(ctx context.Context, addr string, auth ClientAuthHandler, middleware []ClientMiddleware, opts ...grpc.DialOption) (Client, error) {
unary := make([]grpc.UnaryClientInterceptor, 0, len(middleware))
stream := make([]grpc.StreamClientInterceptor, 0, len(middleware))
if auth != nil {
@@ -288,7 +292,7 @@ func NewClientWithMiddleware(addr string, auth ClientAuthHandler, middleware []C
}
}
opts = append(opts, grpc.WithChainUnaryInterceptor(unary...), grpc.WithChainStreamInterceptor(stream...))
- conn, err := grpc.Dial(addr, opts...)
+ conn, err := grpc.DialContext(ctx, addr, opts...)
if err != nil {
return nil, err
}
diff --git a/go/arrow/flight/flightsql/client.go b/go/arrow/flight/flightsql/client.go
index a73fc4657c..a148f83e96 100644
--- a/go/arrow/flight/flightsql/client.go
+++ b/go/arrow/flight/flightsql/client.go
@@ -39,7 +39,11 @@ import (
// its arguments to flight.NewClientWithMiddleware to create the
// underlying Flight Client.
func NewClient(addr string, auth flight.ClientAuthHandler, middleware []flight.ClientMiddleware, opts ...grpc.DialOption) (*Client, error) {
- cl, err := flight.NewClientWithMiddleware(addr, auth, middleware, opts...)
+ return NewClientCtx(context.Background(), addr, auth, middleware, opts...)
+}
+
+func NewClientCtx(ctx context.Context, addr string, auth flight.ClientAuthHandler, middleware []flight.ClientMiddleware, opts ...grpc.DialOption) (*Client, error) {
+ cl, err := flight.NewClientWithMiddlewareCtx(ctx, addr, auth, middleware, opts...)
if err != nil {
return nil, err
}
@@ -1110,7 +1114,9 @@ func (p *PreparedStatement) clearParameters() {
func (p *PreparedStatement) SetParameters(binding arrow.Record) {
p.clearParameters()
p.paramBinding = binding
- p.paramBinding.Retain()
+ if p.paramBinding != nil {
+ p.paramBinding.Retain()
+ }
}
// SetRecordReader takes a RecordReader to send as the parameter bindings when
diff --git a/go/arrow/flight/flightsql/driver/README.md b/go/arrow/flight/flightsql/driver/README.md
new file mode 100644
index 0000000000..cfb33ba2c6
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/README.md
@@ -0,0 +1,151 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+# FlightSQL driver
+
+A FlightSQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/)
+package. This driver is a lightweight wrapper around the FlightSQL client in
+pure Go. It provides all advantages of a `database/sql` driver like automatic
+connection pooling, transactions combined with ease of use (see (#usage)).
+
+---------------------------------------
+
+* [Prerequisits](#prerequisits)
+* [Usage](#usage)
+* [Data Source Name (DSN)](#data-source-name-dsn)
+* [Driver config usage](#driver-config-usage)
+* [TLS setup](#tls-setup)
+
+---------------------------------------
+
+## Prerequisits
+
+* Go 1.19+
+* Installation via `go get -u github.com/apache/arrow/go/v12/arrow/flight/flightsql`
+* Backend speaking FlightSQL
+
+---------------------------------------
+
+## Usage
+
+_Go FlightQL Driver_ is an implementation of Go's `database/sql/driver`
+interface to use the [`database/sql`](https://golang.org/pkg/database/sql/)
+framework. The driver is registered as `flightsql` and configured using a
+[data-source name (DSN)](#data-source-name-dsn).
+
+A basic example using a SQLite backend looks like this
+
+```go
+import (
+ "database/sql"
+ "time"
+
+ _ "github.com/apache/arrow/go/v12/arrow/flight/flightsql"
+)
+
+// Open the connection to an SQLite backend
+db, err := sql.Open("flightsql", "flightsql://localhost:12345?timeout=5s")
+if err != nil {
+ panic(err)
+}
+// Make sure we close the connection to the database
+defer db.Close()
+
+// Use the connection e.g. for querying
+rows, err := db.Query("SELECT * FROM mytable")
+if err != nil {
+ panic(err)
+}
+// ...
+```
+
+## Data Source Name (DSN)
+
+A Data Source Name has the following format:
+
+```text
+flightsql://[user[:password]@]<address>[:port][?param1=value1&...¶mN=valueN]
+```
+
+The data-source-name (DSN) requires the `address` of the backend with an
+optional port setting. The `user` and `password` parameters are passed to the
+backend as GRPC Basic-Auth headers. If your backend requires a token based
+authentication, please use a `token` parameter (see
+[common parameters](#common-parameters) below).
+
+**Please note**: All parameters are case-sensitive!
+
+Alternatively to specifying the DSN directly you can use the `DriverConfig`
+structure to generate the DSN string. See the
+[Driver config usage section](#driver-config-usage) for details.
+
+### Common parameters
+
+The following common parameters exist
+
+#### `token`
+
+The `token` parameter can be used to specify the token for token-based
+authentication. The value is passed on to the backend as a GRPC Bearer-Auth
+header.
+
+#### `timeout`
+
+The `timeout` parameter can be set using a duration string e.g. `timeout=5s`
+to limit the maximum time an operation can take. This prevents calls that wait
+forever, e.g. if the backend is down or a query is taking very long. When
+not set, the driver will use an _infinite_ timeout.
+
+## Driver config usage
+
+Alternatively to specifying the DSN directly you can fill the `DriverConfig`
+structure and generate the DSN out of this. Here is some example
+
+```golang
+package main
+
+import (
+ "database/sql"
+ "log"
+ "time"
+
+ "github.com/apache/arrow/go/v12/arrow/flight/flightsql"
+)
+
+func main() {
+ config := flightsql.DriverConfig{
+ Address: "localhost:12345",
+ Token: "your token",
+ Timeout: 10 * time.Second,
+ Params: map[string]string{
+ "my-custom-parameter": "foobar",
+ },
+ }
+ db, err := sql.Open("flightsql", config.DSN())
+ if err != nil {
+ log.Fatalf("open failed: %v", err)
+ }
+ defer db.Close()
+
+ ...
+}
+```
+
+## TLS setup
+
+Currently TLS is not yet supported and will be added later.
diff --git a/go/arrow/flight/flightsql/driver/config.go b/go/arrow/flight/flightsql/driver/config.go
new file mode 100644
index 0000000000..d4a785dc6b
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/config.go
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package driver
+
+import (
+ "crypto/tls"
+ "fmt"
+ "net/url"
+ "time"
+)
+
+type DriverConfig struct {
+ Address string
+ Username string
+ Password string
+ Token string
+ Timeout time.Duration
+ Params map[string]string
+
+ TLSEnabled bool
+ TLSConfig *tls.Config
+}
+
+func NewDriverConfigFromDSN(dsn string) (*DriverConfig, error) {
+ u, err := url.Parse(dsn)
+ if err != nil {
+ return nil, err
+ }
+
+ // Sanity checks on the given connection string
+ if u.Scheme != "flightsql" {
+ return nil, fmt.Errorf("invalid scheme %q", u.Scheme)
+ }
+ if u.Path != "" {
+ return nil, fmt.Errorf("unexpected path %q", u.Path)
+ }
+
+ // Extract the settings
+ var username, password string
+ if u.User != nil {
+ username = u.User.Username()
+ if v, set := u.User.Password(); set {
+ password = v
+ }
+ }
+
+ config := &DriverConfig{
+ Address: u.Host,
+ Username: username,
+ Password: password,
+ Params: make(map[string]string),
+ }
+
+ // Determine the parameters
+ for key, values := range u.Query() {
+ // We only support single instances
+ if len(values) > 1 {
+ return nil, fmt.Errorf("too many values for %q", key)
+ }
+ var v string
+ if len(values) > 0 {
+ v = values[0]
+ }
+
+ switch key {
+ case "token":
+ config.Token = v
+ case "timeout":
+ config.Timeout, err = time.ParseDuration(v)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ config.Params[key] = v
+ }
+ }
+
+ return config, nil
+}
+
+func (config *DriverConfig) DSN() string {
+ u := url.URL{
+ Scheme: "flightsql",
+ Host: config.Address,
+ }
+ if config.Username != "" {
+ if config.Password == "" {
+ u.User = url.User(config.Username)
+ } else {
+ u.User = url.UserPassword(config.Username, config.Password)
+ }
+ }
+
+ // Set the parameters
+ values := url.Values{}
+ if config.Token != "" {
+ values.Add("token", config.Token)
+ }
+ if config.Timeout > 0 {
+ values.Add("timeout", config.Timeout.String())
+ }
+ for k, v := range config.Params {
+ values.Add(k, v)
+ }
+
+ // Check if we do have parameters at all and set them
+ if len(values) > 0 {
+ u.RawQuery = values.Encode()
+ }
+
+ return u.String()
+}
diff --git a/go/arrow/flight/flightsql/driver/driver.go b/go/arrow/flight/flightsql/driver/driver.go
new file mode 100644
index 0000000000..970d7a4dfe
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/driver.go
@@ -0,0 +1,492 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package driver
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "io"
+ "sort"
+ "time"
+
+ "github.com/apache/arrow/go/v12/arrow"
+ "github.com/apache/arrow/go/v12/arrow/array"
+ "github.com/apache/arrow/go/v12/arrow/flight/flightsql"
+ "github.com/apache/arrow/go/v12/arrow/memory"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/insecure"
+)
+
+var (
+ ErrNotSupported = errors.New("not supported")
+ ErrOutOfRange = errors.New("index out of range")
+ ErrTransactionInProgress = errors.New("transaction still in progress")
+)
+
+type Rows struct {
+ schema *arrow.Schema
+ records []arrow.Record
+ currentRecord int
+ currentRow int
+}
+
+// Columns returns the names of the columns.
+func (r *Rows) Columns() []string {
+ if len(r.records) == 0 {
+ return nil
+ }
+
+ // All records have the same columns
+ var cols []string
+ for _, c := range r.schema.Fields() {
+ cols = append(cols, c.Name)
+ }
+
+ return cols
+}
+
+// Close closes the rows iterator.
+func (r *Rows) Close() error {
+ for _, rec := range r.records {
+ rec.Release()
+ }
+ r.currentRecord = 0
+ r.currentRow = 0
+
+ return nil
+}
+
+// Next is called to populate the next row of data into
+// the provided slice. The provided slice will be the same
+// size as the Columns() are wide.
+//
+// Next should return io.EOF when there are no more rows.
+//
+// The dest should not be written to outside of Next. Care
+// should be taken when closing Rows not to modify
+// a buffer held in dest.
+func (r *Rows) Next(dest []driver.Value) error {
+ if r.currentRecord >= len(r.records) {
+ return io.EOF
+ }
+ record := r.records[r.currentRecord]
+
+ if int64(r.currentRow) >= record.NumRows() {
+ return ErrOutOfRange
+ }
+
+ for i, arr := range record.Columns() {
+ v, err := fromArrowType(arr, r.currentRow)
+ if err != nil {
+ return err
+ }
+ dest[i] = v
+ }
+
+ r.currentRow++
+ if int64(r.currentRow) >= record.NumRows() {
+ r.currentRecord++
+ r.currentRow = 0
+ }
+
+ return nil
+}
+
+type Result struct {
+ affected int64
+ lastinsert int64
+}
+
+// LastInsertId returns the database's auto-generated ID after, for example,
+// an INSERT into a table with primary key.
+func (r *Result) LastInsertId() (int64, error) {
+ if r.lastinsert < 0 {
+ return -1, ErrNotSupported
+ }
+ return r.lastinsert, nil
+}
+
+// RowsAffected returns the number of rows affected by the query.
+func (r *Result) RowsAffected() (int64, error) {
+ if r.affected < 0 {
+ return -1, ErrNotSupported
+ }
+ return r.affected, nil
+}
+
+type Stmt struct {
+ stmt *flightsql.PreparedStatement
+ client *flightsql.Client
+
+ timeout time.Duration
+}
+
+// Close closes the statement.
+func (s *Stmt) Close() error {
+ ctx := context.Background()
+ if s.timeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, s.timeout)
+ defer cancel()
+ }
+
+ return s.stmt.Close(ctx)
+}
+
+// NumInput returns the number of placeholder parameters.
+func (s *Stmt) NumInput() int {
+ schema := s.stmt.ParameterSchema()
+ if schema == nil {
+ // NumInput may also return -1, if the driver doesn't know its number
+ // of placeholders. In that case, the sql package will not sanity check
+ // Exec or Query argument counts.
+ return -1
+ }
+
+ // If NumInput returns >= 0, the sql package will sanity check argument
+ // counts from callers and return errors to the caller before the
+ // statement's Exec or Query methods are called.
+ return len(schema.Fields())
+}
+
+// Exec executes a query that doesn't return rows, such
+// as an INSERT or UPDATE.
+func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
+ var params []driver.NamedValue
+ for i, arg := range args {
+ params = append(params, driver.NamedValue{
+ Ordinal: i,
+ Value: arg,
+ })
+ }
+
+ return s.ExecContext(context.Background(), params)
+}
+
+// ExecContext executes a query that doesn't return rows, such as an INSERT or UPDATE.
+func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+ if err := s.setParameters(args); err != nil {
+ return nil, err
+ }
+
+ if _, set := ctx.Deadline(); !set && s.timeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, s.timeout)
+ defer cancel()
+ }
+
+ n, err := s.stmt.ExecuteUpdate(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Result{affected: n, lastinsert: -1}, nil
+}
+
+// Query executes a query that may return rows, such as a SELECT.
+func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
+ var params []driver.NamedValue
+ for i, arg := range args {
+ params = append(params, driver.NamedValue{
+ Ordinal: i,
+ Value: arg,
+ })
+ }
+
+ return s.QueryContext(context.Background(), params)
+}
+
+// QueryContext executes a query that may return rows, such as a SELECT.
+func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+ if err := s.setParameters(args); err != nil {
+ return nil, err
+ }
+
+ if _, set := ctx.Deadline(); !set && s.timeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, s.timeout)
+ defer cancel()
+ }
+
+ info, err := s.stmt.Execute(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ rows := Rows{}
+ for _, endpoint := range info.Endpoint {
+ reader, err := s.client.DoGet(ctx, endpoint.GetTicket())
+ if err != nil {
+ return nil, fmt.Errorf("getting ticket failed: %w", err)
+ }
+
+ rows.schema = reader.Schema()
+ for reader.Next() {
+ record := reader.Record()
+ record.Retain()
+ rows.records = append(rows.records, record)
+
+ }
+ if err := reader.Err(); err != nil {
+ return &rows, err
+ }
+ }
+
+ return &rows, nil
+}
+
+func (s *Stmt) setParameters(args []driver.NamedValue) error {
+ if len(args) == 0 {
+ s.stmt.SetParameters(nil)
+ return nil
+ }
+
+ sort.SliceStable(args, func(i, j int) bool {
+ return args[i].Ordinal < args[j].Ordinal
+ })
+
+ schema := s.stmt.ParameterSchema()
+ if schema == nil {
+ var fields []arrow.Field
+ for _, arg := range args {
+ dt, err := toArrowDataType(arg.Value)
+ if err != nil {
+ return fmt.Errorf("schema: %w", err)
+ }
+ fields = append(fields, arrow.Field{
+ Name: arg.Name,
+ Type: dt,
+ })
+ }
+ schema = arrow.NewSchema(fields, nil)
+ }
+
+ recBuilder := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+ defer recBuilder.Release()
+
+ for i, arg := range args {
+ fieldBuilder := recBuilder.Field(i)
+ if err := setFieldValue(fieldBuilder, arg.Value); err != nil {
+ return err
+ }
+ }
+
+ rec := recBuilder.NewRecord()
+ defer rec.Release()
+
+ s.stmt.SetParameters(rec)
+
+ return nil
+}
+
+type Tx struct {
+ tx *flightsql.Txn
+ timeout time.Duration
+}
+
+func (t *Tx) Commit() error {
+ ctx := context.Background()
+ if t.timeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, t.timeout)
+ defer cancel()
+ }
+
+ return t.tx.Commit(ctx)
+}
+
+func (t *Tx) Rollback() error {
+ ctx := context.Background()
+ if t.timeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, t.timeout)
+ defer cancel()
+ }
+
+ return t.tx.Rollback(ctx)
+}
+
+type Driver struct{}
+
+// Open returns a new connection to the database.
+func (d *Driver) Open(name string) (driver.Conn, error) {
+ c, err := d.OpenConnector(name)
+ if err != nil {
+ return nil, err
+ }
+
+ return c.Connect(context.Background())
+}
+
+// OpenConnector must parse the name in the same format that Driver.Open
+// parses the name parameter.
+func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
+ config, err := NewDriverConfigFromDSN(name)
+ if err != nil {
+ return nil, err
+ }
+
+ c := &Connector{}
+ if err := c.Configure(config); err != nil {
+ return nil, err
+ }
+
+ return c, nil
+}
+
+type Connector struct {
+ addr string
+ timeout time.Duration
+ options []grpc.DialOption
+}
+
+// Configure the driver with the corresponding config
+func (c *Connector) Configure(config *DriverConfig) error {
+ // Set the driver properties
+ c.addr = config.Address
+ c.timeout = config.Timeout
+ c.options = []grpc.DialOption{grpc.WithBlock()}
+
+ // Create GRPC options necessary for the backend
+ var transportCreds credentials.TransportCredentials
+ if !config.TLSEnabled {
+ transportCreds = insecure.NewCredentials()
+ } else {
+ transportCreds = credentials.NewTLS(config.TLSConfig)
+ }
+ c.options = append(c.options, grpc.WithTransportCredentials(transportCreds))
+
+ // Set authentication credentials
+ rpcCreds := grpcCredentials{
+ username: config.Username,
+ password: config.Password,
+ token: config.Token,
+ params: config.Params,
+ }
+ c.options = append(c.options, grpc.WithPerRPCCredentials(rpcCreds))
+
+ return nil
+}
+
+// Connect returns a connection to the database.
+func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
+ if _, set := ctx.Deadline(); !set && c.timeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, c.timeout)
+ defer cancel()
+ }
+
+ client, err := flightsql.NewClientCtx(ctx, c.addr, nil, nil, c.options...)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Connection{
+ client: client,
+ timeout: c.timeout,
+ }, nil
+}
+
+// Driver returns the underlying Driver of the Connector,
+// mainly to maintain compatibility with the Driver method
+// on sql.DB.
+func (c *Connector) Driver() driver.Driver {
+ return &Driver{}
+}
+
+type Connection struct {
+ client *flightsql.Client
+ txn *flightsql.Txn
+
+ timeout time.Duration
+}
+
+// Prepare returns a prepared statement, bound to this connection.
+func (c *Connection) Prepare(query string) (driver.Stmt, error) {
+ return c.PrepareContext(context.Background(), query)
+}
+
+// PrepareContext returns a prepared statement, bound to this connection.
+// context is for the preparation of the statement,
+// it must not store the context within the statement itself.
+func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+ if _, set := ctx.Deadline(); !set && c.timeout > 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, c.timeout)
+ defer cancel()
+ }
+
+ var err error
+ var stmt *flightsql.PreparedStatement
+ if c.txn != nil && c.txn.ID().IsValid() {
+ stmt, err = c.txn.Prepare(ctx, query)
+ } else {
+ stmt, err = c.client.Prepare(ctx, query)
+ c.txn = nil
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ return &Stmt{
+ stmt: stmt,
+ client: c.client,
+ timeout: c.timeout,
+ }, nil
+}
+
+// Close invalidates and potentially stops any current
+// prepared statements and transactions, marking this
+// connection as no longer in use.
+func (c *Connection) Close() error {
+ if c.txn != nil && c.txn.ID().IsValid() {
+ return ErrTransactionInProgress
+ }
+
+ if c.client == nil {
+ return nil
+ }
+
+ err := c.client.Close()
+ c.client = nil
+
+ return err
+}
+
+// Begin starts and returns a new transaction.
+func (c *Connection) Begin() (driver.Tx, error) {
+ return c.BeginTx(context.Background(), sql.TxOptions{})
+}
+
+func (c *Connection) BeginTx(ctx context.Context, opts sql.TxOptions) (driver.Tx, error) {
+ tx, err := c.client.BeginTransaction(ctx)
+ if err != nil {
+ return nil, err
+ }
+ c.txn = tx
+
+ return &Tx{tx: tx, timeout: c.timeout}, nil
+}
+
+// Register the driver on load.
+func init() {
+ sql.Register("flightsql", &Driver{})
+}
diff --git a/go/arrow/flight/flightsql/driver/driver_test.go b/go/arrow/flight/flightsql/driver/driver_test.go
new file mode 100644
index 0000000000..60cfb32364
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/driver_test.go
@@ -0,0 +1,816 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+// +build go1.18
+
+package driver_test
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "github.com/stretchr/testify/suite"
+
+ "github.com/apache/arrow/go/v12/arrow"
+ "github.com/apache/arrow/go/v12/arrow/array"
+ "github.com/apache/arrow/go/v12/arrow/flight"
+ "github.com/apache/arrow/go/v12/arrow/flight/flightsql"
+ "github.com/apache/arrow/go/v12/arrow/flight/flightsql/driver"
+ "github.com/apache/arrow/go/v12/arrow/flight/flightsql/example"
+ "github.com/apache/arrow/go/v12/arrow/memory"
+)
+
+const defaultTableName = "drivertest"
+
+var defaultStatements = map[string]string{
+ "create table": `
+CREATE TABLE %s (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name varchar(100),
+ value int
+);`,
+ "insert": `INSERT INTO %s (name, value) VALUES ('%s', %d);`,
+ "query": `SELECT * FROM %s;`,
+ "constraint query": `SELECT * FROM %s WHERE name LIKE '%%%s%%'`,
+ "placeholder query": `SELECT * FROM %s WHERE name LIKE ?`,
+}
+
+type SqlTestSuite struct {
+ suite.Suite
+
+ Config driver.DriverConfig
+ TableName string
+ Statements map[string]string
+
+ createServer func() (flight.Server, string, error)
+ startServer func(flight.Server) error
+ stopServer func(flight.Server)
+}
+
+func (s *SqlTestSuite) SetupSuite() {
+ if s.TableName == "" {
+ s.TableName = defaultTableName
+ }
+
+ if s.Statements == nil {
+ s.Statements = make(map[string]string)
+ }
+ // Fill in the statements. Keep statements already defined e.g. by the
+ // user or suite-generator.
+ for k, v := range defaultStatements {
+ if _, found := s.Statements[k]; !found {
+ s.Statements[k] = v
+ }
+ }
+
+ require.Contains(s.T(), s.Statements, "create table")
+ require.Contains(s.T(), s.Statements, "insert")
+ require.Contains(s.T(), s.Statements, "query")
+ require.Contains(s.T(), s.Statements, "constraint query")
+ require.Contains(s.T(), s.Statements, "placeholder query")
+}
+
+func (s *SqlTestSuite) TestOpenClose() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ require.NoError(t, db.Close())
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+func (s *SqlTestSuite) TestCreateTable() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ result, err := db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+ require.NoError(t, err)
+
+ affected, err := result.RowsAffected()
+ require.Equal(t, int64(0), affected)
+ require.NoError(t, err)
+
+ last, err := result.LastInsertId()
+ require.Equal(t, int64(-1), last)
+ require.ErrorIs(t, err, driver.ErrNotSupported)
+
+ require.NoError(t, db.Close())
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+func (s *SqlTestSuite) TestInsert() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ // Create the table
+ _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+ require.NoError(t, err)
+
+ // Insert data
+ values := map[string]int{
+ "zero": 0,
+ "one": 1,
+ "minus one": -1,
+ "twelve": 12,
+ }
+ var stmts []string
+ for k, v := range values {
+ stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, k, v))
+ }
+ result, err := db.Exec(strings.Join(stmts, "\n"))
+ require.NoError(t, err)
+
+ affected, err := result.RowsAffected()
+ require.Equal(t, int64(1), affected)
+ require.NoError(t, err)
+
+ require.NoError(t, db.Close())
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+func (s *SqlTestSuite) TestQuery() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ // Create the table
+ _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+ require.NoError(t, err)
+
+ // Insert data
+ expected := map[string]int{
+ "zero": 0,
+ "one": 1,
+ "minus one": -1,
+ "twelve": 12,
+ }
+ var stmts []string
+ for k, v := range expected {
+ stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, k, v))
+ }
+ _, err = db.Exec(strings.Join(stmts, "\n"))
+ require.NoError(t, err)
+
+ rows, err := db.Query(fmt.Sprintf(s.Statements["query"], s.TableName))
+ require.NoError(t, err)
+
+ // Check result
+ actual := make(map[string]int, len(expected))
+ for rows.Next() {
+ var name string
+ var id, value int
+ require.NoError(t, rows.Scan(&id, &name, &value))
+ actual[name] = value
+ }
+ require.NoError(t, db.Close())
+ require.EqualValues(t, expected, actual)
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+func (s *SqlTestSuite) TestPreparedQuery() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ // Create the table
+ _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+ require.NoError(t, err)
+
+ // Insert data
+ expected := map[string]int{
+ "zero": 0,
+ "one": 1,
+ "minus one": -1,
+ "twelve": 12,
+ }
+ var stmts []string
+ for k, v := range expected {
+ stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, k, v))
+ }
+ _, err = db.Exec(strings.Join(stmts, "\n"))
+ require.NoError(t, err)
+
+ // Do query
+ stmt, err := db.Prepare(fmt.Sprintf(s.Statements["query"], s.TableName))
+ require.NoError(t, err)
+
+ rows, err := stmt.Query()
+ require.NoError(t, err)
+
+ // Check result
+ actual := make(map[string]int, len(expected))
+ for rows.Next() {
+ var name string
+ var id, value int
+ require.NoError(t, rows.Scan(&id, &name, &value))
+ actual[name] = value
+ }
+ require.NoError(t, db.Close())
+ require.EqualValues(t, expected, actual)
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+func (s *SqlTestSuite) TestPreparedQueryWithConstraint() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ // Create the table
+ _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+ require.NoError(t, err)
+
+ // Insert data
+ data := map[string]int{
+ "zero": 0,
+ "one": 1,
+ "minus one": -1,
+ "twelve": 12,
+ }
+ var stmts []string
+ for k, v := range data {
+ stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, k, v))
+ }
+ _, err = db.Exec(strings.Join(stmts, "\n"))
+ require.NoError(t, err)
+
+ // Do query
+ stmt, err := db.Prepare(fmt.Sprintf(s.Statements["constraint query"], s.TableName, "one"))
+ require.NoError(t, err)
+
+ rows, err := stmt.Query()
+ require.NoError(t, err)
+
+ // Check result
+ expected := map[string]int{
+ "one": 1,
+ "minus one": -1,
+ }
+ actual := make(map[string]int, len(expected))
+ for rows.Next() {
+ var name string
+ var id, value int
+ require.NoError(t, rows.Scan(&id, &name, &value))
+ actual[name] = value
+ }
+ require.NoError(t, db.Close())
+ require.EqualValues(t, expected, actual)
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+func (s *SqlTestSuite) TestPreparedQueryWithPlaceholder() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ // Create the table
+ _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+ require.NoError(t, err)
+
+ // Insert data
+ data := map[string]int{
+ "zero": 0,
+ "one": 1,
+ "minus one": -1,
+ "twelve": 12,
+ }
+ var stmts []string
+ for k, v := range data {
+ stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, k, v))
+ }
+ _, err = db.Exec(strings.Join(stmts, "\n"))
+ require.NoError(t, err)
+
+ // Do query
+ query := fmt.Sprintf(s.Statements["placeholder query"], s.TableName)
+ stmt, err := db.Prepare(query)
+ require.NoError(t, err)
+
+ params := []interface{}{"%%one%%"}
+ rows, err := stmt.Query(params...)
+ require.NoError(t, err)
+
+ // Check result
+ expected := map[string]int{
+ "one": 1,
+ "minus one": -1,
+ }
+ actual := make(map[string]int, len(expected))
+ for rows.Next() {
+ var name string
+ var id, value int
+ require.NoError(t, rows.Scan(&id, &name, &value))
+ actual[name] = value
+ }
+ require.NoError(t, db.Close())
+ require.EqualValues(t, expected, actual)
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+func (s *SqlTestSuite) TestTxRollback() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ tx, err := db.Begin()
+ require.NoError(t, err)
+
+ // Create the table
+ _, err = tx.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+ require.NoError(t, err)
+
+ // Insert data
+ data := map[string]int{
+ "zero": 0,
+ "one": 1,
+ "minus one": -1,
+ "twelve": 12,
+ }
+ for k, v := range data {
+ stmt := fmt.Sprintf(s.Statements["insert"], s.TableName, k, v)
+ _, err = tx.Exec(stmt)
+ require.NoError(t, err)
+ }
+
+ // Rollback the transaction
+ require.NoError(t, tx.Rollback())
+
+ // Check result
+ tbls := `SELECT name FROM sqlite_schema WHERE type ='table' AND name NOT LIKE 'sqlite_%';`
+ rows, err := db.Query(tbls)
+ require.NoError(t, err)
+ count := 0
+ for rows.Next() {
+ count++
+ }
+ require.Equal(t, 0, count)
+ require.NoError(t, db.Close())
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+func (s *SqlTestSuite) TestTxCommit() {
+ t := s.T()
+
+ // Create and start the server
+ server, addr, err := s.createServer()
+ require.NoError(t, err)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ require.NoError(s.T(), s.startServer(server))
+ }()
+ defer s.stopServer(server)
+ time.Sleep(100 * time.Millisecond)
+
+ // Configure client
+ cfg := s.Config
+ cfg.Address = addr
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ tx, err := db.Begin()
+ require.NoError(t, err)
+
+ // Create the table
+ _, err = tx.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+ require.NoError(t, err)
+
+ // Insert data
+ data := map[string]int{
+ "zero": 0,
+ "one": 1,
+ "minus one": -1,
+ "twelve": 12,
+ }
+ for k, v := range data {
+ stmt := fmt.Sprintf(s.Statements["insert"], s.TableName, k, v)
+ _, err = tx.Exec(stmt)
+ require.NoError(t, err)
+ }
+
+ // Commit the transaction
+ require.NoError(t, tx.Commit())
+
+ // Check if the table exists
+ tbls := `SELECT name FROM sqlite_schema WHERE type ='table' AND name NOT LIKE 'sqlite_%';`
+ rows, err := db.Query(tbls)
+ require.NoError(t, err)
+
+ var tables []string
+ for rows.Next() {
+ var name string
+ require.NoError(t, rows.Scan(&name))
+ tables = append(tables, name)
+ }
+ require.Contains(t, tables, "drivertest")
+
+ // Check the actual data
+ stmt, err := db.Prepare(fmt.Sprintf(s.Statements["query"], s.TableName))
+ require.NoError(t, err)
+
+ rows, err = stmt.Query()
+ require.NoError(t, err)
+
+ // Check result
+ actual := make(map[string]int, len(data))
+ for rows.Next() {
+ var name string
+ var id, value int
+ require.NoError(t, rows.Scan(&id, &name, &value))
+ actual[name] = value
+ }
+ require.NoError(t, db.Close())
+ require.EqualValues(t, data, actual)
+
+ // Tear-down server
+ s.stopServer(server)
+ wg.Wait()
+}
+
+/*** BACKEND tests ***/
+
+func TestSqliteBackend(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ s := &SqlTestSuite{
+ Config: driver.DriverConfig{
+ Timeout: 5 * time.Second,
+ },
+ }
+
+ s.createServer = func() (flight.Server, string, error) {
+ server := flight.NewServerWithMiddleware(nil)
+
+ // Setup the SQLite backend
+ db, err := sql.Open("sqlite", ":memory:")
+ if err != nil {
+ return nil, "", err
+ }
+ sqliteServer, err := example.NewSQLiteFlightSQLServer(db)
+ if err != nil {
+ return nil, "", err
+ }
+ sqliteServer.Alloc = mem
+
+ // Connect the FlightSQL frontend to the backend
+ server.RegisterFlightService(flightsql.NewFlightServer(sqliteServer))
+ if err := server.Init("localhost:0"); err != nil {
+ return nil, "", err
+ }
+ server.SetShutdownOnSignals(os.Interrupt, os.Kill)
+ return server, server.Addr().String(), nil
+ }
+ s.startServer = func(server flight.Server) error { return server.Serve() }
+ s.stopServer = func(server flight.Server) { server.Shutdown() }
+
+ suite.Run(t, s)
+}
+
+func TestPreparedStatementSchema(t *testing.T) {
+ // Setup the expected test
+ backend := &MockServer{
+ PreparedStatementParameterSchema: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}, Nullable: false}}, nil),
+ DataSchema: arrow.NewSchema([]arrow.Field{
+ {Name: "time", Type: &arrow.Time64Type{Unit: arrow.Nanosecond}, Nullable: true},
+ {Name: "value", Type: &arrow.Int64Type{}, Nullable: false},
+ }, nil),
+ Data: "[]",
+ }
+
+ // Instantiate a mock server
+ server := flight.NewServerWithMiddleware(nil)
+ server.RegisterFlightService(flightsql.NewFlightServer(backend))
+ require.NoError(t, server.Init("localhost:0"))
+ server.SetShutdownOnSignals(os.Interrupt, os.Kill)
+ go server.Serve()
+ defer server.Shutdown()
+
+ // Configure client
+ cfg := driver.DriverConfig{
+ Timeout: 5 * time.Second,
+ Address: server.Addr().String(),
+ }
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ // Do query
+ stmt, err := db.Prepare("SELECT * FROM foo WHERE name LIKE ?")
+ require.NoError(t, err)
+
+ _, err = stmt.Query()
+ require.ErrorContains(t, err, "expected 1 arguments, got 0")
+
+ // Test for error issues by driver
+ _, err = stmt.Query(23)
+ require.ErrorContains(t, err, "invalid value type int64 for builder *array.StringBuilder")
+
+ rows, err := stmt.Query("master")
+ require.NoError(t, err)
+ require.NotNil(t, rows)
+}
+
+func TestPreparedStatementNoSchema(t *testing.T) {
+ // Setup the expected test
+ backend := &MockServer{
+ DataSchema: arrow.NewSchema([]arrow.Field{
+ {Name: "time", Type: &arrow.Time64Type{Unit: arrow.Nanosecond}, Nullable: true},
+ {Name: "value", Type: &arrow.Int64Type{}, Nullable: false},
+ }, nil),
+ Data: "[]",
+ ExpectedPreparedStatementSchema: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}, Nullable: false}}, nil),
+ }
+
+ // Instantiate a mock server
+ server := flight.NewServerWithMiddleware(nil)
+ server.RegisterFlightService(flightsql.NewFlightServer(backend))
+ require.NoError(t, server.Init("localhost:0"))
+ server.SetShutdownOnSignals(os.Interrupt, os.Kill)
+ go server.Serve()
+ defer server.Shutdown()
+
+ // Configure client
+ cfg := driver.DriverConfig{
+ Timeout: 5 * time.Second,
+ Address: server.Addr().String(),
+ }
+ db, err := sql.Open("flightsql", cfg.DSN())
+ require.NoError(t, err)
+ defer db.Close()
+
+ // Do query
+ stmt, err := db.Prepare("SELECT * FROM foo WHERE name LIKE ?")
+ require.NoError(t, err)
+
+ _, err = stmt.Query()
+ require.NoError(t, err, "expected 1 arguments, got 0")
+
+ // Test for error issued by server due to missing parameter schema
+ _, err = stmt.Query(23)
+ require.ErrorContains(t, err, "parameter schema: unexpected")
+
+ rows, err := stmt.Query("master")
+ require.NoError(t, err)
+ require.NotNil(t, rows)
+}
+
+// Mockup database server
+type MockServer struct {
+ flightsql.BaseServer
+ DataSchema *arrow.Schema
+ PreparedStatementParameterSchema *arrow.Schema
+ PreparedStatementError string
+ Data string
+
+ ExpectedPreparedStatementSchema *arrow.Schema
+}
+
+func (s *MockServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (flightsql.ActionCreatePreparedStatementResult, error) {
+ if s.PreparedStatementError != "" {
+ return flightsql.ActionCreatePreparedStatementResult{}, errors.New(s.PreparedStatementError)
+ }
+ return flightsql.ActionCreatePreparedStatementResult{
+ Handle: []byte("prepared"),
+ DatasetSchema: s.DataSchema,
+ ParameterSchema: s.PreparedStatementParameterSchema,
+ }, nil
+}
+
+func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flightsql.PreparedStatementQuery, r flight.MessageReader, w flight.MetadataWriter) error {
+ if s.ExpectedPreparedStatementSchema != nil {
+ if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) {
+ return errors.New("parameter schema: unexpected")
+ }
+ return nil
+ }
+
+ if s.PreparedStatementParameterSchema != nil && !s.PreparedStatementParameterSchema.Equal(r.Schema()) {
+ return fmt.Errorf("parameter schema: %w", arrow.ErrInvalid)
+ }
+
+ return nil
+}
+
+func (s *MockServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
+ record, _, err := array.RecordFromJSON(memory.DefaultAllocator, s.DataSchema, strings.NewReader(s.Data))
+ if err != nil {
+ return nil, nil, err
+ }
+ chunk := make(chan flight.StreamChunk)
+ go func() {
+ defer close(chunk)
+ chunk <- flight.StreamChunk{
+ Data: record,
+ Desc: nil,
+ Err: nil,
+ }
+ }()
+ return s.DataSchema, chunk, nil
+}
+
+func (s *MockServer) GetFlightInfoPreparedStatement(ctx context.Context, stmt flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
+ handle := stmt.GetPreparedStatementHandle()
+ ticket, err := flightsql.CreateStatementQueryTicket(handle)
+ if err != nil {
+ return nil, err
+ }
+ return &flight.FlightInfo{
+ FlightDescriptor: desc,
+ Endpoint: []*flight.FlightEndpoint{
+ {Ticket: &flight.Ticket{Ticket: ticket}},
+ },
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }, nil
+}
diff --git a/go/arrow/flight/flightsql/driver/utils.go b/go/arrow/flight/flightsql/driver/utils.go
new file mode 100644
index 0000000000..f8f1a0e86a
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/utils.go
@@ -0,0 +1,272 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package driver
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "github.com/apache/arrow/go/v12/arrow"
+ "github.com/apache/arrow/go/v12/arrow/array"
+)
+
+// *** GRPC helpers ***
+type grpcCredentials struct {
+ username string
+ password string
+ token string
+ params map[string]string
+}
+
+func (g grpcCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
+ md := make(map[string]string, len(g.params)+1)
+
+ // Authentication parameters
+ switch {
+ case g.token != "":
+ md["authorization"] = "Bearer " + g.token
+ case g.username != "":
+
+ md["authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(g.username+":"+g.password))
+ }
+
+ for k, v := range g.params {
+ md[k] = v
+ }
+
+ return md, nil
+}
+
+func (g grpcCredentials) RequireTransportSecurity() bool {
+ return g.token != "" || g.username != ""
+}
+
+// *** Type conversions ***
+func fromArrowType(arr arrow.Array, idx int) (interface{}, error) {
+ switch c := arr.(type) {
+ case *array.Boolean:
+ return c.Value(idx), nil
+ case *array.Float16:
+ return float64(c.Value(idx).Float32()), nil
+ case *array.Float32:
+ return float64(c.Value(idx)), nil
+ case *array.Float64:
+ return c.Value(idx), nil
+ case *array.Int8:
+ return int64(c.Value(idx)), nil
+ case *array.Int16:
+ return int64(c.Value(idx)), nil
+ case *array.Int32:
+ return int64(c.Value(idx)), nil
+ case *array.Int64:
+ return c.Value(idx), nil
+ case *array.String:
+ return c.Value(idx), nil
+ case *array.Time32:
+ dt, ok := arr.DataType().(*arrow.Time32Type)
+ if !ok {
+ return nil, fmt.Errorf("datatype %T not matching time32", arr.DataType())
+ }
+ v := c.Value(idx)
+ return v.ToTime(dt.TimeUnit()), nil
+ case *array.Time64:
+ dt, ok := arr.DataType().(*arrow.Time64Type)
+ if !ok {
+ return nil, fmt.Errorf("datatype %T not matching time64", arr.DataType())
+ }
+ v := c.Value(idx)
+ return v.ToTime(dt.TimeUnit()), nil
+ case *array.Timestamp:
+ dt, ok := arr.DataType().(*arrow.TimestampType)
+ if !ok {
+ return nil, fmt.Errorf("datatype %T not matching timestamp", arr.DataType())
+ }
+ v := c.Value(idx)
+ return v.ToTime(dt.TimeUnit()), nil
+ }
+
+ return nil, fmt.Errorf("type %T: %w", arr, ErrNotSupported)
+}
+
+func toArrowDataType(value interface{}) (arrow.DataType, error) {
+ switch value.(type) {
+ case bool:
+ return &arrow.BooleanType{}, nil
+ case float32:
+ return &arrow.Float32Type{}, nil
+ case float64:
+ return &arrow.Float64Type{}, nil
+ case int8:
+ return &arrow.Int8Type{}, nil
+ case int16:
+ return &arrow.Int16Type{}, nil
+ case int32:
+ return &arrow.Int32Type{}, nil
+ case int64:
+ return &arrow.Int64Type{}, nil
+ case uint8:
+ return &arrow.Uint8Type{}, nil
+ case uint16:
+ return &arrow.Uint16Type{}, nil
+ case uint32:
+ return &arrow.Uint32Type{}, nil
+ case uint64:
+ return &arrow.Uint64Type{}, nil
+ case string:
+ return &arrow.StringType{}, nil
+ case time.Time:
+ return &arrow.Time64Type{Unit: arrow.Nanosecond}, nil
+ }
+ return nil, fmt.Errorf("type %T: %w", value, ErrNotSupported)
+}
+
+// *** Field builder versions ***
+func setFieldValue(builder array.Builder, arg interface{}) error {
+ switch b := builder.(type) {
+ case *array.BooleanBuilder:
+ switch v := arg.(type) {
+ case bool:
+ b.Append(v)
+ case []bool:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Float32Builder:
+ switch v := arg.(type) {
+ case float32:
+ b.Append(v)
+ case []float32:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Float64Builder:
+ switch v := arg.(type) {
+ case float64:
+ b.Append(v)
+ case []float64:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Int8Builder:
+ switch v := arg.(type) {
+ case int8:
+ b.Append(v)
+ case []int8:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Int16Builder:
+ switch v := arg.(type) {
+ case int16:
+ b.Append(v)
+ case []int16:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Int32Builder:
+ switch v := arg.(type) {
+ case int32:
+ b.Append(v)
+ case []int32:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Int64Builder:
+ switch v := arg.(type) {
+ case int64:
+ b.Append(v)
+ case []int64:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Uint8Builder:
+ switch v := arg.(type) {
+ case uint8:
+ b.Append(v)
+ case []uint8:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Uint16Builder:
+ switch v := arg.(type) {
+ case uint16:
+ b.Append(v)
+ case []uint16:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Uint32Builder:
+ switch v := arg.(type) {
+ case uint32:
+ b.Append(v)
+ case []uint32:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Uint64Builder:
+ switch v := arg.(type) {
+ case uint64:
+ b.Append(v)
+ case []uint64:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.StringBuilder:
+ switch v := arg.(type) {
+ case string:
+ b.Append(v)
+ case []string:
+ b.AppendValues(v, nil)
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ case *array.Time64Builder:
+ switch v := arg.(type) {
+ case int64:
+ b.Append(arrow.Time64(v))
+ case []int64:
+ for _, x := range v {
+ b.Append(arrow.Time64(x))
+ }
+ case uint64:
+ b.Append(arrow.Time64(v))
+ case []uint64:
+ for _, x := range v {
+ b.Append(arrow.Time64(x))
+ }
+ case time.Time:
+ b.Append(arrow.Time64(v.Nanosecond()))
+ default:
+ return fmt.Errorf("invalid value type %T for builder %T", arg, builder)
+ }
+ default:
+ return fmt.Errorf("unknown builder type %T", builder)
+ }
+ return nil
+}