You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/06/20 03:02:43 UTC

[airflow] 01/04: flake8 pass Merging multiple sql operators (#9124)

This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 212a4c8c80ca8fdb0e451a54d8665802406fa74a
Author: Daniel Imberman <da...@astronomer.io>
AuthorDate: Fri Jun 19 15:15:50 2020 -0700

    flake8 pass Merging multiple sql operators (#9124)
---
 airflow/operators/check_operator.py      |  8 +--
 airflow/operators/sql.py                 | 93 +++++++++++++++++---------------
 airflow/operators/sql_branch_operator.py |  2 +-
 tests/operators/test_check_operator.py   |  0
 tests/operators/test_sql.py              |  5 +-
 tests/test_core_to_contrib.py            | 15 ++----
 6 files changed, 61 insertions(+), 62 deletions(-)

diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index 4810eeb..12ac472 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -38,7 +38,7 @@ class CheckOperator(SQLCheckOperator):
             Please use `airflow.operators.sql.SQLCheckOperator`.""",
             DeprecationWarning, stacklevel=2
         )
-        super().__init__(*args, **kwargs)
+        super(CheckOperator, self).__init__(*args, **kwargs)
 
 
 class IntervalCheckOperator(SQLIntervalCheckOperator):
@@ -53,7 +53,7 @@ class IntervalCheckOperator(SQLIntervalCheckOperator):
             Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""",
             DeprecationWarning, stacklevel=2
         )
-        super().__init__(*args, **kwargs)
+        super(IntervalCheckOperator, self).__init__(*args, **kwargs)
 
 
 class ThresholdCheckOperator(SQLThresholdCheckOperator):
@@ -68,7 +68,7 @@ class ThresholdCheckOperator(SQLThresholdCheckOperator):
             Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""",
             DeprecationWarning, stacklevel=2
         )
-        super().__init__(*args, **kwargs)
+        super(ThresholdCheckOperator, self).__init__(*args, **kwargs)
 
 
 class ValueCheckOperator(SQLValueCheckOperator):
@@ -83,4 +83,4 @@ class ValueCheckOperator(SQLValueCheckOperator):
             Please use `airflow.operators.sql.SQLValueCheckOperator`.""",
             DeprecationWarning, stacklevel=2
         )
-        super().__init__(*args, **kwargs)
+        super(ValueCheckOperator, self).__init__(*args, **kwargs)
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index fd997d9..83cb201 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from distutils.util import strtobool
-from typing import Any, Dict, Iterable, List, Mapping, Optional, SupportsAbs, Union
+from typing import Iterable
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
@@ -82,9 +82,9 @@ class SQLCheckOperator(BaseOperator):
 
     @apply_defaults
     def __init__(
-        self, sql: str, conn_id: Optional[str] = None, *args, **kwargs
-    ) -> None:
-        super().__init__(*args, **kwargs)
+        self, sql, conn_id=None, *args, **kwargs
+    ):
+        super(SQLCheckOperator, self).__init__(*args, **kwargs)
         self.conn_id = conn_id
         self.sql = sql
 
@@ -155,14 +155,14 @@ class SQLValueCheckOperator(BaseOperator):
     @apply_defaults
     def __init__(
         self,
-        sql: str,
-        pass_value: Any,
-        tolerance: Any = None,
-        conn_id: Optional[str] = None,
+        sql,
+        pass_value,
+        tolerance=None,
+        conn_id=None,
         *args,
-        **kwargs,
+        **kwargs
     ):
-        super().__init__(*args, **kwargs)
+        super(SQLValueCheckOperator, self).__init__(*args, **kwargs)
         self.sql = sql
         self.conn_id = conn_id
         self.pass_value = str(pass_value)
@@ -278,17 +278,17 @@ class SQLIntervalCheckOperator(BaseOperator):
     @apply_defaults
     def __init__(
         self,
-        table: str,
-        metrics_thresholds: Dict[str, int],
-        date_filter_column: Optional[str] = "ds",
-        days_back: SupportsAbs[int] = -7,
-        ratio_formula: Optional[str] = "max_over_min",
-        ignore_zero: Optional[bool] = True,
-        conn_id: Optional[str] = None,
+        table,
+        metrics_thresholds,
+        date_filter_column="ds",
+        days_back=-7,
+        ratio_formula="max_over_min",
+        ignore_zero=True,
+        conn_id=None,
         *args,
-        **kwargs,
+        **kwargs
     ):
-        super().__init__(*args, **kwargs)
+        super(SQLIntervalCheckOperator, self).__init__(*args, **kwargs)
         if ratio_formula not in self.ratio_formulas:
             msg_template = (
                 "Invalid diff_method: {diff_method}. "
@@ -423,14 +423,14 @@ class SQLThresholdCheckOperator(BaseOperator):
     @apply_defaults
     def __init__(
         self,
-        sql: str,
-        min_threshold: Any,
-        max_threshold: Any,
-        conn_id: Optional[str] = None,
+        sql,
+        min_threshold,
+        max_threshold,
+        conn_id=None,
         *args,
-        **kwargs,
+        **kwargs
     ):
-        super().__init__(*args, **kwargs)
+        super(SQLThresholdCheckOperator, self).__init__(*args, **kwargs)
         self.sql = sql
         self.conn_id = conn_id
         self.min_threshold = _convert_to_float_if_possible(min_threshold)
@@ -461,13 +461,20 @@ class SQLThresholdCheckOperator(BaseOperator):
         self.push(meta_data)
         if not meta_data["within_threshold"]:
             error_msg = (
-                f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
-                f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
-                f'Check description: {meta_data.get("description")}\n'
-                f"SQL: {self.sql}\n"
-                f'Result: {round(meta_data.get("result"), 2)} is not within thresholds '
-                f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
-            )
+                'Threshold Check: "{task_id}" failed.\n'
+                'DAG: {dag_id}\nTask_id: {task_id}\n'
+                'Check description: {description}\n'
+                "SQL: {sql}\n"
+                'Result: {round} is not within thresholds '
+                '{min} and {max}'
+                .format(task_id=meta_data.get("task_id"),
+                        dag_id=self.dag_id,
+                        description=meta_data.get("description"),
+                        sql=self.sql,
+                        round=round(meta_data.get("result"), 2),
+                        min=meta_data.get("min_threshold"),
+                        max=meta_data.get("max_threshold"),
+                        ))
             raise AirflowException(error_msg)
 
         self.log.info("Test %s Successful.", self.task_id)
