You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/13 18:29:04 UTC

[arrow-adbc] branch main updated: feat(python/adbc_driver_manager): add autocommit, executescript (#778)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 42280c28 feat(python/adbc_driver_manager): add autocommit, executescript (#778)
42280c28 is described below

commit 42280c2832aa125fcadb68024a9e181ce7fa0074
Author: David Li <li...@gmail.com>
AuthorDate: Tue Jun 13 14:28:58 2023 -0400

    feat(python/adbc_driver_manager): add autocommit, executescript (#778)
    
    Fixes #599.
---
 c/driver/sqlite/statement_reader.c                 |  7 +++-
 .../adbc_driver_flightsql/dbapi.py                 |  3 +-
 .../adbc_driver_manager/dbapi.py                   | 42 +++++++++++++++++++---
 .../adbc_driver_postgresql/dbapi.py                | 26 ++++++++++++--
 .../adbc_driver_snowflake/dbapi.py                 |  3 +-
 .../adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py |  4 +--
 python/adbc_driver_sqlite/tests/test_dbapi.py      | 29 ++++++++++++++-
 7 files changed, 100 insertions(+), 14 deletions(-)

diff --git a/c/driver/sqlite/statement_reader.c b/c/driver/sqlite/statement_reader.c
index 2b17364a..d9a0823f 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -253,6 +253,8 @@ const char* StatementReaderGetLastError(struct ArrowArrayStream* self) {
 
 void StatementReaderSetError(struct StatementReader* reader) {
   const char* msg = sqlite3_errmsg(reader->db);
+  // Reset here so that we don't get an error again in StatementRelease
+  (void)sqlite3_reset(reader->stmt);
   strncpy(reader->error.message, msg, sizeof(reader->error.message));
   reader->error.message[sizeof(reader->error.message) - 1] = '\0';
 }
@@ -810,7 +812,7 @@ AdbcStatusCode StatementReaderInferOneValue(
     }
     case SQLITE_BLOB:
     default: {
-      return ADBC_STATUS_IO;
+      return ADBC_STATUS_NOT_IMPLEMENTED;
     }
   }
   return ADBC_STATUS_OK;
@@ -870,7 +872,10 @@ AdbcStatusCode AdbcSqliteExportReader(sqlite3* db, sqlite3_stmt* stmt,
         }
         continue;
       } else if (rc == SQLITE_ERROR) {
+        SetError(error, "Failed to step query: %s", sqlite3_errmsg(db));
         status = ADBC_STATUS_IO;
+        // Reset here so that we don't get an error again in StatementRelease
+        (void)sqlite3_reset(stmt);
         break;
       } else if (rc != SQLITE_ROW) {
         status = ADBC_STATUS_INTERNAL;
diff --git a/python/adbc_driver_flightsql/adbc_driver_flightsql/dbapi.py b/python/adbc_driver_flightsql/adbc_driver_flightsql/dbapi.py
index 00a85327..8db494b0 100644
--- a/python/adbc_driver_flightsql/adbc_driver_flightsql/dbapi.py
+++ b/python/adbc_driver_flightsql/adbc_driver_flightsql/dbapi.py
@@ -96,6 +96,7 @@ def connect(
     uri: str,
     db_kwargs: typing.Optional[typing.Dict[str, str]] = None,
     conn_kwargs: typing.Optional[typing.Dict[str, str]] = None,
+    **kwargs,
 ) -> "Connection":
     """
     Connect to a Flight SQL backend via ADBC.
@@ -117,7 +118,7 @@ def connect(
     try:
         db = adbc_driver_flightsql.connect(uri, db_kwargs=db_kwargs)
         conn = adbc_driver_manager.AdbcConnection(db, **(conn_kwargs or {}))
-        return adbc_driver_manager.dbapi.Connection(db, conn)
+        return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
     except Exception:
         if conn:
             conn.close()
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 53f41d0c..22fd3eb3 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -160,6 +160,7 @@ def connect(
     entrypoint: str = None,
     db_kwargs: Optional[Dict[str, str]] = None,
     conn_kwargs: Optional[Dict[str, str]] = None,
+    autocommit=False,
 ) -> "Connection":
     """
     Connect to a database via ADBC.
@@ -180,6 +181,10 @@ def connect(
     conn_kwargs
         Key-value parameters to pass to the driver to initialize the
         connection.
+    autocommit
+        Whether to enable autocommit.  For compliance with DB-API,
+        this is disabled by default.  A warning will be emitted if it
+        cannot be disabled.
     """
     db = None
     conn = None
@@ -194,7 +199,7 @@ def connect(
     try:
         db = _lib.AdbcDatabase(**db_kwargs)
         conn = _lib.AdbcConnection(db, **conn_kwargs)
-        return Connection(db, conn, conn_kwargs)
+        return Connection(db, conn, conn_kwargs, autocommit=autocommit)
     except Exception:
         if conn:
             conn.close()
@@ -267,6 +272,8 @@ class Connection(_Closeable):
         db: Union[_lib.AdbcDatabase, _SharedDatabase],
         conn: _lib.AdbcConnection,
         conn_kwargs: Optional[Dict[str, str]] = None,
+        *,
+        autocommit=False,
     ) -> None:
         self._closed = False
         if isinstance(db, _SharedDatabase):
@@ -280,13 +287,20 @@ class Connection(_Closeable):
             self._conn.set_autocommit(False)
         except _lib.NotSupportedError:
             self._commit_supported = False
-            warnings.warn(
-                "Cannot disable autocommit; conn will not be DB-API 2.0 compliant",
-                category=Warning,
-            )
+            if not autocommit:
+                warnings.warn(
+                    "Cannot disable autocommit; conn will not be DB-API 2.0 compliant",
+                    category=Warning,
+                )
+            self._autocommit = True
         else:
+            self._autocommit = False
             self._commit_supported = True
 
+        if autocommit and self._commit_supported:
+            self._conn.set_autocommit(True)
+            self._autocommit = True
+
     def close(self) -> None:
         """
         Close the connection.
@@ -844,6 +858,24 @@ class Cursor(_Closeable):
         """
         return self._stmt
 
+    def executescript(self, operation: str) -> None:
+        """
+        Execute multiple statements.
+
+        If there is a pending transaction, commits first.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        if not self._conn._autocommit:
+            self._conn.commit()
+
+        self._last_query = None
+        self._results = None
+        self._stmt.set_sql_query(operation)
+        self._stmt.execute_update()
+
     def fetchallarrow(self) -> pyarrow.Table:
         """
         Fetch all rows of the result as a PyArrow Table.
diff --git a/python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py b/python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py
index 47d21d36..88e309cb 100644
--- a/python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py
+++ b/python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py
@@ -19,6 +19,8 @@
 DBAPI 2.0-compatible facade for the ADBC libpq driver.
 """
 
+import typing
+
 import adbc_driver_manager
 import adbc_driver_manager.dbapi
 import adbc_driver_postgresql
@@ -92,15 +94,33 @@ ROWID = adbc_driver_manager.dbapi.ROWID
 # Functions
 
 
-def connect(uri: str) -> "Connection":
-    """Connect to PostgreSQL via ADBC."""
+def connect(
+    uri: str,
+    db_kwargs: typing.Optional[typing.Dict[str, str]] = None,
+    conn_kwargs: typing.Optional[typing.Dict[str, str]] = None,
+    **kwargs
+) -> "Connection":
+    """
+    Connect to PostgreSQL via ADBC.
+
+    Parameters
+    ----------
+    uri : str
+        The URI to connect to.
+    db_kwargs : dict, optional
+        Initial database connection parameters.
+    conn_kwargs : dict, optional
+        Connection-specific parameters.  (ADBC differentiates between
+        a 'database' object shared between multiple 'connection'
+        objects.)
+    """
     db = None
     conn = None
 
     try:
         db = adbc_driver_postgresql.connect(uri)
         conn = adbc_driver_manager.AdbcConnection(db)
-        return adbc_driver_manager.dbapi.Connection(db, conn)
+        return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
     except Exception:
         if conn:
             conn.close()
diff --git a/python/adbc_driver_snowflake/adbc_driver_snowflake/dbapi.py b/python/adbc_driver_snowflake/adbc_driver_snowflake/dbapi.py
index 262bdffb..42b9ec08 100644
--- a/python/adbc_driver_snowflake/adbc_driver_snowflake/dbapi.py
+++ b/python/adbc_driver_snowflake/adbc_driver_snowflake/dbapi.py
@@ -96,6 +96,7 @@ def connect(
     uri: str,
     db_kwargs: typing.Optional[typing.Dict[str, str]] = None,
     conn_kwargs: typing.Optional[typing.Dict[str, str]] = None,
+    **kwargs,
 ) -> "Connection":
     """
     Connect to Snowflake via ADBC.
@@ -117,7 +118,7 @@ def connect(
     try:
         db = adbc_driver_snowflake.connect(uri, db_kwargs=db_kwargs)
         conn = adbc_driver_manager.AdbcConnection(db, **(conn_kwargs or {}))
-        return adbc_driver_manager.dbapi.Connection(db, conn)
+        return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
     except Exception:
         if conn:
             conn.close()
diff --git a/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py b/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
index 3f807248..aa566afa 100644
--- a/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
+++ b/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
@@ -92,7 +92,7 @@ ROWID = adbc_driver_manager.dbapi.ROWID
 # Functions
 
 
-def connect(uri: typing.Optional[str] = None) -> "Connection":
+def connect(uri: typing.Optional[str] = None, **kwargs) -> "Connection":
     """Connect to SQLite via ADBC."""
     db = None
     conn = None
@@ -100,7 +100,7 @@ def connect(uri: typing.Optional[str] = None) -> "Connection":
     try:
         db = adbc_driver_sqlite.connect(uri)
         conn = adbc_driver_manager.AdbcConnection(db)
-        return adbc_driver_manager.dbapi.Connection(db, conn)
+        return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
     except Exception:
         if conn:
             conn.close()
diff --git a/python/adbc_driver_sqlite/tests/test_dbapi.py b/python/adbc_driver_sqlite/tests/test_dbapi.py
index b4eca8f1..9598485b 100644
--- a/python/adbc_driver_sqlite/tests/test_dbapi.py
+++ b/python/adbc_driver_sqlite/tests/test_dbapi.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from pathlib import Path
+
 import pytest
 
 from adbc_driver_sqlite import dbapi
@@ -26,7 +28,32 @@ def sqlite():
         yield conn
 
 
-def test_query_trivial(sqlite):
+def test_query_trivial(sqlite) -> None:
     with sqlite.cursor() as cur:
         cur.execute("SELECT 1")
         assert cur.fetchone() == (1,)
+
+
+def test_autocommit(tmp_path: Path) -> None:
+    # apache/arrow-adbc#599
+    db = tmp_path / "tmp.sqlite"
+    with dbapi.connect(f"file:{db}") as conn:
+        assert not conn._autocommit
+        with conn.cursor() as cur:
+            with pytest.raises(
+                dbapi.OperationalError,
+                match="cannot change into wal mode from within a transaction",
+            ):
+                cur.execute("PRAGMA journal_mode = WAL")
+
+    # This now works if we enable autocommit
+    with dbapi.connect(f"file:{db}", autocommit=True) as conn:
+        assert conn._autocommit
+        with conn.cursor() as cur:
+            cur.execute("PRAGMA journal_mode = WAL")
+
+    # Or we can use executescript
+    with dbapi.connect(f"file:{db}") as conn:
+        assert not conn._autocommit
+        with conn.cursor() as cur:
+            cur.executescript("PRAGMA journal_mode = WAL")