You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/12/11 17:20:23 UTC
[airflow] branch v1-10-stable updated: Add DBApiHook check for 2.0
migration (#12730)
This is an automated email from the ASF dual-hosted git repository.
dimberman pushed a commit to branch v1-10-stable
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-stable by this push:
new 2e1f813 Add DBApiHook check for 2.0 migration (#12730)
2e1f813 is described below
commit 2e1f813c35e60d9e13575639bc913d1cbafcd1ff
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Fri Dec 11 09:19:16 2020 -0800
Add DBApiHook check for 2.0 migration (#12730)
* Add DBApiHook check for 2.0 migration
Adds a check that ensures that any hook that uses the
run, get_pandas_df or get_records functions does not import from the
base_hook
* exception for grpc_hook
* fix plugin
* fix plugin
* fix plugin
* py2 compliance and add full lineage
* black
* fix
---
airflow/upgrade/rules/db_api_functions.py | 97 ++++++++++++++++++++++++++++
tests/upgrade/rules/test_db_api_functions.py | 71 ++++++++++++++++++++
2 files changed, 168 insertions(+)
diff --git a/airflow/upgrade/rules/db_api_functions.py b/airflow/upgrade/rules/db_api_functions.py
new file mode 100644
index 0000000..1801c36
--- /dev/null
+++ b/airflow/upgrade/rules/db_api_functions.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.upgrade.rules.base_rule import BaseRule
+
+
+def check_get_pandas_df(cls):
+ try:
+ cls.__new__(cls).get_pandas_df("fake SQL")
+ return return_error_string(cls, "get_pandas_df")
+ except NotImplementedError:
+ pass
+ except Exception:
+ return return_error_string(cls, "get_pandas_df")
+
+
+def check_run(cls):
+ try:
+ cls.__new__(cls).run("fake SQL")
+ return return_error_string(cls, "run")
+ except NotImplementedError:
+ pass
+ except Exception:
+ return return_error_string(cls, "run")
+
+
+def check_get_records(cls):
+ try:
+ cls.__new__(cls).get_records("fake SQL")
+ return return_error_string(cls, "get_records")
+ except NotImplementedError:
+ pass
+ except Exception:
+ return return_error_string(cls, "get_records")
+
+
+def return_error_string(cls, method):
+ return (
+ "Class {} incorrectly implements the function {} while inheriting from BaseHook. "
+ "Please make this class inherit from airflow.hooks.db_api_hook.DbApiHook instead".format(
+ cls, method
+ )
+ )
+
+
+def get_all_non_dbapi_children():
+ basehook_children = [
+ child for child in BaseHook.__subclasses__() if child.__name__ != "DbApiHook"
+ ]
+ res = basehook_children[:]
+ while basehook_children:
+ next_generation = []
+ for child in basehook_children:
+ subclasses = child.__subclasses__()
+ if subclasses:
+ next_generation.extend(subclasses)
+ res.extend(next_generation)
+ basehook_children = next_generation
+ return res
+
+
+class DbApiRule(BaseRule):
+ title = "Hooks that run DB functions must inherit from DBApiHook"
+
+ description = (
+ "Hooks that run DB functions must inherit from DBApiHook instead of BaseHook"
+ )
+
+ def check(self):
+ basehook_subclasses = get_all_non_dbapi_children()
+ incorrect_implementations = []
+ for child in basehook_subclasses:
+ pandas_df = check_get_pandas_df(child)
+ if pandas_df:
+ incorrect_implementations.append(pandas_df)
+ run = check_run(child)
+ if run:
+ incorrect_implementations.append(run)
+ get_records = check_get_records(child)
+ if get_records:
+ incorrect_implementations.append(get_records)
+ return incorrect_implementations
diff --git a/tests/upgrade/rules/test_db_api_functions.py b/tests/upgrade/rules/test_db_api_functions.py
new file mode 100644
index 0000000..d73a041
--- /dev/null
+++ b/tests/upgrade/rules/test_db_api_functions.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import TestCase
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.hooks.dbapi_hook import DbApiHook
+from airflow.upgrade.rules.db_api_functions import DbApiRule
+
+
+class MyHook(BaseHook):
+ def run(self, sql):
+ pass
+
+ def get_pandas_df(self, sql):
+ pass
+
+ def get_conn(self):
+ pass
+
+
+class GrandChildHook(MyHook):
+ def __init__(self, foo, bar):
+ self.foo = foo
+ self.bar = bar
+
+ def get_records(self, sql):
+ pass
+
+
+class ProperDbApiHook(DbApiHook):
+ def bulk_dump(self, table, tmp_file):
+ pass
+
+ def bulk_load(self, table, tmp_file):
+ pass
+
+ def get_records(self, sql, *kwargs):
+ pass
+
+ def run(self, sql, *kwargs):
+ pass
+
+ def get_pandas_df(self, sql, *kwargs):
+ pass
+
+
+class TestSqlHookCheck(TestCase):
+ def test_fails_on_incorrect_hook(self):
+ db_api_rule_failures = DbApiRule().check()
+ myhook_errors = [d for d in db_api_rule_failures if "MyHook" in d]
+ grandchild_errors = [d for d in db_api_rule_failures if "GrandChild" in d]
+ self.assertEqual(len(myhook_errors), 2)
+ self.assertEqual(len(grandchild_errors), 3)
+ proper_db_api_hook_failures = [
+ failure for failure in db_api_rule_failures if "ProperDbApiHook" in failure
+ ]
+ self.assertEqual(len(proper_db_api_hook_failures), 0)