You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2020/03/05 22:55:16 UTC

[GitHub] [airflow] alexzlue commented on a change in pull request #7353: [AIRFLOW-6685] ThresholdCheckOperator

alexzlue commented on a change in pull request #7353: [AIRFLOW-6685] ThresholdCheckOperator
URL: https://github.com/apache/airflow/pull/7353#discussion_r388614519
 
 

 ##########
 File path: airflow/operators/check_operator.py
 ##########
 @@ -330,3 +330,87 @@ def execute(self, context=None):
 
     def get_db_hook(self):
         return BaseHook.get_hook(conn_id=self.conn_id)
+
+
+class ThresholdCheckOperator(BaseOperator):
+    """
+    Performs a value check using sql code against a mininmum threshold
+    and a maximum threshold. Thresholds can be in the form of a numeric
+    value OR a sql statement that results a numeric.
+
+    Note that this is an abstract class and get_db_hook
+    needs to be defined. Whereas a get_db_hook is hook that gets a
+    single record from an external source.
+
+    :param sql: the sql to be executed. (templated)
+    :type sql: str
+    :param min_threshold: numerical value or min threshold sql to be executed (templated)
+    :type min_threshold: numeric or str
+    :param max_threshold: numerical value or max threshold sql to be executed (templated)
+    :type max_threshold: numeric or str
+    """
+
+    template_fields = ('sql', 'min_threshold', 'max_threshold')  # type: Iterable[str]
+    template_ext = ('.hql', '.sql',)  # type: Iterable[str]
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        min_threshold: Any,
+        max_threshold: Any,
+        conn_id: Optional[str] = None,
+        *args, **kwargs
+    ):
+        super().__init__(*args, **kwargs)
+        self.sql = sql
+        self.conn_id = conn_id
+        self.min_threshold = _convert_to_float_if_possible(min_threshold)
+        self.max_threshold = _convert_to_float_if_possible(max_threshold)
+
+    def execute(self, context=None):
+        hook = self.get_db_hook()
+        result = hook.get_first(self.sql)[0][0]
+
+        if isinstance(self.min_threshold, float):
+            lower_bound = self.min_threshold
+        else:
+            lower_bound = hook.get_first(self.min_threshold)[0][0]
+
+        if isinstance(self.max_threshold, float):
+            upper_bound = self.max_threshold
+        else:
+            upper_bound = hook.get_first(self.max_threshold)[0][0]
+
+        meta_data = {
+            "result": result,
+            "task_id": self.task_id,
+            "min_threshold": lower_bound,
+            "max_threshold": upper_bound,
+            "within_threshold": lower_bound <= result <= upper_bound
+        }
+
+        self.push(meta_data)
+        if not meta_data["within_threshold"]:
+            error_msg = (f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
+                         f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
+                         f'Check description: {meta_data.get("description")}\n'
+                         f'SQL: {self.sql}\n'
+                         f'Result: {round(meta_data.get("result"), 2)} is not within thresholds '
+                         f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
+                         )
+            raise AirflowException(error_msg)
+
+        self.log.info("Test %s Successful.", self.task_id)
+
+    def push(self, meta_data):
+        """
+        Optional: Send data check info and metadata to an external database.
 
 Review comment:
   When inheriting from this class, push can be overwritten

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services