You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/12/11 17:20:23 UTC

[airflow] branch v1-10-stable updated: Add DBApiHook check for 2.0 migration (#12730)

This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch v1-10-stable
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-stable by this push:
     new 2e1f813  Add DBApiHook check for 2.0 migration (#12730)
2e1f813 is described below

commit 2e1f813c35e60d9e13575639bc913d1cbafcd1ff
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Fri Dec 11 09:19:16 2020 -0800

    Add DBApiHook check for 2.0 migration (#12730)
    
    * Add DBApiHook check for 2.0 migration
    
    Adds a check that ensures that any hook that uses the
    run, get_pandas_df or get_records functions does not import from the
    base_hook
    
    * exception for grpc_hook
    
    * fix plugin
    
    * fix plugin
    
    * fix plugin
    
    * py2 compliance and add full lineage
    
    * black
    
    * fix
---
 airflow/upgrade/rules/db_api_functions.py    | 97 ++++++++++++++++++++++++++++
 tests/upgrade/rules/test_db_api_functions.py | 71 ++++++++++++++++++++
 2 files changed, 168 insertions(+)

diff --git a/airflow/upgrade/rules/db_api_functions.py b/airflow/upgrade/rules/db_api_functions.py
new file mode 100644
index 0000000..1801c36
--- /dev/null
+++ b/airflow/upgrade/rules/db_api_functions.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.upgrade.rules.base_rule import BaseRule
+
+
+def check_get_pandas_df(cls):
+    try:
+        cls.__new__(cls).get_pandas_df("fake SQL")
+        return return_error_string(cls, "get_pandas_df")
+    except NotImplementedError:
+        pass
+    except Exception:
+        return return_error_string(cls, "get_pandas_df")
+
+
+def check_run(cls):
+    try:
+        cls.__new__(cls).run("fake SQL")
+        return return_error_string(cls, "run")
+    except NotImplementedError:
+        pass
+    except Exception:
+        return return_error_string(cls, "run")
+
+
+def check_get_records(cls):
+    try:
+        cls.__new__(cls).get_records("fake SQL")
+        return return_error_string(cls, "get_records")
+    except NotImplementedError:
+        pass
+    except Exception:
+        return return_error_string(cls, "get_records")
+
+
+def return_error_string(cls, method):
+    return (
+        "Class {} incorrectly implements the function {} while inheriting from BaseHook. "
+        "Please make this class inherit from airflow.hooks.db_api_hook.DbApiHook instead".format(
+            cls, method
+        )
+    )
+
+
+def get_all_non_dbapi_children():
+    basehook_children = [
+        child for child in BaseHook.__subclasses__() if child.__name__ != "DbApiHook"
+    ]
+    res = basehook_children[:]
+    while basehook_children:
+        next_generation = []
+        for child in basehook_children:
+            subclasses = child.__subclasses__()
+            if subclasses:
+                next_generation.extend(subclasses)
+        res.extend(next_generation)
+        basehook_children = next_generation
+    return res
+
+
+class DbApiRule(BaseRule):
+    title = "Hooks that run DB functions must inherit from DBApiHook"
+
+    description = (
+        "Hooks that run DB functions must inherit from DBApiHook instead of BaseHook"
+    )
+
+    def check(self):
+        basehook_subclasses = get_all_non_dbapi_children()
+        incorrect_implementations = []
+        for child in basehook_subclasses:
+            pandas_df = check_get_pandas_df(child)
+            if pandas_df:
+                incorrect_implementations.append(pandas_df)
+            run = check_run(child)
+            if run:
+                incorrect_implementations.append(run)
+            get_records = check_get_records(child)
+            if get_records:
+                incorrect_implementations.append(get_records)
+        return incorrect_implementations
diff --git a/tests/upgrade/rules/test_db_api_functions.py b/tests/upgrade/rules/test_db_api_functions.py
new file mode 100644
index 0000000..d73a041
--- /dev/null
+++ b/tests/upgrade/rules/test_db_api_functions.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import TestCase
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.hooks.dbapi_hook import DbApiHook
+from airflow.upgrade.rules.db_api_functions import DbApiRule
+
+
+class MyHook(BaseHook):
+    def run(self, sql):
+        pass
+
+    def get_pandas_df(self, sql):
+        pass
+
+    def get_conn(self):
+        pass
+
+
+class GrandChildHook(MyHook):
+    def __init__(self, foo, bar):
+        self.foo = foo
+        self.bar = bar
+
+    def get_records(self, sql):
+        pass
+
+
+class ProperDbApiHook(DbApiHook):
+    def bulk_dump(self, table, tmp_file):
+        pass
+
+    def bulk_load(self, table, tmp_file):
+        pass
+
+    def get_records(self, sql, *kwargs):
+        pass
+
+    def run(self, sql, *kwargs):
+        pass
+
+    def get_pandas_df(self, sql, *kwargs):
+        pass
+
+
+class TestSqlHookCheck(TestCase):
+    def test_fails_on_incorrect_hook(self):
+        db_api_rule_failures = DbApiRule().check()
+        myhook_errors = [d for d in db_api_rule_failures if "MyHook" in d]
+        grandchild_errors = [d for d in db_api_rule_failures if "GrandChild" in d]
+        self.assertEqual(len(myhook_errors), 2)
+        self.assertEqual(len(grandchild_errors), 3)
+        proper_db_api_hook_failures = [
+            failure for failure in db_api_rule_failures if "ProperDbApiHook" in failure
+        ]
+        self.assertEqual(len(proper_db_api_hook_failures), 0)