@@ -478,7 +485,7 @@ class SQLThresholdCheckOperator(BaseOperator):
         Default functionality will log metadata.
         """
 
-        info = "\n".join([f"""{key}: {item}""" for key, item in meta_data.items()])
+        info = "\n".join(["{key}: {item}".format(key=key, item=item) for key, item in meta_data.items()])
         self.log.info("Log from %s:\n%s", self.dag_id, info)
 
     def get_db_hook(self):
@@ -516,16 +523,16 @@ class BranchSQLOperator(BaseOperator, SkipMixin):
     @apply_defaults
     def __init__(
         self,
-        sql: str,
-        follow_task_ids_if_true: List[str],
-        follow_task_ids_if_false: List[str],
-        conn_id: str = "default_conn_id",
-        database: Optional[str] = None,
-        parameters: Optional[Union[Mapping, Iterable]] = None,
+        sql,
+        follow_task_ids_if_true,
+        follow_task_ids_if_false,
+        conn_id="default_conn_id",
+        database=None,
+        parameters=None,
         *args,
-        **kwargs,
-    ) -> None:
-        super().__init__(*args, **kwargs)
+        **kwargs
+    ):
+        super(BranchSQLOperator, self).__init__(*args, **kwargs)
         self.conn_id = conn_id
         self.sql = sql
         self.parameters = parameters
@@ -551,7 +558,7 @@ class BranchSQLOperator(BaseOperator, SkipMixin):
 
         return self._hook
 
-    def execute(self, context: Dict):
+    def execute(self, context):
         # get supported hook
         self._hook = self._get_hook()
 
diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
index cd319aa..b911e34 100644
--- a/airflow/operators/sql_branch_operator.py
+++ b/airflow/operators/sql_branch_operator.py
@@ -32,4 +32,4 @@ class BranchSqlOperator(BranchSQLOperator):
             Please use `airflow.operators.sql.BranchSQLOperator`.""",
             DeprecationWarning, stacklevel=2
         )
-        super().__init__(*args, **kwargs)
+        super(BranchSqlOperator, self).__init__(*args, **kwargs)
diff --git a/tests/operators/test_check_operator.py b/tests/operators/test_check_operator.py
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py
index a538f15..e5c1f98 100644
--- a/tests/operators/test_sql.py
+++ b/tests/operators/test_sql.py
@@ -200,8 +200,7 @@ class TestIntervalCheckOperator(unittest.TestCase):
                 [2, 2, 2, 2],  # reference
                 [1, 1, 1, 1],  # current
             ]
-
-            yield from rows
+            return rows
 
         mock_hook.get_first.side_effect = returned_row()
         mock_get_db_hook.return_value = mock_hook
@@ -226,7 +225,7 @@ class TestIntervalCheckOperator(unittest.TestCase):
                 [1, 1, 1, 1],  # current
             ]
 
-            yield from rows
+            return rows
 
         mock_hook.get_first.side_effect = returned_row()
         mock_get_db_hook.return_value = mock_hook
diff --git a/tests/test_core_to_contrib.py b/tests/test_core_to_contrib.py
index 0a3e7fb..127905a 100644
--- a/tests/test_core_to_contrib.py
+++ b/tests/test_core_to_contrib.py
@@ -19,12 +19,10 @@
 import importlib
 import sys
 from inspect import isabstract
-from typing import Any
 from unittest import TestCase, mock
 
 from parameterized import parameterized
 
-HOOKS = []
 
 OPERATORS = [
     (
@@ -49,24 +47,19 @@ OPERATORS = [
     ),
 ]
 
-SECRETS = []
 
-SENSORS = []
-
-TRANSFERS = []
-
-ALL = HOOKS + OPERATORS + SECRETS + SENSORS + TRANSFERS
+ALL = OPERATORS
 
 RENAMED_HOOKS = [
     (old_class, new_class)
-    for old_class, new_class in HOOKS + OPERATORS + SECRETS + SENSORS
+    for old_class, new_class in OPERATORS
     if old_class.rpartition(".")[2] != new_class.rpartition(".")[2]
 ]
 
 
 class TestMovingCoreToContrib(TestCase):
     @staticmethod
-    def assert_warning(msg: str, warning: Any):
+    def assert_warning(msg, warning):
         error = "Text '{}' not in warnings".format(msg)
         assert any(msg in str(w) for w in warning.warnings), error
 
@@ -100,7 +93,7 @@ class TestMovingCoreToContrib(TestCase):
         class_ = getattr(module, class_name)
 
         if isabstract(class_) and not parent:
-            class_name = f"Mock({class_.__name__})"
+            class_name = "Mock({class_name})".format(class_name=class_.__name__)
 
             attributes = {
                 a: mock.MagicMock() for a in class_.__abstractmethods__