You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/12/01 04:05:14 UTC
[airflow] branch check-sql-hook updated: Add DBApiHook check for
2.0 migration
This is an automated email from the ASF dual-hosted git repository.
dimberman pushed a commit to branch check-sql-hook
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/check-sql-hook by this push:
new 21820ab Add DBApiHook check for 2.0 migration
21820ab is described below
commit 21820abe4b5ce76229c3fe68d7d7e30fd2f8c805
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Mon Nov 30 20:03:13 2020 -0800
Add DBApiHook check for 2.0 migration
Adds a check that ensures that any hook that uses the
run, get_pandas_df or get_records functions does not import from the
base_hook
---
airflow/upgrade/rules/db_api_functions.py | 85 ++++++++++++++++++++++++++++
tests/upgrade/rules/test_db_api_functions.py | 62 ++++++++++++++++++++
2 files changed, 147 insertions(+)
diff --git a/airflow/upgrade/rules/db_api_functions.py b/airflow/upgrade/rules/db_api_functions.py
new file mode 100644
index 0000000..47d8702
--- /dev/null
+++ b/airflow/upgrade/rules/db_api_functions.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.upgrade.rules.base_rule import BaseRule
+from airflow.exceptions import AirflowException
+
+
+def check_get_pandas_df(cls):
+ try:
+ cls.get_pandas_df(None, "fake SQL")
+ return return_error_string(cls, "get_pandas_df")
+ except NotImplementedError:
+ pass
+ except Exception as e:
+ raise AirflowException(
+ "the following hook incorrectly implements %s. error: %s", cls, e
+ )
+
+
+def check_run(cls):
+ try:
+ cls.run(None, "fake SQL")
+ return return_error_string(cls, "run")
+ except NotImplementedError:
+ pass
+ except Exception as e:
+ raise AirflowException(
+ "the following hook incorrectly implements run %s. error: %s", cls, e
+ )
+
+
+def check_get_records(cls):
+ try:
+ cls.get_records(None, "fake SQL")
+ return return_error_string(cls, "get_records")
+ except NotImplementedError:
+ pass
+ except Exception as e:
+ raise AirflowException(
+ "the following hook incorrectly implements run %s. error: %s", cls, e
+ )
+
+
+def return_error_string(cls, method):
+ return (
+ "Class {} incorrectly implements the function {} while inheriting from BaseHook. "
+ "Please make this class inherit from airflow.hooks.db_api_hook.DbApiHook instead".format(
+ cls, method
+ )
+ )
+
+
+class DbApiRule(BaseRule):
+ def check(self):
+ subclasses = BaseHook.__subclasses__()
+ incorrect_implementations = []
+ for s in subclasses:
+ if "airflow.hooks" in s.__module__:
+ pass
+ else:
+ pandas_df = check_get_pandas_df(s)
+ if pandas_df:
+ incorrect_implementations.append(pandas_df)
+ run = check_run(s)
+ if run:
+ incorrect_implementations.append(run)
+ get_records = check_get_records(s)
+ if get_records:
+ incorrect_implementations.append(get_records)
+ return incorrect_implementations
diff --git a/tests/upgrade/rules/test_db_api_functions.py b/tests/upgrade/rules/test_db_api_functions.py
new file mode 100644
index 0000000..4a061e7
--- /dev/null
+++ b/tests/upgrade/rules/test_db_api_functions.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import TestCase
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.hooks.dbapi_hook import DbApiHook
+from airflow.upgrade.rules.db_api_functions import DbApiRule
+
+
+class MyHook(BaseHook):
+ def get_records(self, sql):
+ pass
+
+ def run(self, sql):
+ pass
+
+ def get_pandas_df(self, sql):
+ pass
+
+ def get_conn(self):
+ pass
+
+
+class ProperDbApiHook(DbApiHook):
+
+ def bulk_dump(self, table, tmp_file):
+ pass
+
+ def bulk_load(self, table, tmp_file):
+ pass
+
+ def get_records(self, sql, *kwargs):
+ pass
+
+ def run(self, sql, *kwargs):
+ pass
+
+ def get_pandas_df(self, sql, *kwargs):
+ pass
+
+
+class TestSqlHookCheck(TestCase):
+ def test_fails_on_incorrect_hook(self):
+ db_api_rule_failures = DbApiRule().check()
+ self.assertEqual(len(db_api_rule_failures), 3)
+ proper_db_api_hook_failures = \
+ [failure for failure in db_api_rule_failures if "ProperDbApiHook" in failure]
+ self.assertEqual(len(proper_db_api_hook_failures), 0)