You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by fo...@apache.org on 2018/03/26 19:15:29 UTC

incubator-airflow git commit: [AIRFLOW-2228] Enhancements in ValueCheckOperator

Repository: incubator-airflow
Updated Branches:
  refs/heads/master b65dc43d2 -> acc9a3617


[AIRFLOW-2228] Enhancements in ValueCheckOperator

Allow ValueCheckOperator to accept a tolerance of
1.
Modify pass_value to be a template field,
so that its value can be determined at runtime.
Add tolerance value in airflow exception.
This gives an idea about the allowed range for
resultant records.

Closes #3149 from sakshi2894/AIRFLOW-2228


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/acc9a361
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/acc9a361
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/acc9a361

Branch: refs/heads/master
Commit: acc9a3617e7d4520bb95191c80f4d3d2e64f622d
Parents: b65dc43
Author: Sakshi Bansal <sa...@qubole.com>
Authored: Mon Mar 26 21:15:22 2018 +0200
Committer: Fokko Driesprong <fo...@godatadriven.com>
Committed: Mon Mar 26 21:15:22 2018 +0200

----------------------------------------------------------------------
 airflow/operators/check_operator.py    | 30 ++++++----
 tests/operators/test_check_operator.py | 90 +++++++++++++++++++++++++++++
 2 files changed, 109 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/acc9a361/airflow/operators/check_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index ff82539..9994671 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -115,7 +115,7 @@ class ValueCheckOperator(BaseOperator):
     __mapper_args__ = {
         'polymorphic_identity': 'ValueCheckOperator'
     }
-    template_fields = ('sql',)
+    template_fields = ('sql', 'pass_value',)
     template_ext = ('.hql', '.sql',)
     ui_color = '#fff7e6'
 
@@ -127,10 +127,9 @@ class ValueCheckOperator(BaseOperator):
         super(ValueCheckOperator, self).__init__(*args, **kwargs)
         self.sql = sql
         self.conn_id = conn_id
-        self.pass_value = _convert_to_float_if_possible(pass_value)
+        self.pass_value = str(pass_value)
         tol = _convert_to_float_if_possible(tolerance)
         self.tol = tol if isinstance(tol, float) else None
-        self.is_numeric_value_check = isinstance(self.pass_value, float)
         self.has_tolerance = self.tol is not None
 
     def execute(self, context=None):
@@ -138,23 +137,32 @@ class ValueCheckOperator(BaseOperator):
         records = self.get_db_hook().get_first(self.sql)
         if not records:
             raise AirflowException("The query returned None")
-        test_results = []
-        except_temp = ("Test failed.\nPass value:{self.pass_value}\n"
+
+        pass_value_conv = _convert_to_float_if_possible(self.pass_value)
+        is_numeric_value_check = isinstance(pass_value_conv, float)
+
+        tolerance_pct_str = None
+        if (self.tol is not None):
+            tolerance_pct_str = str(self.tol * 100) + '%'
+
+        except_temp = ("Test failed.\nPass value:{pass_value_conv}\n"
+                       "Tolerance:{tolerance_pct_str}\n"
                        "Query:\n{self.sql}\nResults:\n{records!s}")
-        if not self.is_numeric_value_check:
-            tests = [str(r) == self.pass_value for r in records]
-        elif self.is_numeric_value_check:
+        if not is_numeric_value_check:
+            tests = [str(r) == pass_value_conv for r in records]
+        elif is_numeric_value_check:
             try:
                 num_rec = [float(r) for r in records]
             except (ValueError, TypeError) as e:
                 cvestr = "Converting a result to float failed.\n"
-                raise AirflowException(cvestr+except_temp.format(**locals()))
+                raise AirflowException(cvestr + except_temp.format(**locals()))
             if self.has_tolerance:
                 tests = [
-                    r / (1 + self.tol) <= self.pass_value <= r / (1 - self.tol)
+                    pass_value_conv * (1 - self.tol) <=
+                    r <= pass_value_conv * (1 + self.tol)
                     for r in num_rec]
             else:
-                tests = [r == self.pass_value for r in num_rec]
+                tests = [r == pass_value_conv for r in num_rec]
         if not all(tests):
             raise AirflowException(except_temp.format(**locals()))
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/acc9a361/tests/operators/test_check_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/test_check_operator.py b/tests/operators/test_check_operator.py
new file mode 100644
index 0000000..903d547
--- /dev/null
+++ b/tests/operators/test_check_operator.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from datetime import datetime
+from airflow.models import DAG
+from airflow.exceptions import AirflowException
+from airflow.operators.check_operator import ValueCheckOperator
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+class ValueCheckOperatorTest(unittest.TestCase):
+
+    def setUp(self):
+        self.task_id = 'test_task'
+        self.conn_id = 'default_conn'
+
+    def __construct_operator(self, sql, pass_value, tolerance=None):
+
+        dag = DAG('test_dag', start_date=datetime(2017, 1, 1))
+
+        return ValueCheckOperator(
+            dag=dag,
+            task_id=self.task_id,
+            conn_id=self.conn_id,
+            sql=sql,
+            pass_value=pass_value,
+            tolerance=tolerance)
+
+    def test_pass_value_template_string(self):
+        pass_value_str = "2018-03-22"
+        operator = self.__construct_operator('select date from tab1;', "{{ ds }}")
+        result = operator.render_template('pass_value', operator.pass_value,
+                                          {'ds': pass_value_str})
+
+        self.assertEqual(operator.task_id, self.task_id)
+        self.assertEqual(result, pass_value_str)
+
+    def test_pass_value_template_string_float(self):
+        pass_value_float = 4.0
+        operator = self.__construct_operator('select date from tab1;', pass_value_float)
+        result = operator.render_template('pass_value', operator.pass_value, {})
+
+        self.assertEqual(operator.task_id, self.task_id)
+        self.assertEqual(result, str(pass_value_float))
+
+    @mock.patch.object(ValueCheckOperator, 'get_db_hook')
+    def test_execute_pass(self, mock_get_db_hook):
+
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [10]
+        mock_get_db_hook.return_value = mock_hook
+
+        sql = 'select value from tab1 limit 1;'
+
+        operator = self.__construct_operator(sql, 5, 1)
+
+        operator.execute(None)
+
+        mock_hook.get_first.assert_called_with(sql)
+
+    @mock.patch.object(ValueCheckOperator, 'get_db_hook')
+    def test_execute_fail(self, mock_get_db_hook):
+
+        mock_hook = mock.Mock()
+        mock_hook.get_first.return_value = [11]
+        mock_get_db_hook.return_value = mock_hook
+
+        operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
+
+        with self.assertRaisesRegexp(AirflowException, 'Tolerance:100.0%'):
+            operator.execute()