You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/06/20 03:02:43 UTC
[airflow] 01/04: flake8 pass Merging multiple sql operators (#9124)
This is an automated email from the ASF dual-hosted git repository.
dimberman pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 212a4c8c80ca8fdb0e451a54d8665802406fa74a
Author: Daniel Imberman <da...@astronomer.io>
AuthorDate: Fri Jun 19 15:15:50 2020 -0700
flake8 pass Merging multiple sql operators (#9124)
---
airflow/operators/check_operator.py | 8 +--
airflow/operators/sql.py | 93 +++++++++++++++++---------------
airflow/operators/sql_branch_operator.py | 2 +-
tests/operators/test_check_operator.py | 0
tests/operators/test_sql.py | 5 +-
tests/test_core_to_contrib.py | 15 ++----
6 files changed, 61 insertions(+), 62 deletions(-)
diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index 4810eeb..12ac472 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -38,7 +38,7 @@ class CheckOperator(SQLCheckOperator):
Please use `airflow.operators.sql.SQLCheckOperator`.""",
DeprecationWarning, stacklevel=2
)
- super().__init__(*args, **kwargs)
+ super(CheckOperator, self).__init__(*args, **kwargs)
class IntervalCheckOperator(SQLIntervalCheckOperator):
@@ -53,7 +53,7 @@ class IntervalCheckOperator(SQLIntervalCheckOperator):
Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""",
DeprecationWarning, stacklevel=2
)
- super().__init__(*args, **kwargs)
+ super(IntervalCheckOperator, self).__init__(*args, **kwargs)
class ThresholdCheckOperator(SQLThresholdCheckOperator):
@@ -68,7 +68,7 @@ class ThresholdCheckOperator(SQLThresholdCheckOperator):
Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""",
DeprecationWarning, stacklevel=2
)
- super().__init__(*args, **kwargs)
+ super(ThresholdCheckOperator, self).__init__(*args, **kwargs)
class ValueCheckOperator(SQLValueCheckOperator):
@@ -83,4 +83,4 @@ class ValueCheckOperator(SQLValueCheckOperator):
Please use `airflow.operators.sql.SQLValueCheckOperator`.""",
DeprecationWarning, stacklevel=2
)
- super().__init__(*args, **kwargs)
+ super(ValueCheckOperator, self).__init__(*args, **kwargs)
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index fd997d9..83cb201 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from distutils.util import strtobool
-from typing import Any, Dict, Iterable, List, Mapping, Optional, SupportsAbs, Union
+from typing import Iterable
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
@@ -82,9 +82,9 @@ class SQLCheckOperator(BaseOperator):
@apply_defaults
def __init__(
- self, sql: str, conn_id: Optional[str] = None, *args, **kwargs
- ) -> None:
- super().__init__(*args, **kwargs)
+ self, sql, conn_id=None, *args, **kwargs
+ ):
+ super(SQLCheckOperator, self).__init__(*args, **kwargs)
self.conn_id = conn_id
self.sql = sql
@@ -155,14 +155,14 @@ class SQLValueCheckOperator(BaseOperator):
@apply_defaults
def __init__(
self,
- sql: str,
- pass_value: Any,
- tolerance: Any = None,
- conn_id: Optional[str] = None,
+ sql,
+ pass_value,
+ tolerance=None,
+ conn_id=None,
*args,
- **kwargs,
+ **kwargs
):
- super().__init__(*args, **kwargs)
+ super(SQLValueCheckOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.conn_id = conn_id
self.pass_value = str(pass_value)
@@ -278,17 +278,17 @@ class SQLIntervalCheckOperator(BaseOperator):
@apply_defaults
def __init__(
self,
- table: str,
- metrics_thresholds: Dict[str, int],
- date_filter_column: Optional[str] = "ds",
- days_back: SupportsAbs[int] = -7,
- ratio_formula: Optional[str] = "max_over_min",
- ignore_zero: Optional[bool] = True,
- conn_id: Optional[str] = None,
+ table,
+ metrics_thresholds,
+ date_filter_column="ds",
+ days_back=-7,
+ ratio_formula="max_over_min",
+ ignore_zero=True,
+ conn_id=None,
*args,
- **kwargs,
+ **kwargs
):
- super().__init__(*args, **kwargs)
+ super(SQLIntervalCheckOperator, self).__init__(*args, **kwargs)
if ratio_formula not in self.ratio_formulas:
msg_template = (
"Invalid diff_method: {diff_method}. "
@@ -423,14 +423,14 @@ class SQLThresholdCheckOperator(BaseOperator):
@apply_defaults
def __init__(
self,
- sql: str,
- min_threshold: Any,
- max_threshold: Any,
- conn_id: Optional[str] = None,
+ sql,
+ min_threshold,
+ max_threshold,
+ conn_id=None,
*args,
- **kwargs,
+ **kwargs
):
- super().__init__(*args, **kwargs)
+ super(SQLThresholdCheckOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.conn_id = conn_id
self.min_threshold = _convert_to_float_if_possible(min_threshold)
@@ -461,13 +461,20 @@ class SQLThresholdCheckOperator(BaseOperator):
self.push(meta_data)
if not meta_data["within_threshold"]:
error_msg = (
- f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
- f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
- f'Check description: {meta_data.get("description")}\n'
- f"SQL: {self.sql}\n"
- f'Result: {round(meta_data.get("result"), 2)} is not within thresholds '
- f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
- )
+ 'Threshold Check: "{task_id}" failed.\n'
+ 'DAG: {dag_id}\nTask_id: {task_id}\n'
+ 'Check description: {description}\n'
+ "SQL: {sql}\n"
+ 'Result: {round} is not within thresholds '
+ '{min} and {max}'
+ .format(task_id=meta_data.get("task_id"),
+ dag_id=self.dag_id,
+ description=meta_data.get("description"),
+ sql=self.sql,
+ round=round(meta_data.get("result"), 2),
+ min=meta_data.get("min_threshold"),
+ max=meta_data.get("max_threshold"),
+ ))
raise AirflowException(error_msg)
self.log.info("Test %s Successful.", self.task_id)
@@ -478,7 +485,7 @@ class SQLThresholdCheckOperator(BaseOperator):
Default functionality will log metadata.
"""
- info = "\n".join([f"""{key}: {item}""" for key, item in meta_data.items()])
+ info = "\n".join(["{key}: {item}".format(key=key, item=item) for key, item in meta_data.items()])
self.log.info("Log from %s:\n%s", self.dag_id, info)
def get_db_hook(self):
@@ -516,16 +523,16 @@ class BranchSQLOperator(BaseOperator, SkipMixin):
@apply_defaults
def __init__(
self,
- sql: str,
- follow_task_ids_if_true: List[str],
- follow_task_ids_if_false: List[str],
- conn_id: str = "default_conn_id",
- database: Optional[str] = None,
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ sql,
+ follow_task_ids_if_true,
+ follow_task_ids_if_false,
+ conn_id="default_conn_id",
+ database=None,
+ parameters=None,
*args,
- **kwargs,
- ) -> None:
- super().__init__(*args, **kwargs)
+ **kwargs
+ ):
+ super(BranchSQLOperator, self).__init__(*args, **kwargs)
self.conn_id = conn_id
self.sql = sql
self.parameters = parameters
@@ -551,7 +558,7 @@ class BranchSQLOperator(BaseOperator, SkipMixin):
return self._hook
- def execute(self, context: Dict):
+ def execute(self, context):
# get supported hook
self._hook = self._get_hook()
diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
index cd319aa..b911e34 100644
--- a/airflow/operators/sql_branch_operator.py
+++ b/airflow/operators/sql_branch_operator.py
@@ -32,4 +32,4 @@ class BranchSqlOperator(BranchSQLOperator):
Please use `airflow.operators.sql.BranchSQLOperator`.""",
DeprecationWarning, stacklevel=2
)
- super().__init__(*args, **kwargs)
+ super(BranchSqlOperator, self).__init__(*args, **kwargs)
diff --git a/tests/operators/test_check_operator.py b/tests/operators/test_check_operator.py
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py
index a538f15..e5c1f98 100644
--- a/tests/operators/test_sql.py
+++ b/tests/operators/test_sql.py
@@ -200,8 +200,7 @@ class TestIntervalCheckOperator(unittest.TestCase):
[2, 2, 2, 2], # reference
[1, 1, 1, 1], # current
]
-
- yield from rows
+ return rows
mock_hook.get_first.side_effect = returned_row()
mock_get_db_hook.return_value = mock_hook
@@ -226,7 +225,7 @@ class TestIntervalCheckOperator(unittest.TestCase):
[1, 1, 1, 1], # current
]
- yield from rows
+ return rows
mock_hook.get_first.side_effect = returned_row()
mock_get_db_hook.return_value = mock_hook
diff --git a/tests/test_core_to_contrib.py b/tests/test_core_to_contrib.py
index 0a3e7fb..127905a 100644
--- a/tests/test_core_to_contrib.py
+++ b/tests/test_core_to_contrib.py
@@ -19,12 +19,10 @@
import importlib
import sys
from inspect import isabstract
-from typing import Any
from unittest import TestCase, mock
from parameterized import parameterized
-HOOKS = []
OPERATORS = [
(
@@ -49,24 +47,19 @@ OPERATORS = [
),
]
-SECRETS = []
-SENSORS = []
-
-TRANSFERS = []
-
-ALL = HOOKS + OPERATORS + SECRETS + SENSORS + TRANSFERS
+ALL = OPERATORS
RENAMED_HOOKS = [
(old_class, new_class)
- for old_class, new_class in HOOKS + OPERATORS + SECRETS + SENSORS
+ for old_class, new_class in OPERATORS
if old_class.rpartition(".")[2] != new_class.rpartition(".")[2]
]
class TestMovingCoreToContrib(TestCase):
@staticmethod
- def assert_warning(msg: str, warning: Any):
+ def assert_warning(msg, warning):
error = "Text '{}' not in warnings".format(msg)
assert any(msg in str(w) for w in warning.warnings), error
@@ -100,7 +93,7 @@ class TestMovingCoreToContrib(TestCase):
class_ = getattr(module, class_name)
if isabstract(class_) and not parent:
- class_name = f"Mock({class_.__name__})"
+ class_name = "Mock({class_name})".format(class_name=class_.__name__)
attributes = {
a: mock.MagicMock() for a in class_.__abstractmethods__