You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2023/01/24 21:01:15 UTC

[arrow-adbc] branch main updated: feat(go/adbc/driver/flightsql): implement more connection options (#381)

This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new bc29a37  feat(go/adbc/driver/flightsql): implement more connection options (#381)
bc29a37 is described below

commit bc29a377d225e75944b20b80e104095cd05c352e
Author: David Li <li...@gmail.com>
AuthorDate: Tue Jan 24 16:01:09 2023 -0500

    feat(go/adbc/driver/flightsql): implement more connection options (#381)
    
    * feat(go/adbc/driver/flightsql): implement more connection options
    
    Part of #380.
    
    * Use switch
    
    * Document user/pass
    
    * Update docs/source/driver/go/flight_sql.rst
    
    Co-authored-by: Matt Topol <zo...@gmail.com>
    
    Co-authored-by: Matt Topol <zo...@gmail.com>
---
 docs/source/driver/go/flight_sql.rst            | 117 ++++++++--------------
 go/adbc/driver/flightsql/flightsql_adbc.go      |  99 +++++++++++++-----
 go/adbc/driver/flightsql/flightsql_adbc_test.go | 127 ++++++++++++++++++++++++
 go/adbc/driver/flightsql/flightsql_statement.go |   4 +-
 4 files changed, 245 insertions(+), 102 deletions(-)

diff --git a/docs/source/driver/go/flight_sql.rst b/docs/source/driver/go/flight_sql.rst
index 76c543e..1768708 100644
--- a/docs/source/driver/go/flight_sql.rst
+++ b/docs/source/driver/go/flight_sql.rst
@@ -32,6 +32,10 @@ The Flight SQL driver is shipped as a standalone library.
    .. tab-item:: Go
       :sync: go
 
+      .. code-block:: shell
+
+         go get github.com/apache/arrow-adbc/go
+
    .. tab-item:: Python
       :sync: python
 
@@ -57,6 +61,7 @@ the :cpp:class:`AdbcDatabase`.
          // Ignoring error handling
          struct AdbcDatabase database;
          AdbcDatabaseNew(&database, nullptr);
+         AdbcDatabaseSetOption(&database, "driver", "adbc_driver_flightsql", nullptr);
          AdbcDatabaseSetOption(&database, "uri", "grpc://localhost:8080", nullptr);
          AdbcDatabaseInit(&database, nullptr);
 
@@ -65,10 +70,9 @@ the :cpp:class:`AdbcDatabase`.
 
       .. code-block:: python
 
-         import pyarrow.flight_sql
-
+         import adbc_driver_flightsql.dbapi
 
-         with pyarrow.flight_sql.connect("grpc://localhost:8080") as conn:
+         with adbc_driver_flightsql.dbapi.connect("grpc://localhost:8080") as conn:
              pass
 
 Supported Features
@@ -80,33 +84,28 @@ API specification 1.0.0, as well as some additional, custom options.
 Authentication
 --------------
 
-The driver does no authentication by default.
+The driver does no authentication by default.  The driver implements a
+few optional authentication schemes:
 
-The driver implements one optional authentication scheme that mimics
-the Arrow Flight SQL JDBC driver.  This can be enabled by setting the
-option ``arrow.flight.sql.authorization_header`` on the
-:cpp:class:`AdbcDatabase`.  The client provides credentials by setting
-the option value to the value of the ``authorization`` header sent
-from client to server.  The server then responds with an
-``authorization`` header on the first request.  The value of this
-header will then be sent back as the ``authorization`` header on all
-future requests.
+- Mutual TLS (mTLS): see "Client Options" below.
+- An HTTP-style scheme mimicking the Arrow Flight SQL JDBC driver.
+
+  Set the options ``username`` and ``password`` on the
+  :cpp:class:`AdbcDatabase`.  Alternatively, set the option
+  ``arrow.flight.sql.authorization_header`` for full control.
+
+  The client provides credentials sending an ``authorization`` from
+  client to server.  The server then responds with an
+  ``authorization`` header on the first request.  The value of this
+  header will then be sent back as the ``authorization`` header on all
+  future requests.
 
 Bulk Ingestion
 --------------
 
 Flight SQL does not have a dedicated API for bulk ingestion of Arrow
-data into a given table.  The driver instead constructs SQL statements
-to create and insert into the table.
-
-.. warning:: The driver does not escape or validate the names of
-             tables or columns.  As a precaution, it instead limits
-             identifier names to letters, numbers, and underscores.
-             Bulk ingestion should not be used with untrusted user
-             input.
-
-The driver binds a batch of data at a time for efficiency.  Also, the
-generated SQL statements hardcode ``?`` as the parameter identifier.
+data into a given table.  The driver does not currently implement bulk
+ingestion as a result.
 
 Client Options
 --------------
@@ -114,18 +113,22 @@ Client Options
 The options used for creating the Flight RPC client can be customized.
 These options map 1:1 with the options in FlightClientOptions:
 
-``arrow.flight.sql.client_option.tls_root_certs``
-    Override the root certificates used to validate the server's TLS
-    certificate.
+``arrow.flight.sql.client_option.mtls_cert_chain``
+    The certificate chain to use for mTLS.
+
+``arrow.flight.sql.client_option.mtls_private_key``
+    The private key to use for mTLS.
 
-``arrow.flight.sql.client_option.override_hostname``
+``arrow.flight.sql.client_option.tls_override_hostname``
     Override the hostname used to verify the server's TLS certificate.
 
-``arrow.flight.sql.client_option.cert_chain``
-    The certificate chain to use for mTLS.
+``arrow.flight.sql.client_option.tls_skip_verify``
+    Disable verification of the server's TLS certificate.  Value
+    should be ``true`` or ``false``.
 
-``arrow.flight.sql.client_option.private_key``
-    The private key to use for mTLS.
+``arrow.flight.sql.client_option.tls_root_certs``
+    Override the root certificates used to validate the server's TLS
+    certificate.
 
 ``arrow.flight.sql.client_option.generic_int_option.<OPTION_NAME>``
     Option prefixes used to specify generic transport-layer options.
@@ -133,10 +136,6 @@ These options map 1:1 with the options in FlightClientOptions:
 ``arrow.flight.sql.client_option.generic_string_option.<OPTION_NAME>``
     Option prefixes used to specify generic transport-layer options.
 
-``arrow.flight.sql.client_option.disable_server_verification``
-    Disable verification of the server's TLS certificate.  Value
-    should be ``true`` or ``false``.
-
 Custom Call Headers
 -------------------
 
@@ -221,47 +220,9 @@ The options are as follows:
 Transactions
 ------------
 
-The driver will issue transaction RPCs, but the driver will not check
-the server's SqlInfo to determine whether this is supported first.
-
-Type Mapping
-------------
-
-When executing a bulk ingestion operation, the driver needs to be able
-to construct appropriate SQL queries for the database.  (The driver
-does not currently support using Substrait plans instead.)  In
-particular, a mapping from Arrow types to SQL type names is required.
-While a default mapping is provided, the client may wish to override
-this mapping, which can be done by setting special options on
-:cpp:class:`AdbcDatabase`.  (The driver does not currently inspect
-Flight SQL metadata to construct this mapping.)
-
-All such options begin with ``arrow.flight.sql.quirks.ingest_type.``
-and are followed by a type name below.
-
-.. warning:: The driver does **not** escape or validate the values
-             here.  They should not come from untrusted user input, or
-             a SQL injection vulnerability may result.
-
-.. csv-table:: Type Names
-   :header: "Arrow Type Name", "Default SQL Type Name"
-
-   binary,BLOB
-   bool,BOOLEAN
-   date32,DATE
-   date64,DATE
-   decimal128,NUMERIC
-   decimal256,NUMERIC
-   double,DOUBLE PRECISION
-   float,REAL
-   int16,SMALLINT
-   int32,INT
-   int64,BIGINT
-   large_binary,BLOB
-   large_string,TEXT
-   string,TEXT
-   time32,TIME
-   time64,TIME
-   timestamp,TIMESTAMP
+The driver supports transactions.  It will first check the server's
+SqlInfo to determine whether this is supported.  Otherwise,
+transaction-related ADBC APIs will return
+:c:type:`ADBC_STATUS_NOT_IMPLEMENTED`.
 
 .. _DBAPI 2.0: https://peps.python.org/pep-0249/
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go
index b959f32..6c46e9d 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -34,6 +34,7 @@ package flightsql
 import (
 	"context"
 	"crypto/tls"
+	"crypto/x509"
 	"fmt"
 	"io"
 	"net/url"
@@ -57,8 +58,11 @@ import (
 )
 
 const (
-	OptionSSLSkipVerify = "adbc.flight.sql.client_option.tls_skip_verify"
-	OptionSSLCertFile   = "adbc.flight.sql.client_option.tls_root_certs"
+	OptionMTLSCertChain       = "adbc.flight.sql.client_option.mtls_cert_chain"
+	OptionMTLSPrivateKey      = "adbc.flight.sql.client_option.mtls_private_key"
+	OptionSSLOverrideHostname = "adbc.flight.sql.client_option.tls_override_hostname"
+	OptionSSLSkipVerify       = "adbc.flight.sql.client_option.tls_skip_verify"
+	OptionSSLRootCerts        = "adbc.flight.sql.client_option.tls_root_certs"
 
 	infoDriverName = "ADBC Flight SQL Driver - Go"
 )
@@ -86,6 +90,7 @@ type Driver struct {
 }
 
 func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
+	opts = maps.Clone(opts)
 	uri, ok := opts[adbc.OptionKeyURI]
 	if !ok {
 		return nil, adbc.Error{
@@ -93,6 +98,7 @@ func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
 			Code: adbc.StatusInvalidArgument,
 		}
 	}
+	delete(opts, adbc.OptionKeyURI)
 
 	db := &database{alloc: d.Alloc}
 	if db.alloc == nil {
@@ -116,47 +122,82 @@ type database struct {
 }
 
 func (d *database) SetOptions(cnOptions map[string]string) error {
-	if d.uri.Scheme == "grpc+tls" {
-		d.creds = credentials.NewTLS(&tls.Config{})
-	} else {
-		d.creds = insecure.NewCredentials()
-	}
+	var tlsConfig tls.Config
 
-	if val, ok := cnOptions[OptionSSLSkipVerify]; ok && val == adbc.OptionValueEnabled {
-		if d.uri.Scheme != "grpc+tls" {
+	mtlsCert := cnOptions[OptionMTLSCertChain]
+	mtlsKey := cnOptions[OptionMTLSPrivateKey]
+	switch {
+	case mtlsCert != "" && mtlsKey != "":
+		cert, err := tls.X509KeyPair([]byte(mtlsCert), []byte(mtlsKey))
+		if err != nil {
 			return adbc.Error{
-				Msg:  "Connection is not TLS-enabled",
+				Msg:  fmt.Sprintf("Invalid mTLS certificate: %#v", err),
 				Code: adbc.StatusInvalidArgument,
 			}
 		}
-		d.creds = credentials.NewTLS(&tls.Config{InsecureSkipVerify: true})
+		tlsConfig.Certificates = []tls.Certificate{cert}
+		delete(cnOptions, OptionMTLSCertChain)
+		delete(cnOptions, OptionMTLSPrivateKey)
+	case mtlsCert != "":
+		return adbc.Error{
+			Msg:  fmt.Sprintf("Must provide both '%s' and '%s', only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey, OptionMTLSCertChain),
+			Code: adbc.StatusInvalidArgument,
+		}
+	case mtlsKey != "":
+		return adbc.Error{
+			Msg:  fmt.Sprintf("Must provide both '%s' and '%s', only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey, OptionMTLSPrivateKey),
+			Code: adbc.StatusInvalidArgument,
+		}
+	}
+
+	if hostname, ok := cnOptions[OptionSSLOverrideHostname]; ok {
+		tlsConfig.ServerName = hostname
+		delete(cnOptions, OptionSSLOverrideHostname)
 	}
 
-	// option specified path to certificate file
-	if cert, ok := cnOptions[OptionSSLCertFile]; ok {
-		if d.uri.Scheme != "grpc+tls" {
+	if val, ok := cnOptions[OptionSSLSkipVerify]; ok {
+		if val == adbc.OptionValueEnabled {
+			tlsConfig.InsecureSkipVerify = true
+		} else if val == adbc.OptionValueDisabled {
+			tlsConfig.InsecureSkipVerify = false
+		} else {
 			return adbc.Error{
-				Msg:  "Connection is not TLS-enabled",
+				Msg:  fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionSSLSkipVerify, val),
 				Code: adbc.StatusInvalidArgument,
 			}
 		}
+		delete(cnOptions, OptionSSLSkipVerify)
+	}
 
-		c, err := credentials.NewClientTLSFromFile(cert, "")
-		if err != nil {
+	if cert, ok := cnOptions[OptionSSLRootCerts]; ok {
+		cp := x509.NewCertPool()
+		if !cp.AppendCertsFromPEM([]byte(cert)) {
 			return adbc.Error{
-				Msg:  "invalid SSL certificate passed",
+				Msg:  fmt.Sprintf("Invalid value for database option '%s': failed to append certificates", OptionSSLRootCerts),
 				Code: adbc.StatusInvalidArgument,
 			}
 		}
-		d.creds = c
+		tlsConfig.RootCAs = cp
+		delete(cnOptions, OptionSSLRootCerts)
 	}
 
+	d.creds = credentials.NewTLS(&tlsConfig)
+
 	if u, ok := cnOptions[adbc.OptionKeyUsername]; ok {
 		d.user = u
+		delete(cnOptions, adbc.OptionKeyUsername)
 	}
 
 	if p, ok := cnOptions[adbc.OptionKeyPassword]; ok {
 		d.pass = p
+		delete(cnOptions, adbc.OptionKeyPassword)
+	}
+
+	for key := range cnOptions {
+		return adbc.Error{
+			Msg:  fmt.Sprintf("Unknown database option '%s'", key),
+			Code: adbc.StatusInvalidArgument,
+		}
 	}
 
 	return nil
@@ -181,9 +222,13 @@ func getFlightClient(ctx context.Context, loc string, d *database) (*flightsql.C
 	if err != nil {
 		return nil, adbc.Error{Msg: fmt.Sprintf("Invalid URI '%s': %s", loc, err), Code: adbc.StatusInvalidArgument}
 	}
+	creds := d.creds
+	if uri.Scheme == "grpc" || uri.Scheme == "grpc+tcp" {
+		creds = insecure.NewCredentials()
+	}
 
 	cl, err := flightsql.NewClient(uri.Host, nil, []flight.ClientMiddleware{
-		flight.CreateClientMiddleware(authMiddle)}, grpc.WithTransportCredentials(d.creds))
+		flight.CreateClientMiddleware(authMiddle)}, grpc.WithTransportCredentials(creds))
 	if err != nil {
 		return nil, adbc.Error{
 			Msg:  err.Error(),
@@ -611,15 +656,23 @@ func (c *cnxn) Close() error {
 //
 // A partition can be retrieved by using ExecutePartitions on a statement.
 func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (rdr array.RecordReader, err error) {
-	var endpoint flight.FlightEndpoint
-	if err := proto.Unmarshal(serializedPartition, &endpoint); err != nil {
+	var info flight.FlightInfo
+	if err := proto.Unmarshal(serializedPartition, &info); err != nil {
 		return nil, adbc.Error{
 			Msg:  err.Error(),
 			Code: adbc.StatusInvalidArgument,
 		}
 	}
 
-	rdr, err = doGet(ctx, c.cl, &endpoint, c.clientCache)
+	// The driver only ever returns one endpoint.
+	if len(info.Endpoint) != 1 {
+		return nil, adbc.Error{
+			Msg:  fmt.Sprintf("Invalid partition: expected 1 endpoint, got %d", len(info.Endpoint)),
+			Code: adbc.StatusInvalidArgument,
+		}
+	}
+
+	rdr, err = doGet(ctx, c.cl, info.Endpoint[0], c.clientCache)
 	if err != nil {
 		return nil, adbcFromFlightStatus(err)
 	}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index bdbfacc..2a860be 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -36,6 +36,7 @@ import (
 	"github.com/apache/arrow/go/v11/arrow/memory"
 	"github.com/stretchr/testify/require"
 	"github.com/stretchr/testify/suite"
+	"google.golang.org/protobuf/proto"
 )
 
 type FlightSQLQuirks struct {
@@ -203,4 +204,130 @@ func TestADBCFlightSQL(t *testing.T) {
 	suite.Run(t, &validation.DatabaseTests{Quirks: q})
 	suite.Run(t, &validation.ConnectionTests{Quirks: q})
 	suite.Run(t, &validation.StatementTests{Quirks: q})
+
+	suite.Run(t, &PartitionTests{Quirks: q})
+	suite.Run(t, &SSLTests{Quirks: q})
+}
+
+// Driver-specific tests
+
+type SSLTests struct {
+	suite.Suite
+
+	Driver adbc.Driver
+	Quirks validation.DriverQuirks
+}
+
+func (suite *SSLTests) SetupTest() {
+	suite.Driver = suite.Quirks.SetupDriver(suite.T())
+}
+
+func (suite *SSLTests) TearDownTest() {
+	suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
+	suite.Driver = nil
+}
+
+func (suite *SSLTests) TestMutualTLS() {
+	// Just checks that the option is accepted - doesn't actually configure TLS
+	options := suite.Quirks.DatabaseOptions()
+
+	options["adbc.flight.sql.client_option.mtls_cert_chain"] = "certs"
+	_, err := suite.Driver.NewDatabase(options)
+	suite.Require().ErrorContains(err, "Must provide both")
+
+	options["adbc.flight.sql.client_option.mtls_private_key"] = "key"
+	_, err = suite.Driver.NewDatabase(options)
+	suite.Require().ErrorContains(err, "Invalid mTLS certificate")
+
+	delete(options, "adbc.flight.sql.client_option.mtls_cert_chain")
+	_, err = suite.Driver.NewDatabase(options)
+	suite.Require().ErrorContains(err, "Must provide both")
+}
+
+func (suite *SSLTests) TestOverrideHostname() {
+	// Just checks that the option is accepted - doesn't actually configure TLS
+	options := suite.Quirks.DatabaseOptions()
+	options["adbc.flight.sql.client_option.tls_override_hostname"] = "hostname"
+	_, err := suite.Driver.NewDatabase(options)
+	suite.Require().NoError(err)
+}
+
+func (suite *SSLTests) TestRootCerts() {
+	// Just checks that the option is accepted - doesn't actually configure TLS
+	options := suite.Quirks.DatabaseOptions()
+	options["adbc.flight.sql.client_option.tls_root_certs"] = "these are not valid certs"
+	_, err := suite.Driver.NewDatabase(options)
+	suite.Require().ErrorContains(err, "Invalid value for database option 'adbc.flight.sql.client_option.tls_root_certs': failed to append certificates")
+}
+
+func (suite *SSLTests) TestSkipVerify() {
+	options := suite.Quirks.DatabaseOptions()
+	options["adbc.flight.sql.client_option.tls_skip_verify"] = "true"
+	_, err := suite.Driver.NewDatabase(options)
+	suite.Require().NoError(err)
+
+	options = suite.Quirks.DatabaseOptions()
+	options["adbc.flight.sql.client_option.tls_skip_verify"] = "false"
+	_, err = suite.Driver.NewDatabase(options)
+	suite.Require().NoError(err)
+
+	options = suite.Quirks.DatabaseOptions()
+	options["adbc.flight.sql.client_option.tls_skip_verify"] = "invalid"
+	_, err = suite.Driver.NewDatabase(options)
+	suite.Require().ErrorContains(err, "Invalid value for database option 'adbc.flight.sql.client_option.tls_skip_verify': 'invalid'")
+}
+
+func (suite *SSLTests) TestUnknownOption() {
+	options := suite.Quirks.DatabaseOptions()
+	options["unknown option"] = "unknown value"
+	_, err := suite.Driver.NewDatabase(options)
+	suite.Require().ErrorContains(err, "Unknown database option 'unknown option'")
+}
+
+type PartitionTests struct {
+	suite.Suite
+
+	Driver adbc.Driver
+	Quirks validation.DriverQuirks
+
+	DB   adbc.Database
+	Cnxn adbc.Connection
+	ctx  context.Context
+}
+
+func (suite *PartitionTests) SetupTest() {
+	suite.Driver = suite.Quirks.SetupDriver(suite.T())
+	var err error
+	suite.DB, err = suite.Driver.NewDatabase(suite.Quirks.DatabaseOptions())
+	suite.Require().NoError(err)
+	suite.ctx = context.Background()
+	suite.Cnxn, err = suite.DB.Open(suite.ctx)
+	suite.Require().NoError(err)
+}
+
+func (suite *PartitionTests) TearDownTest() {
+	suite.Require().NoError(suite.Cnxn.Close())
+	suite.Quirks.TearDownDriver(suite.T(), suite.Driver)
+	suite.Cnxn = nil
+	suite.DB = nil
+	suite.Driver = nil
+}
+
+func (suite *PartitionTests) TestIntrospectPartitions() {
+	stmt, err := suite.Cnxn.NewStatement()
+	suite.Require().NoError(err)
+	defer stmt.Close()
+
+	suite.Require().NoError(stmt.SetSqlQuery("SELECT 42"))
+
+	_, partitions, _, err := stmt.ExecutePartitions(context.Background())
+	suite.Require().NoError(err)
+	suite.Require().Equal(uint64(1), partitions.NumPartitions)
+
+	info := &flight.FlightInfo{}
+	suite.Require().NoError(proto.Unmarshal(partitions.PartitionIDs[0], info))
+	suite.Require().Equal(int64(-1), info.TotalBytes)
+	suite.Require().Equal(int64(-1), info.TotalRecords)
+	suite.Require().Equal(1, len(info.Endpoint))
+	suite.Require().Equal(0, len(info.Endpoint[0].Location))
 }
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go
index 8d26df2..b251e68 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -279,7 +279,9 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
 	out.NumPartitions = uint64(len(info.Endpoint))
 	out.PartitionIDs = make([][]byte, out.NumPartitions)
 	for i, e := range info.Endpoint {
-		data, err := proto.Marshal(e)
+		partition := proto.Clone(info).(*flight.FlightInfo)
+		partition.Endpoint = []*flight.FlightEndpoint{e}
+		data, err := proto.Marshal(partition)
 		if err != nil {
 			return sc, out, -1, adbc.Error{
 				Msg:  err.Error(),