You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/13 18:29:04 UTC
[arrow-adbc] branch main updated: feat(python/adbc_driver_manager): add autocommit, executescript (#778)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 42280c28 feat(python/adbc_driver_manager): add autocommit, executescript (#778)
42280c28 is described below
commit 42280c2832aa125fcadb68024a9e181ce7fa0074
Author: David Li <li...@gmail.com>
AuthorDate: Tue Jun 13 14:28:58 2023 -0400
feat(python/adbc_driver_manager): add autocommit, executescript (#778)
Fixes #599.
---
c/driver/sqlite/statement_reader.c | 7 +++-
.../adbc_driver_flightsql/dbapi.py | 3 +-
.../adbc_driver_manager/dbapi.py | 42 +++++++++++++++++++---
.../adbc_driver_postgresql/dbapi.py | 26 ++++++++++++--
.../adbc_driver_snowflake/dbapi.py | 3 +-
.../adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py | 4 +--
python/adbc_driver_sqlite/tests/test_dbapi.py | 29 ++++++++++++++-
7 files changed, 100 insertions(+), 14 deletions(-)
diff --git a/c/driver/sqlite/statement_reader.c b/c/driver/sqlite/statement_reader.c
index 2b17364a..d9a0823f 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -253,6 +253,8 @@ const char* StatementReaderGetLastError(struct ArrowArrayStream* self) {
void StatementReaderSetError(struct StatementReader* reader) {
const char* msg = sqlite3_errmsg(reader->db);
+ // Reset here so that we don't get an error again in StatementRelease
+ (void)sqlite3_reset(reader->stmt);
strncpy(reader->error.message, msg, sizeof(reader->error.message));
reader->error.message[sizeof(reader->error.message) - 1] = '\0';
}
@@ -810,7 +812,7 @@ AdbcStatusCode StatementReaderInferOneValue(
}
case SQLITE_BLOB:
default: {
- return ADBC_STATUS_IO;
+ return ADBC_STATUS_NOT_IMPLEMENTED;
}
}
return ADBC_STATUS_OK;
@@ -870,7 +872,10 @@ AdbcStatusCode AdbcSqliteExportReader(sqlite3* db, sqlite3_stmt* stmt,
}
continue;
} else if (rc == SQLITE_ERROR) {
+ SetError(error, "Failed to step query: %s", sqlite3_errmsg(db));
status = ADBC_STATUS_IO;
+ // Reset here so that we don't get an error again in StatementRelease
+ (void)sqlite3_reset(stmt);
break;
} else if (rc != SQLITE_ROW) {
status = ADBC_STATUS_INTERNAL;
diff --git a/python/adbc_driver_flightsql/adbc_driver_flightsql/dbapi.py b/python/adbc_driver_flightsql/adbc_driver_flightsql/dbapi.py
index 00a85327..8db494b0 100644
--- a/python/adbc_driver_flightsql/adbc_driver_flightsql/dbapi.py
+++ b/python/adbc_driver_flightsql/adbc_driver_flightsql/dbapi.py
@@ -96,6 +96,7 @@ def connect(
uri: str,
db_kwargs: typing.Optional[typing.Dict[str, str]] = None,
conn_kwargs: typing.Optional[typing.Dict[str, str]] = None,
+ **kwargs,
) -> "Connection":
"""
Connect to a Flight SQL backend via ADBC.
@@ -117,7 +118,7 @@ def connect(
try:
db = adbc_driver_flightsql.connect(uri, db_kwargs=db_kwargs)
conn = adbc_driver_manager.AdbcConnection(db, **(conn_kwargs or {}))
- return adbc_driver_manager.dbapi.Connection(db, conn)
+ return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
except Exception:
if conn:
conn.close()
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 53f41d0c..22fd3eb3 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -160,6 +160,7 @@ def connect(
entrypoint: str = None,
db_kwargs: Optional[Dict[str, str]] = None,
conn_kwargs: Optional[Dict[str, str]] = None,
+ autocommit=False,
) -> "Connection":
"""
Connect to a database via ADBC.
@@ -180,6 +181,10 @@ def connect(
conn_kwargs
Key-value parameters to pass to the driver to initialize the
connection.
+ autocommit
+ Whether to enable autocommit. For compliance with DB-API,
+ this is disabled by default. A warning will be emitted if it
+ cannot be disabled.
"""
db = None
conn = None
@@ -194,7 +199,7 @@ def connect(
try:
db = _lib.AdbcDatabase(**db_kwargs)
conn = _lib.AdbcConnection(db, **conn_kwargs)
- return Connection(db, conn, conn_kwargs)
+ return Connection(db, conn, conn_kwargs, autocommit=autocommit)
except Exception:
if conn:
conn.close()
@@ -267,6 +272,8 @@ class Connection(_Closeable):
db: Union[_lib.AdbcDatabase, _SharedDatabase],
conn: _lib.AdbcConnection,
conn_kwargs: Optional[Dict[str, str]] = None,
+ *,
+ autocommit=False,
) -> None:
self._closed = False
if isinstance(db, _SharedDatabase):
@@ -280,13 +287,20 @@ class Connection(_Closeable):
self._conn.set_autocommit(False)
except _lib.NotSupportedError:
self._commit_supported = False
- warnings.warn(
- "Cannot disable autocommit; conn will not be DB-API 2.0 compliant",
- category=Warning,
- )
+ if not autocommit:
+ warnings.warn(
+ "Cannot disable autocommit; conn will not be DB-API 2.0 compliant",
+ category=Warning,
+ )
+ self._autocommit = True
else:
+ self._autocommit = False
self._commit_supported = True
+ if autocommit and self._commit_supported:
+ self._conn.set_autocommit(True)
+ self._autocommit = True
+
def close(self) -> None:
"""
Close the connection.
@@ -844,6 +858,24 @@ class Cursor(_Closeable):
"""
return self._stmt
+ def executescript(self, operation: str) -> None:
+ """
+ Execute multiple statements.
+
+ If there is a pending transaction, commits first.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ if not self._conn._autocommit:
+ self._conn.commit()
+
+ self._last_query = None
+ self._results = None
+ self._stmt.set_sql_query(operation)
+ self._stmt.execute_update()
+
def fetchallarrow(self) -> pyarrow.Table:
"""
Fetch all rows of the result as a PyArrow Table.
diff --git a/python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py b/python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py
index 47d21d36..88e309cb 100644
--- a/python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py
+++ b/python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py
@@ -19,6 +19,8 @@
DBAPI 2.0-compatible facade for the ADBC libpq driver.
"""
+import typing
+
import adbc_driver_manager
import adbc_driver_manager.dbapi
import adbc_driver_postgresql
@@ -92,15 +94,33 @@ ROWID = adbc_driver_manager.dbapi.ROWID
# Functions
-def connect(uri: str) -> "Connection":
- """Connect to PostgreSQL via ADBC."""
+def connect(
+ uri: str,
+ db_kwargs: typing.Optional[typing.Dict[str, str]] = None,
+ conn_kwargs: typing.Optional[typing.Dict[str, str]] = None,
+ **kwargs
+) -> "Connection":
+ """
+ Connect to PostgreSQL via ADBC.
+
+ Parameters
+ ----------
+ uri : str
+ The URI to connect to.
+ db_kwargs : dict, optional
+ Initial database connection parameters.
+ conn_kwargs : dict, optional
+ Connection-specific parameters. (ADBC differentiates between
+ a 'database' object shared between multiple 'connection'
+ objects.)
+ """
db = None
conn = None
try:
db = adbc_driver_postgresql.connect(uri)
conn = adbc_driver_manager.AdbcConnection(db)
- return adbc_driver_manager.dbapi.Connection(db, conn)
+ return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
except Exception:
if conn:
conn.close()
diff --git a/python/adbc_driver_snowflake/adbc_driver_snowflake/dbapi.py b/python/adbc_driver_snowflake/adbc_driver_snowflake/dbapi.py
index 262bdffb..42b9ec08 100644
--- a/python/adbc_driver_snowflake/adbc_driver_snowflake/dbapi.py
+++ b/python/adbc_driver_snowflake/adbc_driver_snowflake/dbapi.py
@@ -96,6 +96,7 @@ def connect(
uri: str,
db_kwargs: typing.Optional[typing.Dict[str, str]] = None,
conn_kwargs: typing.Optional[typing.Dict[str, str]] = None,
+ **kwargs,
) -> "Connection":
"""
Connect to Snowflake via ADBC.
@@ -117,7 +118,7 @@ def connect(
try:
db = adbc_driver_snowflake.connect(uri, db_kwargs=db_kwargs)
conn = adbc_driver_manager.AdbcConnection(db, **(conn_kwargs or {}))
- return adbc_driver_manager.dbapi.Connection(db, conn)
+ return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
except Exception:
if conn:
conn.close()
diff --git a/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py b/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
index 3f807248..aa566afa 100644
--- a/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
+++ b/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
@@ -92,7 +92,7 @@ ROWID = adbc_driver_manager.dbapi.ROWID
# Functions
-def connect(uri: typing.Optional[str] = None) -> "Connection":
+def connect(uri: typing.Optional[str] = None, **kwargs) -> "Connection":
"""Connect to SQLite via ADBC."""
db = None
conn = None
@@ -100,7 +100,7 @@ def connect(uri: typing.Optional[str] = None) -> "Connection":
try:
db = adbc_driver_sqlite.connect(uri)
conn = adbc_driver_manager.AdbcConnection(db)
- return adbc_driver_manager.dbapi.Connection(db, conn)
+ return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
except Exception:
if conn:
conn.close()
diff --git a/python/adbc_driver_sqlite/tests/test_dbapi.py b/python/adbc_driver_sqlite/tests/test_dbapi.py
index b4eca8f1..9598485b 100644
--- a/python/adbc_driver_sqlite/tests/test_dbapi.py
+++ b/python/adbc_driver_sqlite/tests/test_dbapi.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+from pathlib import Path
+
import pytest
from adbc_driver_sqlite import dbapi
@@ -26,7 +28,32 @@ def sqlite():
yield conn
-def test_query_trivial(sqlite):
+def test_query_trivial(sqlite) -> None:
with sqlite.cursor() as cur:
cur.execute("SELECT 1")
assert cur.fetchone() == (1,)
+
+
+def test_autocommit(tmp_path: Path) -> None:
+ # apache/arrow-adbc#599
+ db = tmp_path / "tmp.sqlite"
+ with dbapi.connect(f"file:{db}") as conn:
+ assert not conn._autocommit
+ with conn.cursor() as cur:
+ with pytest.raises(
+ dbapi.OperationalError,
+ match="cannot change into wal mode from within a transaction",
+ ):
+ cur.execute("PRAGMA journal_mode = WAL")
+
+ # This now works if we enable autocommit
+ with dbapi.connect(f"file:{db}", autocommit=True) as conn:
+ assert conn._autocommit
+ with conn.cursor() as cur:
+ cur.execute("PRAGMA journal_mode = WAL")
+
+ # Or we can use executescript
+ with dbapi.connect(f"file:{db}") as conn:
+ assert not conn._autocommit
+ with conn.cursor() as cur:
+ cur.executescript("PRAGMA journal_mode = WAL")