You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/12/09 23:04:46 UTC

[airflow] branch main updated: Fix template rendering for Common SQL operators (#28202)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a6cda7cd23 Fix template rendering for Common SQL operators (#28202)
a6cda7cd23 is described below

commit a6cda7cd230ef22f7fe042d6d5e9f78c660c4a75
Author: Jonathan Stott <81...@users.noreply.github.com>
AuthorDate: Fri Dec 9 23:04:38 2022 +0000

    Fix template rendering for Common SQL operators (#28202)
    
    Closes: #28195
    
    This patch fixes all the common SQL operators I could find which showed
    the same problem as reported in #28195, that statements are generated
    "too early", before the templated variables have been applied.  I think
    all changes should have tests which break without the fix.  Some of
    these tests are quite brittle in that they mock complex nested SQL which
    is not ideal.
    
    This patch adds the `self.sql` variable as a templated parameter,
    allowing for templated `table`, `partition_clause`, `checks` etc.
---
 airflow/providers/common/sql/operators/sql.py    |  28 +++--
 airflow/providers/common/sql/operators/sql.pyi   |   2 +
 tests/providers/common/sql/operators/test_sql.py | 134 +++++++++++++++++++++++
 3 files changed, 154 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py
index 09034b104c..cc2e9e4b57 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -324,7 +324,8 @@ class SQLColumnCheckOperator(BaseSQLOperator):
         :ref:`howto/operator:SQLColumnCheckOperator`
     """
 
-    template_fields = ("partition_clause",)
+    template_fields = ("partition_clause", "table", "sql")
+    template_fields_renderers = {"sql": "sql"}
 
     sql_check_template = """
         SELECT '{column}' AS col_name, '{check}' AS check_type, {column}_{check} AS check_result
@@ -550,7 +551,9 @@ class SQLTableCheckOperator(BaseSQLOperator):
         :ref:`howto/operator:SQLTableCheckOperator`
     """
 
-    template_fields = ("partition_clause",)
+    template_fields = ("partition_clause", "table", "sql")
+
+    template_fields_renderers = {"sql": "sql"}
 
     sql_check_template = """
     SELECT '{check_name}' AS check_name, MIN({check_name}) AS check_result
@@ -603,6 +606,8 @@ class SQLTableCheckOperator(BaseSQLOperator):
         self.log.info("All tests have passed")
 
     def _generate_sql_query(self):
+        self.log.info("Partition clause: %s", self.partition_clause)
+
         def _generate_partition_clause(check_name):
             if self.partition_clause and "partition_clause" not in self.checks[check_name]:
                 return f"WHERE {self.partition_clause}"
@@ -953,8 +958,8 @@ class SQLThresholdCheckOperator(BaseSQLOperator):
     ):
         super().__init__(conn_id=conn_id, database=database, **kwargs)
         self.sql = sql
-        self.min_threshold = _convert_to_float_if_possible(min_threshold)
-        self.max_threshold = _convert_to_float_if_possible(max_threshold)
+        self.min_threshold = min_threshold
+        self.max_threshold = max_threshold
 
     def execute(self, context: Context):
         hook = self.get_db_hook()
@@ -962,15 +967,18 @@ class SQLThresholdCheckOperator(BaseSQLOperator):
         if not result:
             self._raise_exception(f"The following query returned zero rows: {self.sql}")
 
-        if isinstance(self.min_threshold, float):
-            lower_bound = self.min_threshold
+        min_threshold = _convert_to_float_if_possible(self.min_threshold)
+        max_threshold = _convert_to_float_if_possible(self.max_threshold)
+
+        if isinstance(min_threshold, float):
+            lower_bound = min_threshold
         else:
-            lower_bound = hook.get_first(self.min_threshold)[0]
+            lower_bound = hook.get_first(min_threshold)[0]
 
-        if isinstance(self.max_threshold, float):
-            upper_bound = self.max_threshold
+        if isinstance(max_threshold, float):
+            upper_bound = max_threshold
         else:
-            upper_bound = hook.get_first(self.max_threshold)[0]
+            upper_bound = hook.get_first(max_threshold)[0]
 
         meta_data = {
             "result": result,
diff --git a/airflow/providers/common/sql/operators/sql.pyi b/airflow/providers/common/sql/operators/sql.pyi
index cbbd8ddcdc..70e24ce240 100644
--- a/airflow/providers/common/sql/operators/sql.pyi
+++ b/airflow/providers/common/sql/operators/sql.pyi
@@ -80,6 +80,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
 
 class SQLColumnCheckOperator(BaseSQLOperator):
     template_fields: Incomplete
+    template_fields_renderers: Incomplete
     sql_check_template: str
     column_checks: Incomplete
     table: Incomplete
@@ -102,6 +103,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
 
 class SQLTableCheckOperator(BaseSQLOperator):
     template_fields: Incomplete
+    template_fields_renderers: Incomplete
     sql_check_template: str
     table: Incomplete
     checks: Incomplete
diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py
index 216ac79280..be9b18c4b3 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -157,6 +157,12 @@ class TestColumnCheckOperator:
         monkeypatch.setattr(MockHook, "get_records", get_records)
         return operator
 
+    def _full_check_sql(self, sql: str) -> str:
+        """
+        Wraps the check fragment in the outer parts of the sql query
+        """
+        return f"SELECT col_name, check_type, check_result FROM ({sql}) AS check_columns"
+
     def test_check_not_in_column_checks(self, monkeypatch):
         with pytest.raises(AirflowException, match="Invalid column check: invalid_check_name."):
             self._construct_operator(monkeypatch, self.invalid_column_mapping, ())
@@ -246,6 +252,16 @@ class TestColumnCheckOperator:
             == self.correct_generate_sql_query_with_partition.lstrip()
         )
 
+    def test_generate_sql_query_with_templated_partitions(self, monkeypatch):
+        checks = self.short_valid_column_mapping["X"]
+        operator = self._construct_operator(monkeypatch, self.short_valid_column_mapping, ())
+        operator.partition_clause = "{{ params.col }} > 1"
+        operator.render_template_fields({"params": {"col": "Y"}})
+        assert (
+            operator._generate_sql_query("X", checks).lstrip()
+            == self.correct_generate_sql_query_with_partition.lstrip()
+        )
+
     def test_generate_sql_query_with_partitions_and_check_partition(self, monkeypatch):
         self.short_valid_column_mapping["X"]["null_check"]["partition_clause"] = "Z < 100"
         checks = self.short_valid_column_mapping["X"]
@@ -267,6 +283,55 @@ class TestColumnCheckOperator:
         )
         del self.short_valid_column_mapping["X"]["distinct_check"]["partition_clause"]
 
