You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/14 21:21:53 UTC

[arrow-adbc] branch main updated: fix(go/adbc/pkg): allow ConnectionSetOptions before Init (#789)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new cb016366 fix(go/adbc/pkg): allow ConnectionSetOptions before Init (#789)
cb016366 is described below

commit cb0163663cd927385b76bd8e67b0ad0ae6675501
Author: David Li <li...@gmail.com>
AuthorDate: Wed Jun 14 17:21:47 2023 -0400

    fix(go/adbc/pkg): allow ConnectionSetOptions before Init (#789)
    
    Fixes #713.
---
 go/adbc/pkg/_tmpl/driver.go.tmpl                 | 39 ++++++++++++++++++++++--
 go/adbc/pkg/flightsql/driver.go                  | 39 ++++++++++++++++++++++--
 go/adbc/pkg/panicdummy/driver.go                 | 39 ++++++++++++++++++++++--
 go/adbc/pkg/snowflake/driver.go                  | 39 ++++++++++++++++++++++--
 python/adbc_driver_flightsql/tests/conftest.py   | 36 ++++++++++++++--------
 python/adbc_driver_flightsql/tests/test_dbapi.py | 18 +++++++++++
 6 files changed, 186 insertions(+), 24 deletions(-)

diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index fbc80aab..03a94c02 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -243,7 +243,8 @@ func {{.Prefix}}DatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErr
 }
 
 type cConn struct {
-	cnxn adbc.Connection
+	cnxn     adbc.Connection
+	initArgs map[string]string
 }
 
 func checkConnAlloc(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname string) bool {
@@ -308,7 +309,22 @@ func {{.Prefix}}ConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.c
 	}
 	conn := getFromHandle[cConn](cnxn.private_data)
 
-	rawCode := errToAdbcErr(err, conn.cnxn.(adbc.PostInitOptions).SetOption(C.GoString(key), C.GoString(val)))
+	if conn.cnxn == nil {
+		// not yet initialized
+		k, v := C.GoString(key), C.GoString(val)
+		if conn.initArgs == nil {
+			conn.initArgs = map[string]string{}
+		}
+		conn.initArgs[k] = v
+		return C.ADBC_STATUS_OK
+	}
+
+	opts, ok := conn.cnxn.(adbc.PostInitOptions)
+	if !ok {
+		setErr(err, "AdbcConnectionSetOption: not supported post-init")
+		return C.ADBC_STATUS_NOT_IMPLEMENTED
+	}
+	rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))
 	return C.AdbcStatusCode(rawCode)
 }
 
@@ -336,8 +352,25 @@ func {{.Prefix}}ConnectionInit(cnxn *C.struct_AdbcConnection, db *C.struct_AdbcD
 	if e != nil {
 		return C.AdbcStatusCode(errToAdbcErr(err, e))
 	}
-
 	conn.cnxn = c
+
+	if len(conn.initArgs) > 0 {
+		// C allow SetOption before Init, Go doesn't allow options to Open so set them now
+		opts, ok := conn.cnxn.(adbc.PostInitOptions)
+		if !ok {
+			setErr(err, "AdbcConnectionInit: options are not supported")
+			return C.ADBC_STATUS_NOT_IMPLEMENTED
+		}
+
+		for k, v := range conn.initArgs {
+			rawCode := errToAdbcErr(err, opts.SetOption(k, v))
+			if rawCode != adbc.StatusOK {
+				return C.AdbcStatusCode(rawCode)
+			}
+		}
+		conn.initArgs = nil
+	}
+
 	return C.ADBC_STATUS_OK
 }
 
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 976c28e1..6d5cf75b 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -247,7 +247,8 @@ func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError
 }
 
 type cConn struct {
-	cnxn adbc.Connection
+	cnxn     adbc.Connection
+	initArgs map[string]string
 }
 
 func checkConnAlloc(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname string) bool {
@@ -312,7 +313,22 @@ func FlightSQLConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cch
 	}
 	conn := getFromHandle[cConn](cnxn.private_data)
 
-	rawCode := errToAdbcErr(err, conn.cnxn.(adbc.PostInitOptions).SetOption(C.GoString(key), C.GoString(val)))
+	if conn.cnxn == nil {
+		// not yet initialized
+		k, v := C.GoString(key), C.GoString(val)
+		if conn.initArgs == nil {
+			conn.initArgs = map[string]string{}
+		}
+		conn.initArgs[k] = v
+		return C.ADBC_STATUS_OK
+	}
+
+	opts, ok := conn.cnxn.(adbc.PostInitOptions)
+	if !ok {
+		setErr(err, "AdbcConnectionSetOption: not supported post-init")
+		return C.ADBC_STATUS_NOT_IMPLEMENTED
+	}
+	rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))
 	return C.AdbcStatusCode(rawCode)
 }
 
