You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/11/14 19:33:02 UTC

[airflow] branch main updated: Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ae98b824d Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)
3ae98b824d is described below

commit 3ae98b824db437b2db928a73ac8b50c0a2f80124
Author: Wil Molina <mo...@gmail.com>
AuthorDate: Mon Nov 14 11:32:50 2022 -0800

    Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)
---
 airflow/providers/common/sql/operators/sql.py      | 12 ++++++++++--
 airflow/providers/snowflake/operators/snowflake.py |  2 +-
 tests/providers/common/sql/operators/test_sql.py   |  7 ++++++-
 3 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py
index 66244a858d..8dee6ed968 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -608,6 +608,7 @@ class SQLCheckOperator(BaseSQLOperator):
     :param sql: the sql to be executed. (templated)
     :param conn_id: the connection ID used to connect to the database.
     :param database: name of database which overwrite the defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
     """
 
     template_fields: Sequence[str] = ("sql",)
@@ -619,14 +620,21 @@ class SQLCheckOperator(BaseSQLOperator):
     ui_color = "#fff7e6"
 
     def __init__(
-        self, *, sql: str, conn_id: str | None = None, database: str | None = None, **kwargs
+        self,
+        *,
+        sql: str,
+        conn_id: str | None = None,
+        database: str | None = None,
+        parameters: Iterable | Mapping | None = None,
+        **kwargs,
     ) -> None:
         super().__init__(conn_id=conn_id, database=database, **kwargs)
         self.sql = sql
+        self.parameters = parameters
 
     def execute(self, context: Context):
         self.log.info("Executing SQL check: %s", self.sql)
-        records = self.get_db_hook().get_first(self.sql)
+        records = self.get_db_hook().get_first(self.sql, self.parameters)
 
         self.log.info("Record: %s", records)
         if not records:
diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py
index 2546ddfb5e..cf7835ef65 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -179,7 +179,7 @@ class SnowflakeCheckOperator(SQLCheckOperator):
         session_parameters: dict | None = None,
         **kwargs,
     ) -> None:
-        super().__init__(sql=sql, **kwargs)
+        super().__init__(sql=sql, parameters=parameters, **kwargs)
         self.snowflake_conn_id = snowflake_conn_id
         self.sql = sql
         self.autocommit = autocommit
diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py
index 51f013f7fc..3741a93ed3 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -483,7 +483,7 @@ class TestSQLCheckOperatorDbHook:
 
 class TestCheckOperator(unittest.TestCase):
     def setUp(self):
-        self._operator = SQLCheckOperator(task_id="test_task", sql="sql")
+        self._operator = SQLCheckOperator(task_id="test_task", sql="sql", parameters="parameters")
 
     @mock.patch.object(SQLCheckOperator, "get_db_hook")
     def test_execute_no_records(self, mock_get_db_hook):
@@ -499,6 +499,11 @@ class TestCheckOperator(unittest.TestCase):
         with pytest.raises(AirflowException, match=r"Test failed."):
             self._operator.execute({})
 
+    @mock.patch.object(SQLCheckOperator, "get_db_hook")
+    def test_sqlcheckoperator_parameters(self, mock_get_db_hook):
+        self._operator.execute({})
+        mock_get_db_hook.return_value.get_first.assert_called_once_with("sql", "parameters")
+
 
 class TestValueCheckOperator(unittest.TestCase):
     def setUp(self):