You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/31 02:19:39 UTC

[airflow] branch main updated: Rename schema to database in PostgresHook (#26744)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 39caf1d5bc Rename schema to database in PostgresHook (#26744)
39caf1d5bc is described below

commit 39caf1d5bc5ec5ff653cf00b25d45e176709b59e
Author: Felix Uellendall <fe...@users.noreply.github.com>
AuthorDate: Mon Oct 31 03:19:31 2022 +0100

    Rename schema to database in PostgresHook (#26744)
    
    * Rename schema to database in PostgresHook
    
    In PostgresHook the "schema" field is only being called like that to make it compatible with the underlying DbApiHook which uses the schema for the sql alchemy connector. The postgres connector library however does not allow setting a schema in the connection instead a database can be set. To clarify that, we change all references in the PostgresHook code and documentation.
---
 airflow/providers/postgres/hooks/postgres.py       | 39 ++++++++++++++++++--
 airflow/providers/postgres/operators/postgres.py   |  2 +
 .../connections/postgres.rst                       |  9 ++++-
 .../operators/postgres_operator_howto_guide.rst    |  2 +
 tests/providers/common/sql/operators/test_sql.py   |  1 -
 tests/providers/common/sql/sensors/test_sql.py     |  4 +-
 tests/providers/postgres/hooks/test_postgres.py    | 43 ++++++++++++----------
 .../providers/postgres/operators/test_postgres.py  |  6 +--
 .../providers/slack/transfers/test_sql_to_slack.py |  4 +-
 9 files changed, 78 insertions(+), 32 deletions(-)

diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py
index 7a2d0e7e9b..26b1d791d1 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import os
+import warnings
 from contextlib import closing
 from copy import deepcopy
 from typing import Any, Iterable, Union
@@ -67,10 +68,38 @@ class PostgresHook(DbApiHook):
     supports_autocommit = True
 
     def __init__(self, *args, **kwargs) -> None:
+        if "schema" in kwargs:
+            warnings.warn(
+                'The "schema" arg has been renamed to "database" as it contained the database name.'
+                'Please use "database" to set the database name.',
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            kwargs["database"] = kwargs["schema"]
         super().__init__(*args, **kwargs)
         self.connection: Connection | None = kwargs.pop("connection", None)
         self.conn: connection = None
-        self.schema: str | None = kwargs.pop("schema", None)
+        self.database: str | None = kwargs.pop("database", None)
+
+    @property
+    def schema(self):
+        warnings.warn(
+            'The "schema" variable has been renamed to "database" as it contained the database name.'
+            'Please use "database" to get the database name.',
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        return self.database
+
+    @schema.setter
+    def schema(self, value):
+        warnings.warn(
+            'The "schema" variable has been renamed to "database" as it contained the database name.'
+            'Please use "database" to set the database name.',
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        self.database = value
 
     def _get_cursor(self, raw_cursor: str) -> CursorType:
         _cursor = raw_cursor.lower()
@@ -95,7 +124,7 @@ class PostgresHook(DbApiHook):
             host=conn.host,
             user=conn.login,
             password=conn.password,
-            dbname=self.schema or conn.schema,
+            dbname=self.database or conn.schema,
             port=conn.port,
         )
         raw_cursor = conn.extra_dejson.get("cursor", False)
@@ -143,7 +172,9 @@ class PostgresHook(DbApiHook):
         Extract the URI from the connection.
         :return: the extracted uri.
         """
-        uri = super().get_uri().replace("postgres://", "postgresql://")
+        conn = self.get_connection(getattr(self, self.conn_name_attr))
+        conn.schema = self.database or conn.schema
+        uri = conn.get_uri().replace("postgres://", "postgresql://")
         return uri
 
     def bulk_load(self, table: str, tmp_file: str) -> None:
@@ -196,7 +227,7 @@ class PostgresHook(DbApiHook):
             # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
             cluster_creds = redshift_client.get_cluster_credentials(
                 DbUser=login,
-                DbName=self.schema or conn.schema,
+                DbName=self.database or conn.schema,
                 ClusterIdentifier=cluster_identifier,
                 AutoCreate=False,
             )
diff --git a/airflow/providers/postgres/operators/postgres.py b/airflow/providers/postgres/operators/postgres.py
index 561d06c167..a9489b6663 100644
--- a/airflow/providers/postgres/operators/postgres.py
+++ b/airflow/providers/postgres/operators/postgres.py
@@ -38,6 +38,8 @@ class PostgresOperator(SQLExecuteQueryOperator):
         (default value: False)
     :param parameters: (optional) the parameters to render the SQL query with.
     :param database: name of database which overwrite defined one in connection
+    :param runtime_parameters: a mapping of runtime params added to the final sql being executed.
+        For example, you could set the schema via `{"search_path": "CUSTOM_SCHEMA"}`.
     """
 
     template_fields: Sequence[str] = ("sql",)
diff --git a/docs/apache-airflow-providers-postgres/connections/postgres.rst b/docs/apache-airflow-providers-postgres/connections/postgres.rst
index f97e99af84..68966dc926 100644
--- a/docs/apache-airflow-providers-postgres/connections/postgres.rst
+++ b/docs/apache-airflow-providers-postgres/connections/postgres.rst
@@ -29,7 +29,14 @@ Host (required)
     The host to connect to.
 
 Schema (optional)
-    Specify the schema name to be used in the database.
+    Specify the name of the database to connect to.
+
+    .. note::
+
+        If you want to define a default database schema:
+
+        * using ``PostgresOperator`` see :ref:`Passing Server Configuration Parameters into PostgresOperator <howto/operators:postgres>`
+        * using ``PostgresHook`` see `search_path <https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH>_`
 
 Login (required)
     Specify the user name to connect.
diff --git a/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst b/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst
index 790d02caec..648a6c75e0 100644
--- a/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst
+++ b/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst
@@ -15,6 +15,8 @@
     specific language governing permissions and limitations
     under the License.
 
+.. _howto/operators:postgres:
+
 How-to Guide for PostgresOperator
 =================================
 
diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py
index 2980326602..51f013f7fc 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -448,7 +448,6 @@ class TestSQLCheckOperatorDbHook:
         if database:
             self._operator.database = database
         assert isinstance(self._operator._hook, PostgresHook)
-        assert self._operator._hook.schema == database
         mock_get_conn.assert_called_once_with(self.conn_id)
 
     def test_not_allowed_conn_type(self, mock_get_conn):
diff --git a/tests/providers/common/sql/sensors/test_sql.py b/tests/providers/common/sql/sensors/test_sql.py
index 6e8f01c3c7..77665f1c84 100644
--- a/tests/providers/common/sql/sensors/test_sql.py
+++ b/tests/providers/common/sql/sensors/test_sql.py
@@ -263,8 +263,8 @@ class TestSqlSensor(TestHiveEnvironment):
             conn_id="postgres_default",
             sql="SELECT 1",
             hook_params={
-                "schema": "public",
+                "log_sql": False,
             },
         )
         hook = op._get_hook()
-        assert hook.schema == "public"
+        assert hook.log_sql == op.hook_params["log_sql"]
diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py
index e415a1c449..a6195d82d2 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -33,7 +33,7 @@ from airflow.utils.types import NOTSET
 class TestPostgresHookConn:
     @pytest.fixture(autouse=True)
     def setup(self):
-        self.connection = Connection(login="login", password="password", host="host", schema="schema")
+        self.connection = Connection(login="login", password="password", host="host", schema="database")
 
         class UnitTestPostgresHook(PostgresHook):
             conn_name_attr = "test_conn_id"
@@ -47,7 +47,7 @@ class TestPostgresHookConn:
         self.db_hook.test_conn_id = "non_default"
         self.db_hook.get_conn()
         mock_connect.assert_called_once_with(
-            user="login", password="password", host="host", dbname="schema", port=None
+            user="login", password="password", host="host", dbname="database", port=None
         )
         self.db_hook.get_connection.assert_called_once_with("non_default")
 
@@ -55,7 +55,7 @@ class TestPostgresHookConn:
     def test_get_conn(self, mock_connect):
         self.db_hook.get_conn()
         mock_connect.assert_called_once_with(
-            user="login", password="password", host="host", dbname="schema", port=None
+            user="login", password="password", host="host", dbname="database", port=None
         )
 
     @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
@@ -64,7 +64,7 @@ class TestPostgresHookConn:
         self.connection.conn_type = "postgres"
         self.db_hook.get_conn()
         assert mock_connect.call_count == 1
-        assert self.db_hook.get_uri() == "postgresql://login:password@host/schema?client_encoding=utf-8"
+        assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8"
 
     @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
     def test_get_conn_cursor(self, mock_connect):
@@ -75,7 +75,7 @@ class TestPostgresHookConn:
             user="login",
             password="password",
             host="host",
-            dbname="schema",
+            dbname="database",
             port=None,
         )
 
@@ -87,20 +87,20 @@ class TestPostgresHookConn:
 
     @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
     def test_get_conn_from_connection(self, mock_connect):
-        conn = Connection(login="login-conn", password="password-conn", host="host", schema="schema")
+        conn = Connection(login="login-conn", password="password-conn", host="host", schema="database")
         hook = PostgresHook(connection=conn)
         hook.get_conn()
         mock_connect.assert_called_once_with(
-            user="login-conn", password="password-conn", host="host", dbname="schema", port=None
+            user="login-conn", password="password-conn", host="host", dbname="database", port=None
         )
 
     @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
-    def test_get_conn_from_connection_with_schema(self, mock_connect):
-        conn = Connection(login="login-conn", password="password-conn", host="host", schema="schema")
-        hook = PostgresHook(connection=conn, schema="schema-override")
+    def test_get_conn_from_connection_with_database(self, mock_connect):
+        conn = Connection(login="login-conn", password="password-conn", host="host", schema="database")
+        hook = PostgresHook(connection=conn, database="database-override")
         hook.get_conn()
         mock_connect.assert_called_once_with(
-            user="login-conn", password="password-conn", host="host", dbname="schema-override", port=None
+            user="login-conn", password="password-conn", host="host", dbname="database-override", port=None
         )
 
     @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
@@ -146,7 +146,7 @@ class TestPostgresHookConn:
         self.connection.extra = '{"connect_timeout": 3}'
         self.db_hook.get_conn()
         mock_connect.assert_called_once_with(
-            user="login", password="password", host="host", dbname="schema", port=None, connect_timeout=3
+            user="login", password="password", host="host", dbname="database", port=None, connect_timeout=3
         )
 
     @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
@@ -225,32 +225,37 @@ class TestPostgresHookConn:
             port=(port or 5439),
         )
 
-    def test_get_uri_from_connection_without_schema_override(self):
+    def test_get_uri_from_connection_without_database_override(self):
         self.db_hook.get_connection = mock.MagicMock(
             return_value=Connection(
                 conn_type="postgres",
                 host="host",
                 login="login",
                 password="password",
-                schema="schema",
+                schema="database",
                 port=1,
             )
         )
-        assert "postgresql://login:password@host:1/schema" == self.db_hook.get_uri()
+        assert "postgresql://login:password@host:1/database" == self.db_hook.get_uri()
 
-    def test_get_uri_from_connection_with_schema_override(self):
-        hook = PostgresHook(schema="schema-override")
+    def test_get_uri_from_connection_with_database_override(self):
+        hook = PostgresHook(database="database-override")
         hook.get_connection = mock.MagicMock(
             return_value=Connection(
                 conn_type="postgres",
                 host="host",
                 login="login",
                 password="password",
-                schema="schema",
+                schema="database",
                 port=1,
             )
         )
-        assert "postgresql://login:password@host:1/schema-override" == hook.get_uri()
+        assert "postgresql://login:password@host:1/database-override" == hook.get_uri()
+
+    def test_schema_kwarg_database_kwarg_compatibility(self):
+        database = "database-override"
+        hook = PostgresHook(schema=database)
+        assert hook.database == database
 
 
 class TestPostgresHook(unittest.TestCase):
diff --git a/tests/providers/postgres/operators/test_postgres.py b/tests/providers/postgres/operators/test_postgres.py
index f6ce234560..394cfc2618 100644
--- a/tests/providers/postgres/operators/test_postgres.py
+++ b/tests/providers/postgres/operators/test_postgres.py
@@ -79,14 +79,14 @@ class TestPostgres(unittest.TestCase):
         op = PostgresOperator(task_id="postgres_operator_test_vacuum", sql=sql, dag=self.dag, autocommit=True)
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
-    def test_overwrite_schema(self):
+    def test_overwrite_database(self):
         """
-        Verifies option to overwrite connection schema
+        Verifies option to overwrite connection database
         """
 
         sql = "SELECT 1;"
         op = PostgresOperator(
-            task_id="postgres_operator_test_schema_overwrite",
+            task_id="postgres_operator_test_database_overwrite",
             sql=sql,
             dag=self.dag,
             autocommit=True,
diff --git a/tests/providers/slack/transfers/test_sql_to_slack.py b/tests/providers/slack/transfers/test_sql_to_slack.py
index 8c19273d46..307469460b 100644
--- a/tests/providers/slack/transfers/test_sql_to_slack.py
+++ b/tests/providers/slack/transfers/test_sql_to_slack.py
@@ -186,11 +186,11 @@ class TestSqlToSlackOperator:
             sql="SELECT 1",
             slack_message="message: {{ ds }}, {{ xxxx }}",
             sql_hook_params={
-                "schema": "public",
+                "log_sql": False,
             },
         )
         hook = op._get_hook()
-        assert hook.schema == "public"
+        assert hook.log_sql == op.sql_hook_params["log_sql"]
 
     @mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
     def test_hook_params_snowflake(self, mock_get_conn):