@@ -340,8 +356,25 @@ func FlightSQLConnectionInit(cnxn *C.struct_AdbcConnection, db *C.struct_AdbcDat
 	if e != nil {
 		return C.AdbcStatusCode(errToAdbcErr(err, e))
 	}
-
 	conn.cnxn = c
+
+	if len(conn.initArgs) > 0 {
+		// C allow SetOption before Init, Go doesn't allow options to Open so set them now
+		opts, ok := conn.cnxn.(adbc.PostInitOptions)
+		if !ok {
+			setErr(err, "AdbcConnectionInit: options are not supported")
+			return C.ADBC_STATUS_NOT_IMPLEMENTED
+		}
+
+		for k, v := range conn.initArgs {
+			rawCode := errToAdbcErr(err, opts.SetOption(k, v))
+			if rawCode != adbc.StatusOK {
+				return C.AdbcStatusCode(rawCode)
+			}
+		}
+		conn.initArgs = nil
+	}
+
 	return C.ADBC_STATUS_OK
 }
 
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 915e70d2..374c3cb8 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -247,7 +247,8 @@ func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcErro
 }
 
 type cConn struct {
-	cnxn adbc.Connection
+	cnxn     adbc.Connection
+	initArgs map[string]string
 }
 
 func checkConnAlloc(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname string) bool {
@@ -312,7 +313,22 @@ func PanicDummyConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cc
 	}
 	conn := getFromHandle[cConn](cnxn.private_data)
 
-	rawCode := errToAdbcErr(err, conn.cnxn.(adbc.PostInitOptions).SetOption(C.GoString(key), C.GoString(val)))
+	if conn.cnxn == nil {
+		// not yet initialized
+		k, v := C.GoString(key), C.GoString(val)
+		if conn.initArgs == nil {
+			conn.initArgs = map[string]string{}
+		}
+		conn.initArgs[k] = v
+		return C.ADBC_STATUS_OK
+	}
+
+	opts, ok := conn.cnxn.(adbc.PostInitOptions)
+	if !ok {
+		setErr(err, "AdbcConnectionSetOption: not supported post-init")
+		return C.ADBC_STATUS_NOT_IMPLEMENTED
+	}
+	rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))
 	return C.AdbcStatusCode(rawCode)
 }
 
@@ -340,8 +356,25 @@ func PanicDummyConnectionInit(cnxn *C.struct_AdbcConnection, db *C.struct_AdbcDa
 	if e != nil {
 		return C.AdbcStatusCode(errToAdbcErr(err, e))
 	}
-
 	conn.cnxn = c
+
+	if len(conn.initArgs) > 0 {
+		// C allow SetOption before Init, Go doesn't allow options to Open so set them now
+		opts, ok := conn.cnxn.(adbc.PostInitOptions)
+		if !ok {
+			setErr(err, "AdbcConnectionInit: options are not supported")
+			return C.ADBC_STATUS_NOT_IMPLEMENTED
+		}
+
+		for k, v := range conn.initArgs {
+			rawCode := errToAdbcErr(err, opts.SetOption(k, v))
+			if rawCode != adbc.StatusOK {
+				return C.AdbcStatusCode(rawCode)
+			}
+		}
+		conn.initArgs = nil
+	}
+
 	return C.ADBC_STATUS_OK
 }
 
diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go
index 4e7e5f6f..31e2f131 100644
--- a/go/adbc/pkg/snowflake/driver.go
+++ b/go/adbc/pkg/snowflake/driver.go
@@ -247,7 +247,8 @@ func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError
 }
 
 type cConn struct {
-	cnxn adbc.Connection
+	cnxn     adbc.Connection
+	initArgs map[string]string
 }
 
 func checkConnAlloc(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, fname string) bool {
@@ -312,7 +313,22 @@ func SnowflakeConnectionSetOption(cnxn *C.struct_AdbcConnection, key, val *C.cch
 	}
 	conn := getFromHandle[cConn](cnxn.private_data)
 
-	rawCode := errToAdbcErr(err, conn.cnxn.(adbc.PostInitOptions).SetOption(C.GoString(key), C.GoString(val)))
+	if conn.cnxn == nil {
+		// not yet initialized
+		k, v := C.GoString(key), C.GoString(val)
+		if conn.initArgs == nil {
+			conn.initArgs = map[string]string{}
+		}
+		conn.initArgs[k] = v
+		return C.ADBC_STATUS_OK
+	}
+
+	opts, ok := conn.cnxn.(adbc.PostInitOptions)
+	if !ok {
+		setErr(err, "AdbcConnectionSetOption: not supported post-init")
+		return C.ADBC_STATUS_NOT_IMPLEMENTED
+	}
+	rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), C.GoString(val)))
 	return C.AdbcStatusCode(rawCode)
 }
 
