You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by fo...@apache.org on 2018/03/26 19:15:29 UTC
incubator-airflow git commit: [AIRFLOW-2228] Enhancements in
ValueCheckOperator
Repository: incubator-airflow
Updated Branches:
refs/heads/master b65dc43d2 -> acc9a3617
[AIRFLOW-2228] Enhancements in ValueCheckOperator
Allow ValueCheckOperator to accept a tolerance of
1.
Modify pass_value to be a template field,
so that its value can be determined at runtime.
Add tolerance value in airflow exception.
This gives an idea about the allowed range for
resultant records.
Closes #3149 from sakshi2894/AIRFLOW-2228
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/acc9a361
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/acc9a361
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/acc9a361
Branch: refs/heads/master
Commit: acc9a3617e7d4520bb95191c80f4d3d2e64f622d
Parents: b65dc43
Author: Sakshi Bansal <sa...@qubole.com>
Authored: Mon Mar 26 21:15:22 2018 +0200
Committer: Fokko Driesprong <fo...@godatadriven.com>
Committed: Mon Mar 26 21:15:22 2018 +0200
----------------------------------------------------------------------
airflow/operators/check_operator.py | 30 ++++++----
tests/operators/test_check_operator.py | 90 +++++++++++++++++++++++++++++
2 files changed, 109 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/acc9a361/airflow/operators/check_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
index ff82539..9994671 100644
--- a/airflow/operators/check_operator.py
+++ b/airflow/operators/check_operator.py
@@ -115,7 +115,7 @@ class ValueCheckOperator(BaseOperator):
__mapper_args__ = {
'polymorphic_identity': 'ValueCheckOperator'
}
- template_fields = ('sql',)
+ template_fields = ('sql', 'pass_value',)
template_ext = ('.hql', '.sql',)
ui_color = '#fff7e6'
@@ -127,10 +127,9 @@ class ValueCheckOperator(BaseOperator):
super(ValueCheckOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.conn_id = conn_id
- self.pass_value = _convert_to_float_if_possible(pass_value)
+ self.pass_value = str(pass_value)
tol = _convert_to_float_if_possible(tolerance)
self.tol = tol if isinstance(tol, float) else None
- self.is_numeric_value_check = isinstance(self.pass_value, float)
self.has_tolerance = self.tol is not None
def execute(self, context=None):
@@ -138,23 +137,32 @@ class ValueCheckOperator(BaseOperator):
records = self.get_db_hook().get_first(self.sql)
if not records:
raise AirflowException("The query returned None")
- test_results = []
- except_temp = ("Test failed.\nPass value:{self.pass_value}\n"
+
+ pass_value_conv = _convert_to_float_if_possible(self.pass_value)
+ is_numeric_value_check = isinstance(pass_value_conv, float)
+
+ tolerance_pct_str = None
+ if (self.tol is not None):
+ tolerance_pct_str = str(self.tol * 100) + '%'
+
+ except_temp = ("Test failed.\nPass value:{pass_value_conv}\n"
+ "Tolerance:{tolerance_pct_str}\n"
"Query:\n{self.sql}\nResults:\n{records!s}")
- if not self.is_numeric_value_check:
- tests = [str(r) == self.pass_value for r in records]
- elif self.is_numeric_value_check:
+ if not is_numeric_value_check:
+ tests = [str(r) == pass_value_conv for r in records]
+ elif is_numeric_value_check:
try:
num_rec = [float(r) for r in records]
except (ValueError, TypeError) as e:
cvestr = "Converting a result to float failed.\n"
- raise AirflowException(cvestr+except_temp.format(**locals()))
+ raise AirflowException(cvestr + except_temp.format(**locals()))
if self.has_tolerance:
tests = [
- r / (1 + self.tol) <= self.pass_value <= r / (1 - self.tol)
+ pass_value_conv * (1 - self.tol) <=
+ r <= pass_value_conv * (1 + self.tol)
for r in num_rec]
else:
- tests = [r == self.pass_value for r in num_rec]
+ tests = [r == pass_value_conv for r in num_rec]
if not all(tests):
raise AirflowException(except_temp.format(**locals()))
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/acc9a361/tests/operators/test_check_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/test_check_operator.py b/tests/operators/test_check_operator.py
new file mode 100644
index 0000000..903d547
--- /dev/null
+++ b/tests/operators/test_check_operator.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from datetime import datetime
+from airflow.models import DAG
+from airflow.exceptions import AirflowException
+from airflow.operators.check_operator import ValueCheckOperator
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+
+class ValueCheckOperatorTest(unittest.TestCase):
+
+ def setUp(self):
+ self.task_id = 'test_task'
+ self.conn_id = 'default_conn'
+
+ def __construct_operator(self, sql, pass_value, tolerance=None):
+
+ dag = DAG('test_dag', start_date=datetime(2017, 1, 1))
+
+ return ValueCheckOperator(
+ dag=dag,
+ task_id=self.task_id,
+ conn_id=self.conn_id,
+ sql=sql,
+ pass_value=pass_value,
+ tolerance=tolerance)
+
+ def test_pass_value_template_string(self):
+ pass_value_str = "2018-03-22"
+ operator = self.__construct_operator('select date from tab1;', "{{ ds }}")
+ result = operator.render_template('pass_value', operator.pass_value,
+ {'ds': pass_value_str})
+
+ self.assertEqual(operator.task_id, self.task_id)
+ self.assertEqual(result, pass_value_str)
+
+ def test_pass_value_template_string_float(self):
+ pass_value_float = 4.0
+ operator = self.__construct_operator('select date from tab1;', pass_value_float)
+ result = operator.render_template('pass_value', operator.pass_value, {})
+
+ self.assertEqual(operator.task_id, self.task_id)
+ self.assertEqual(result, str(pass_value_float))
+
+ @mock.patch.object(ValueCheckOperator, 'get_db_hook')
+ def test_execute_pass(self, mock_get_db_hook):
+
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [10]
+ mock_get_db_hook.return_value = mock_hook
+
+ sql = 'select value from tab1 limit 1;'
+
+ operator = self.__construct_operator(sql, 5, 1)
+
+ operator.execute(None)
+
+ mock_hook.get_first.assert_called_with(sql)
+
+ @mock.patch.object(ValueCheckOperator, 'get_db_hook')
+ def test_execute_fail(self, mock_get_db_hook):
+
+ mock_hook = mock.Mock()
+ mock_hook.get_first.return_value = [11]
+ mock_get_db_hook.return_value = mock_hook
+
+ operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
+
+ with self.assertRaisesRegexp(AirflowException, 'Tolerance:100.0%'):
+ operator.execute()