You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/11/14 19:33:02 UTC
[airflow] branch main updated: Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3ae98b824d Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)
3ae98b824d is described below
commit 3ae98b824db437b2db928a73ac8b50c0a2f80124
Author: Wil Molina <mo...@gmail.com>
AuthorDate: Mon Nov 14 11:32:50 2022 -0800
Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)
---
airflow/providers/common/sql/operators/sql.py | 12 ++++++++++--
airflow/providers/snowflake/operators/snowflake.py | 2 +-
tests/providers/common/sql/operators/test_sql.py | 7 ++++++-
3 files changed, 17 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py
index 66244a858d..8dee6ed968 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -608,6 +608,7 @@ class SQLCheckOperator(BaseSQLOperator):
:param sql: the sql to be executed. (templated)
:param conn_id: the connection ID used to connect to the database.
:param database: name of database which overwrite the defined one in connection
+ :param parameters: (optional) the parameters to render the SQL query with.
"""
template_fields: Sequence[str] = ("sql",)
@@ -619,14 +620,21 @@ class SQLCheckOperator(BaseSQLOperator):
ui_color = "#fff7e6"
def __init__(
- self, *, sql: str, conn_id: str | None = None, database: str | None = None, **kwargs
+ self,
+ *,
+ sql: str,
+ conn_id: str | None = None,
+ database: str | None = None,
+ parameters: Iterable | Mapping | None = None,
+ **kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
self.sql = sql
+ self.parameters = parameters
def execute(self, context: Context):
self.log.info("Executing SQL check: %s", self.sql)
- records = self.get_db_hook().get_first(self.sql)
+ records = self.get_db_hook().get_first(self.sql, self.parameters)
self.log.info("Record: %s", records)
if not records:
diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py
index 2546ddfb5e..cf7835ef65 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -179,7 +179,7 @@ class SnowflakeCheckOperator(SQLCheckOperator):
session_parameters: dict | None = None,
**kwargs,
) -> None:
- super().__init__(sql=sql, **kwargs)
+ super().__init__(sql=sql, parameters=parameters, **kwargs)
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
self.autocommit = autocommit
diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py
index 51f013f7fc..3741a93ed3 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -483,7 +483,7 @@ class TestSQLCheckOperatorDbHook:
class TestCheckOperator(unittest.TestCase):
def setUp(self):
- self._operator = SQLCheckOperator(task_id="test_task", sql="sql")
+ self._operator = SQLCheckOperator(task_id="test_task", sql="sql", parameters="parameters")
@mock.patch.object(SQLCheckOperator, "get_db_hook")
def test_execute_no_records(self, mock_get_db_hook):
@@ -499,6 +499,11 @@ class TestCheckOperator(unittest.TestCase):
with pytest.raises(AirflowException, match=r"Test failed."):
self._operator.execute({})
+ @mock.patch.object(SQLCheckOperator, "get_db_hook")
+ def test_sqlcheckoperator_parameters(self, mock_get_db_hook):
+ self._operator.execute({})
+ mock_get_db_hook.return_value.get_first.assert_called_once_with("sql", "parameters")
+
class TestValueCheckOperator(unittest.TestCase):
def setUp(self):