@@ -340,8 +356,25 @@ func SnowflakeConnectionInit(cnxn *C.struct_AdbcConnection, db *C.struct_AdbcDat
 	if e != nil {
 		return C.AdbcStatusCode(errToAdbcErr(err, e))
 	}
-
 	conn.cnxn = c
+
+	if len(conn.initArgs) > 0 {
+		// C allow SetOption before Init, Go doesn't allow options to Open so set them now
+		opts, ok := conn.cnxn.(adbc.PostInitOptions)
+		if !ok {
+			setErr(err, "AdbcConnectionInit: options are not supported")
+			return C.ADBC_STATUS_NOT_IMPLEMENTED
+		}
+
+		for k, v := range conn.initArgs {
+			rawCode := errToAdbcErr(err, opts.SetOption(k, v))
+			if rawCode != adbc.StatusOK {
+				return C.AdbcStatusCode(rawCode)
+			}
+		}
+		conn.initArgs = nil
+	}
+
 	return C.ADBC_STATUS_OK
 }
 
diff --git a/python/adbc_driver_flightsql/tests/conftest.py b/python/adbc_driver_flightsql/tests/conftest.py
index b80f952e..4ca9508d 100644
--- a/python/adbc_driver_flightsql/tests/conftest.py
+++ b/python/adbc_driver_flightsql/tests/conftest.py
@@ -24,23 +24,37 @@ import adbc_driver_flightsql.dbapi
 import adbc_driver_manager
 
 
-@pytest.fixture
-def dremio_uri():
+@pytest.fixture(scope="session")
+def dremio_uri() -> str:
     dremio_uri = os.environ.get("ADBC_DREMIO_FLIGHTSQL_URI")
     if not dremio_uri:
         pytest.skip("Set ADBC_DREMIO_FLIGHTSQL_URI to run tests")
-    yield dremio_uri
+    return dremio_uri
 
 
-@pytest.fixture
-def dremio(dremio_uri):
+@pytest.fixture(scope="session")
+def dremio_user() -> str:
     username = os.environ.get("ADBC_DREMIO_FLIGHTSQL_USER")
+    if not username:
+        pytest.skip("Set ADBC_DREMIO_FLIGHTSQL_USER to run tests")
+    return username
+
+
+@pytest.fixture(scope="session")
+def dremio_pass() -> str:
     password = os.environ.get("ADBC_DREMIO_FLIGHTSQL_PASS")
+    if not password:
+        pytest.skip("Set ADBC_DREMIO_FLIGHTSQL_PASS to run tests")
+    return password
+
+
+@pytest.fixture
+def dremio(dremio_uri, dremio_user, dremio_pass):
     with adbc_driver_flightsql.connect(
         dremio_uri,
         db_kwargs={
-            adbc_driver_manager.DatabaseOptions.USERNAME.value: username,
-            adbc_driver_manager.DatabaseOptions.PASSWORD.value: password,
+            adbc_driver_manager.DatabaseOptions.USERNAME.value: dremio_user,
+            adbc_driver_manager.DatabaseOptions.PASSWORD.value: dremio_pass,
         },
     ) as db:
         with adbc_driver_manager.AdbcConnection(db) as conn:
@@ -48,14 +62,12 @@ def dremio(dremio_uri):
 
 
 @pytest.fixture
-def dremio_dbapi(dremio_uri):
-    username = os.environ.get("ADBC_DREMIO_FLIGHTSQL_USER")
-    password = os.environ.get("ADBC_DREMIO_FLIGHTSQL_PASS")
+def dremio_dbapi(dremio_uri, dremio_user, dremio_pass):
     with adbc_driver_flightsql.dbapi.connect(
         dremio_uri,
         db_kwargs={
-            adbc_driver_manager.DatabaseOptions.USERNAME.value: username,
-            adbc_driver_manager.DatabaseOptions.PASSWORD.value: password,
+            adbc_driver_manager.DatabaseOptions.USERNAME.value: dremio_user,
+            adbc_driver_manager.DatabaseOptions.PASSWORD.value: dremio_pass,
         },
     ) as conn:
         yield conn
diff --git a/python/adbc_driver_flightsql/tests/test_dbapi.py b/python/adbc_driver_flightsql/tests/test_dbapi.py
index 115cd5a4..cf72ab0e 100644
--- a/python/adbc_driver_flightsql/tests/test_dbapi.py
+++ b/python/adbc_driver_flightsql/tests/test_dbapi.py
@@ -17,6 +17,9 @@
 
 import pyarrow
 
+import adbc_driver_flightsql.dbapi
+import adbc_driver_manager
+
 
 def test_query_trivial(dremio_dbapi):
     with dremio_dbapi.cursor() as cur:
@@ -32,3 +35,18 @@ def test_query_partitioned(dremio_dbapi):
 
         cur.adbc_read_partition(partitions[0])
         assert cur.fetchone() == (1,)
+
+
+def test_set_options(dremio_uri, dremio_user, dremio_pass):
+    # Regression test for apache/arrow-adbc#713
+    with adbc_driver_flightsql.dbapi.connect(
+        dremio_uri,
+        db_kwargs={
+            adbc_driver_manager.DatabaseOptions.USERNAME.value: dremio_user,
+            adbc_driver_manager.DatabaseOptions.PASSWORD.value: dremio_pass,
+        },
+        conn_kwargs={
+            "adbc.flight.sql.rpc.call_header.x-foo": "1",
+        },
+    ):
+        pass