+    @mock.patch.object(SQLColumnCheckOperator, "get_db_hook")
+    def test_generated_sql_respects_templated_partitions(self, mock_get_db_hook):
+        records = [
+            ("X", "null_check", 0),
+            ("X", "distinct_check", 10),
+        ]
+
+        mock_hook = mock.Mock()
+        mock_hook.get_records.return_value = records
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = SQLColumnCheckOperator(
+            task_id="test_task",
+            table="test_table",
+            column_mapping=self.short_valid_column_mapping,
+            partition_clause="{{ params.col }} > 1",
+        )
+        operator.render_template_fields({"params": {"col": "Y"}})
+
+        operator.execute(context=MagicMock())
+
+        mock_get_db_hook.return_value.get_records.assert_called_once_with(
+            self._full_check_sql(self.correct_generate_sql_query_with_partition),
+        )
+
+    @mock.patch.object(SQLColumnCheckOperator, "get_db_hook")
+    def test_generated_sql_respects_templated_table(self, mock_get_db_hook):
+        records = [
+            ("X", "null_check", 0),
+            ("X", "distinct_check", 10),
+        ]
+
+        mock_hook = mock.Mock()
+        mock_hook.get_records.return_value = records
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = SQLColumnCheckOperator(
+            task_id="test_task",
+            table="{{ params.table }}",
+            column_mapping=self.short_valid_column_mapping,
+        )
+        operator.render_template_fields({"params": {"table": "test_table"}})
+
+        operator.execute(context=MagicMock())
+
+        mock_get_db_hook.return_value.get_records.assert_called_once_with(
+            self._full_check_sql(self.correct_generate_sql_query_no_partitions),
+        )
+
 
 class TestTableCheckOperator:
 
