You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2023/01/24 21:01:15 UTC
[arrow-adbc] branch main updated: feat(go/adbc/driver/flightsql): implement more connection options (#381)
This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new bc29a37 feat(go/adbc/driver/flightsql): implement more connection options (#381)
bc29a37 is described below
commit bc29a377d225e75944b20b80e104095cd05c352e
Author: David Li <li...@gmail.com>
AuthorDate: Tue Jan 24 16:01:09 2023 -0500
feat(go/adbc/driver/flightsql): implement more connection options (#381)
* feat(go/adbc/driver/flightsql): implement more connection options
Part of #380.
* Use switch
* Document user/pass
* Update docs/source/driver/go/flight_sql.rst
Co-authored-by: Matt Topol <zo...@gmail.com>
Co-authored-by: Matt Topol <zo...@gmail.com>
---
docs/source/driver/go/flight_sql.rst | 117 ++++++++--------------
go/adbc/driver/flightsql/flightsql_adbc.go | 99 +++++++++++++-----
go/adbc/driver/flightsql/flightsql_adbc_test.go | 127 ++++++++++++++++++++++++
go/adbc/driver/flightsql/flightsql_statement.go | 4 +-
4 files changed, 245 insertions(+), 102 deletions(-)
diff --git a/docs/source/driver/go/flight_sql.rst b/docs/source/driver/go/flight_sql.rst
index 76c543e..1768708 100644
--- a/docs/source/driver/go/flight_sql.rst
+++ b/docs/source/driver/go/flight_sql.rst
@@ -32,6 +32,10 @@ The Flight SQL driver is shipped as a standalone library.
.. tab-item:: Go
:sync: go
+ .. code-block:: shell
+
+ go get github.com/apache/arrow-adbc/go
+
.. tab-item:: Python
:sync: python
@@ -57,6 +61,7 @@ the :cpp:class:`AdbcDatabase`.
// Ignoring error handling
struct AdbcDatabase database;
AdbcDatabaseNew(&database, nullptr);
+ AdbcDatabaseSetOption(&database, "driver", "adbc_driver_flightsql", nullptr);
AdbcDatabaseSetOption(&database, "uri", "grpc://localhost:8080", nullptr);
AdbcDatabaseInit(&database, nullptr);
@@ -65,10 +70,9 @@ the :cpp:class:`AdbcDatabase`.
.. code-block:: python
- import pyarrow.flight_sql
-
+ import adbc_driver_flightsql.dbapi
- with pyarrow.flight_sql.connect("grpc://localhost:8080") as conn:
+ with adbc_driver_flightsql.dbapi.connect("grpc://localhost:8080") as conn:
pass
Supported Features
@@ -80,33 +84,28 @@ API specification 1.0.0, as well as some additional, custom options.
Authentication
--------------
-The driver does no authentication by default.
+The driver does no authentication by default. The driver implements a
+few optional authentication schemes:
-The driver implements one optional authentication scheme that mimics
-the Arrow Flight SQL JDBC driver. This can be enabled by setting the
-option ``arrow.flight.sql.authorization_header`` on the
-:cpp:class:`AdbcDatabase`. The client provides credentials by setting
-the option value to the value of the ``authorization`` header sent
-from client to server. The server then responds with an
-``authorization`` header on the first request. The value of this
-header will then be sent back as the ``authorization`` header on all
-future requests.
+- Mutual TLS (mTLS): see "Client Options" below.
+- An HTTP-style scheme mimicking the Arrow Flight SQL JDBC driver.
+
+ Set the options ``username`` and ``password`` on the
+ :cpp:class:`AdbcDatabase`. Alternatively, set the option
+ ``arrow.flight.sql.authorization_header`` for full control.
+
+ The client provides credentials sending an ``authorization`` from
+ client to server. The server then responds with an
+ ``authorization`` header on the first request. The value of this
+ header will then be sent back as the ``authorization`` header on all
+ future requests.
Bulk Ingestion
--------------
Flight SQL does not have a dedicated API for bulk ingestion of Arrow
-data into a given table. The driver instead constructs SQL statements
-to create and insert into the table.
-
-.. warning:: The driver does not escape or validate the names of
- tables or columns. As a precaution, it instead limits
- identifier names to letters, numbers, and underscores.
- Bulk ingestion should not be used with untrusted user
- input.
-
-The driver binds a batch of data at a time for efficiency. Also, the
-generated SQL statements hardcode ``?`` as the parameter identifier.
+data into a given table. The driver does not currently implement bulk
+ingestion as a result.
Client Options
--------------
@@ -114,18 +113,22 @@ Client Options
The options used for creating the Flight RPC client can be customized.
These options map 1:1 with the options in FlightClientOptions:
-``arrow.flight.sql.client_option.tls_root_certs``
- Override the root certificates used to validate the server's TLS
- certificate.
+``arrow.flight.sql.client_option.mtls_cert_chain``
+ The certificate chain to use for mTLS.
+
+``arrow.flight.sql.client_option.mtls_private_key``
+ The private key to use for mTLS.
-``arrow.flight.sql.client_option.override_hostname``
+``arrow.flight.sql.client_option.tls_override_hostname``
Override the hostname used to verify the server's TLS certificate.
-``arrow.flight.sql.client_option.cert_chain``
- The certificate chain to use for mTLS.
+``arrow.flight.sql.client_option.tls_skip_verify``
+ Disable verification of the server's TLS certificate. Value
+ should be ``true`` or ``false``.
-``arrow.flight.sql.client_option.private_key``
- The private key to use for mTLS.
+``arrow.flight.sql.client_option.tls_root_certs``
+ Override the root certificates used to validate the server's TLS
+ certificate.
``arrow.flight.sql.client_option.generic_int_option.<OPTION_NAME>``
Option prefixes used to specify generic transport-layer options.
@@ -133,10 +136,6 @@ These options map 1:1 with the options in FlightClientOptions:
``arrow.flight.sql.client_option.generic_string_option.<OPTION_NAME>``
Option prefixes used to specify generic transport-layer options.
-``arrow.flight.sql.client_option.disable_server_verification``
- Disable verification of the server's TLS certificate. Value
- should be ``true`` or ``false``.
-
Custom Call Headers
-------------------
@@ -221,47 +220,9 @@ The options are as follows:
Transactions
------------
-The driver will issue transaction RPCs, but the driver will not check
-the server's SqlInfo to determine whether this is supported first.
-
-Type Mapping
-------------
-
-When executing a bulk ingestion operation, the driver needs to be able
-to construct appropriate SQL queries for the database. (The driver
-does not currently support using Substrait plans instead.) In
-particular, a mapping from Arrow types to SQL type names is required.
-While a default mapping is provided, the client may wish to override
-this mapping, which can be done by setting special options on
-:cpp:class:`AdbcDatabase`. (The driver does not currently inspect
-Flight SQL metadata to construct this mapping.)
-
-All such options begin with ``arrow.flight.sql.quirks.ingest_type.``
-and are followed by a type name below.
-
-.. warning:: The driver does **not** escape or validate the values
- here. They should not come from untrusted user input, or
- a SQL injection vulnerability may result.
-
-.. csv-table:: Type Names
- :header: "Arrow Type Name", "Default SQL Type Name"
-
- binary,BLOB
- bool,BOOLEAN
- date32,DATE
- date64,DATE
- decimal128,NUMERIC
- decimal256,NUMERIC
- double,DOUBLE PRECISION
- float,REAL
- int16,SMALLINT
- int32,INT
- int64,BIGINT
- large_binary,BLOB
- large_string,TEXT
- string,TEXT
- time32,TIME
- time64,TIME
- timestamp,TIMESTAMP
+The driver supports transactions. It will first check the server's
+SqlInfo to determine whether this is supported. Otherwise,
+transaction-related ADBC APIs will return
+:c:type:`ADBC_STATUS_NOT_IMPLEMENTED`.
.. _DBAPI 2.0: https://peps.python.org/pep-0249/
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go
index b959f32..6c46e9d 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -34,6 +34,7 @@ package flightsql
import (
"context"
"crypto/tls"
+ "crypto/x509"
"fmt"
"io"
"net/url"
@@ -57,8 +58,11 @@ import (
)
const (
- OptionSSLSkipVerify = "adbc.flight.sql.client_option.tls_skip_verify"
- OptionSSLCertFile = "adbc.flight.sql.client_option.tls_root_certs"
+ OptionMTLSCertChain = "adbc.flight.sql.client_option.mtls_cert_chain"
+ OptionMTLSPrivateKey = "adbc.flight.sql.client_option.mtls_private_key"
+ OptionSSLOverrideHostname = "adbc.flight.sql.client_option.tls_override_hostname"
+ OptionSSLSkipVerify = "adbc.flight.sql.client_option.tls_skip_verify"
+ OptionSSLRootCerts = "adbc.flight.sql.client_option.tls_root_certs"
infoDriverName = "ADBC Flight SQL Driver - Go"
)
@@ -86,6 +90,7 @@ type Driver struct {
}
func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
+ opts = maps.Clone(opts)
uri, ok := opts[adbc.OptionKeyURI]
if !ok {
return nil, adbc.Error{
@@ -93,6 +98,7 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
Code: adbc.StatusInvalidArgument,
}
}
+ delete(opts, adbc.OptionKeyURI)
db := &database{alloc: d.Alloc}
if db.alloc == nil {
@@ -116,47 +122,82 @@ type database struct {
}
func (d *database) SetOptions(cnOptions map[string]string) error {
- if d.uri.Scheme == "grpc+tls" {
- d.creds = credentials.NewTLS(&tls.Config{})
- } else {
- d.creds = insecure.NewCredentials()
- }
+ var tlsConfig tls.Config
- if val, ok := cnOptions[OptionSSLSkipVerify]; ok && val == adbc.OptionValueEnabled {
- if d.uri.Scheme != "grpc+tls" {
+ mtlsCert := cnOptions[OptionMTLSCertChain]
+ mtlsKey := cnOptions[OptionMTLSPrivateKey]
+ switch {
+ case mtlsCert != "" && mtlsKey != "":
+ cert, err := tls.X509KeyPair([]byte(mtlsCert), []byte(mtlsKey))
+ if err != nil {
return adbc.Error{
- Msg: "Connection is not TLS-enabled",
+ Msg: fmt.Sprintf("Invalid mTLS certificate: %#v", err),
Code: adbc.StatusInvalidArgument,
}
}
- d.creds = credentials.NewTLS(&tls.Config{InsecureSkipVerify: true})
+ tlsConfig.Certificates = []tls.Certificate{cert}
+ delete(cnOptions, OptionMTLSCertChain)
+ delete(cnOptions, OptionMTLSPrivateKey)
+ case mtlsCert != "":
+ return adbc.Error{
+ Msg: fmt.Sprintf("Must provide both '%s' and '%s', only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey, OptionMTLSCertChain),
+ Code: adbc.StatusInvalidArgument,
+ }
+ case mtlsKey != "":
+ return adbc.Error{
+ Msg: fmt.Sprintf("Must provide both '%s' and '%s', only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey, OptionMTLSPrivateKey),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ if hostname, ok := cnOptions[OptionSSLOverrideHostname]; ok {
+ tlsConfig.ServerName = hostname
+ delete(cnOptions, OptionSSLOverrideHostname)
}
- // option specified path to certificate file
- if cert, ok := cnOptions[OptionSSLCertFile]; ok {
- if d.uri.Scheme != "grpc+tls" {
+ if val, ok := cnOptions[OptionSSLSkipVerify]; ok {
+ if val == adbc.OptionValueEnabled {
+ tlsConfig.InsecureSkipVerify = true
+ } else if val == adbc.OptionValueDisabled {
+ tlsConfig.InsecureSkipVerify = false
+ } else {
return adbc.Error{
- Msg: "Connection is not TLS-enabled",
+ Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, val),
Code: adbc.StatusInvalidArgument,
}
}
+ delete(cnOptions, OptionSSLSkipVerify)
+ }
- c, err := credentials.NewClientTLSFromFile(cert, "")
- if err != nil {
+ if cert, ok := cnOptions[OptionSSLRootCerts]; ok {
+ cp := x509.NewCertPool()
+ if !cp.AppendCertsFromPEM([]byte(cert)) {
return adbc.Error{
- Msg: "invalid SSL certificate passed",
+ Msg: fmt.Sprintf("Invalid value for database option '%s': failed to append certificates", OptionSSLRootCerts),
Code: adbc.StatusInvalidArgument,
}
}
- d.creds = c
+ tlsConfig.RootCAs = cp
+ delete(cnOptions, OptionSSLRootCerts)
}
+ d.creds = credentials.NewTLS(&tlsConfig)
+
if u, ok := cnOptions[adbc.OptionKeyUsername]; ok {
d.user = u
+ delete(cnOptions, adbc.OptionKeyUsername)
}
if p, ok := cnOptions[adbc.OptionKeyPassword]; ok {
d.pass = p
+ delete(cnOptions, adbc.OptionKeyPassword)
+ }
+
+ for key := range cnOptions {
+ return adbc.Error{
+ Msg: fmt.Sprintf("Unknown database option '%s'", key),
+ Code: adbc.StatusInvalidArgument,
+ }
}
return nil
@@ -181,9 +222,13 @@ func getFlightClient(ctx context.Context, loc string, d *database) (*flightsql.C
if err != nil {
return nil, adbc.Error{Msg: fmt.Sprintf("Invalid URI '%s': %s", loc, err), Code: adbc.StatusInvalidArgument}
}
+ creds := d.creds
+ if uri.Scheme == "grpc" || uri.Scheme == "grpc+tcp" {
+ creds = insecure.NewCredentials()
+ }
cl, err := flightsql.NewClient(uri.Host, nil, []flight.ClientMiddleware{
- flight.CreateClientMiddleware(authMiddle)}, grpc.WithTransportCredentials(d.creds))
+ flight.CreateClientMiddleware(authMiddle)}, grpc.WithTransportCredentials(creds))
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
@@ -611,15 +656,23 @@ func (c *cnxn) Close() error {
//
// A partition can be retrieved by using ExecutePartitions on a statement.
func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (rdr array.RecordReader, err error) {
- var endpoint flight.FlightEndpoint
- if err := proto.Unmarshal(serializedPartition, &endpoint); err != nil {
+ var info flight.FlightInfo
+ if err := proto.Unmarshal(serializedPartition, &info); err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInvalidArgument,
}
}
- rdr, err = doGet(ctx, c.cl, &endpoint, c.clientCache)
+ // The driver only ever returns one endpoint.
+ if len(info.Endpoint) != 1 {
+ return nil, adbc.Error{
+ Msg: fmt.Sprintf("Invalid partition: expected 1 endpoint, got %d", len(info.Endpoint)),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ rdr, err = doGet(ctx, c.cl, info.Endpoint[0], c.clientCache)
if err != nil {
return nil, adbcFromFlightStatus(err)
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index bdbfacc..2a860be 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -36,6 +36,7 @@ import (
"github.com/apache/arrow/go/v11/arrow/memory"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
+ "google.golang.org/protobuf/proto"
)
type FlightSQLQuirks struct {
@@ -203,4 +204,130 @@ func TestADBCFlightSQL(t *testing.T) {
suite.Run(t, &validation.DatabaseTests{Quirks: q})
suite.Run(t, &validation.ConnectionTests{Quirks: q})
suite.Run(t, &validation.StatementTests{Quirks: q})
+
+ suite.Run(t, &PartitionTests{Quirks: q})
+ suite.Run(t, &SSLTests{Quirks: q})
+}
+
+// Driver-specific tests
+
+type SSLTests struct {
+ suite.Suite
+
+ Driver adbc.Driver
+ Quirks validation.DriverQuirks
+}
+
+func (suite *SSLTests) SetupTest() {
+ suite.Driver = suite.Quirks.SetupDriver(suite.T())
+}
+
+func (suite *SSLTests) TearDownTest() {
+ suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
+ suite.Driver = nil
+}
+
+func (suite *SSLTests) TestMutualTLS() {
+ // Just checks that the option is accepted - doesn't actually configure TLS
+ options := suite.Quirks.DatabaseOptions()
+
+ options["adbc.flight.sql.client_option.mtls_cert_chain"] = "certs"
+ _, err := suite.Driver.NewDatabase(options)
+ suite.Require().ErrorContains(err, "Must provide both")
+
+ options["adbc.flight.sql.client_option.mtls_private_key"] = "key"
+ _, err = suite.Driver.NewDatabase(options)
+ suite.Require().ErrorContains(err, "Invalid mTLS certificate")
+
+ delete(options, "adbc.flight.sql.client_option.mtls_cert_chain")
+ _, err = suite.Driver.NewDatabase(options)
+ suite.Require().ErrorContains(err, "Must provide both")
+}
+
+func (suite *SSLTests) TestOverrideHostname() {
+ // Just checks that the option is accepted - doesn't actually configure TLS
+ options := suite.Quirks.DatabaseOptions()
+ options["adbc.flight.sql.client_option.tls_override_hostname"] = "hostname"
+ _, err := suite.Driver.NewDatabase(options)
+ suite.Require().NoError(err)
+}
+
+func (suite *SSLTests) TestRootCerts() {
+ // Just checks that the option is accepted - doesn't actually configure TLS
+ options := suite.Quirks.DatabaseOptions()
+ options["adbc.flight.sql.client_option.tls_root_certs"] = "these are not valid certs"
+ _, err := suite.Driver.NewDatabase(options)
+ suite.Require().ErrorContains(err, "Invalid value for database option 'adbc.flight.sql.client_option.tls_root_certs': failed to append certificates")
+}
+
+func (suite *SSLTests) TestSkipVerify() {
+ options := suite.Quirks.DatabaseOptions()
+ options["adbc.flight.sql.client_option.tls_skip_verify"] = "true"
+ _, err := suite.Driver.NewDatabase(options)
+ suite.Require().NoError(err)
+
+ options = suite.Quirks.DatabaseOptions()
+ options["adbc.flight.sql.client_option.tls_skip_verify"] = "false"
+ _, err = suite.Driver.NewDatabase(options)
+ suite.Require().NoError(err)
+
+ options = suite.Quirks.DatabaseOptions()
+ options["adbc.flight.sql.client_option.tls_skip_verify"] = "invalid"
+ _, err = suite.Driver.NewDatabase(options)
+ suite.Require().ErrorContains(err, "Invalid value for database option 'adbc.flight.sql.client_option.tls_skip_verify': 'invalid'")
+}
+
+func (suite *SSLTests) TestUnknownOption() {
+ options := suite.Quirks.DatabaseOptions()
+ options["unknown option"] = "unknown value"
+ _, err := suite.Driver.NewDatabase(options)
+ suite.Require().ErrorContains(err, "Unknown database option 'unknown option'")
+}
+
+type PartitionTests struct {
+ suite.Suite
+
+ Driver adbc.Driver
+ Quirks validation.DriverQuirks
+
+ DB adbc.Database
+ Cnxn adbc.Connection
+ ctx context.Context
+}
+
+func (suite *PartitionTests) SetupTest() {
+ suite.Driver = suite.Quirks.SetupDriver(suite.T())
+ var err error
+ suite.DB, err = suite.Driver.NewDatabase(suite.Quirks.DatabaseOptions())
+ suite.Require().NoError(err)
+ suite.ctx = context.Background()
+ suite.Cnxn, err = suite.DB.Open(suite.ctx)
+ suite.Require().NoError(err)
+}
+
+func (suite *PartitionTests) TearDownTest() {
+ suite.Require().NoError(suite.Cnxn.Close())
+ suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
+ suite.Cnxn = nil
+ suite.DB = nil
+ suite.Driver = nil
+}
+
+func (suite *PartitionTests) TestIntrospectPartitions() {
+ stmt, err := suite.Cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer stmt.Close()
+
+ suite.Require().NoError(stmt.SetSqlQuery("SELECT 42"))
+
+ _, partitions, _, err := stmt.ExecutePartitions(context.Background())
+ suite.Require().NoError(err)
+ suite.Require().Equal(uint64(1), partitions.NumPartitions)
+
+ info := &flight.FlightInfo{}
+ suite.Require().NoError(proto.Unmarshal(partitions.PartitionIDs[0], info))
+ suite.Require().Equal(int64(-1), info.TotalBytes)
+ suite.Require().Equal(int64(-1), info.TotalRecords)
+ suite.Require().Equal(1, len(info.Endpoint))
+ suite.Require().Equal(0, len(info.Endpoint[0].Location))
}
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go
index 8d26df2..b251e68 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -279,7 +279,9 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
out.NumPartitions = uint64(len(info.Endpoint))
out.PartitionIDs = make([][]byte, out.NumPartitions)
for i, e := range info.Endpoint {
- data, err := proto.Marshal(e)
+ partition := proto.Clone(info).(*flight.FlightInfo)
+ partition.Endpoint = []*flight.FlightEndpoint{e}
+ data, err := proto.Marshal(partition)
if err != nil {
return sc, out, -1, adbc.Error{
Msg: err.Error(),