@@ -363,6 +428,48 @@ class TestTableCheckOperator:
         finally:
             hook.run(["DROP TABLE employees"])
 
+    @pytest.mark.parametrize(
+        ["conn_id"],
+        [
+            pytest.param("postgres_default", marks=[pytest.mark.backend("postgres")]),
+            pytest.param("mysql_default", marks=[pytest.mark.backend("mysql")]),
+        ],
+    )
+    def test_sql_check_partition_clause_templating(self, conn_id):
+        """
+        Checks that the generated sql respects a templated partition clause
+        """
+        operator = SQLTableCheckOperator(
+            task_id="test_task",
+            table="employees",
+            checks={"row_count_check": {"check_statement": "COUNT(*) = 5"}},
+            conn_id=conn_id,
+            partition_clause="employment_year = {{ params.year }}",
+        )
+
+        hook = operator.get_db_hook()
+        hook.run(
+            [
+                """
+                CREATE TABLE IF NOT EXISTS employees (
+                    employee_name VARCHAR(50) NOT NULL,
+                    employment_year INT NOT NULL
+                );
+                """,
+                "INSERT INTO employees VALUES ('Adam', 2021)",
+                "INSERT INTO employees VALUES ('Chris', 2021)",
+                "INSERT INTO employees VALUES ('Frank', 2021)",
+                "INSERT INTO employees VALUES ('Fritz', 2021)",
+                "INSERT INTO employees VALUES ('Magda', 2022)",
+                "INSERT INTO employees VALUES ('Phil', 2021)",
+            ]
+        )
+        try:
+            operator.render_template_fields({"params": {"year": 2021}})
+            operator.execute({})
+        finally:
+            hook.run(["DROP TABLE employees"])
+
     def test_pass_all_checks_check(self, monkeypatch):
         records = [("row_count_check", 1), ("column_sum_check", "y")]
         operator = self._construct_operator(monkeypatch, self.checks, records)
@@ -388,6 +495,22 @@ class TestTableCheckOperator:
             operator._generate_sql_query().lstrip() == self.correct_generate_sql_query_with_partition.lstrip()
         )
 
+    def test_generate_sql_query_with_templated_partitions(self, monkeypatch):
+        operator = self._construct_operator(monkeypatch, self.checks, ())
+        operator.partition_clause = "{{ params.col }} > 10"
+        operator.render_template_fields({"params": {"col": "col_a"}})
+        assert (
+            operator._generate_sql_query().lstrip() == self.correct_generate_sql_query_with_partition.lstrip()
+        )
+
+    def test_generate_sql_query_with_templated_table(self, monkeypatch):
+        operator = self._construct_operator(monkeypatch, self.checks, ())
+        operator.table = "{{ params.table }}"
+        operator.render_template_fields({"params": {"table": "test_table"}})
+        assert (
+            operator._generate_sql_query().lstrip() == self.correct_generate_sql_query_no_partitions.lstrip()
+        )
+
     def test_generate_sql_query_with_partitions_and_check_partition(self, monkeypatch):
         self.checks["row_count_check"]["partition_clause"] = "id = 100"
         operator = self._construct_operator(monkeypatch, self.checks, ())
@@ -703,6 +826,17 @@ class TestThresholdCheckOperator:
 
         operator.execute(context=MagicMock())
 
+    @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
+    def test_pass_min_value_max_value_templated(self, mock_get_db_hook):
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = (10,)
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self._construct_operator("Select avg(val) from table1 limit 1", "{{ params.min }}", 100)
+        operator.render_template_fields({"params": {"min": 1}})
+        operator.execute(context=MagicMock())
+        mock_hook.get_first.assert_called_once_with("Select avg(val) from table1 limit 1")
+
     @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
     def test_fail_min_value_max_value(self, mock_get_db_hook):
         mock_hook = mock.Mock()