You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by di...@apache.org on 2020/12/01 04:05:14 UTC

[airflow] branch check-sql-hook updated: Add DBApiHook check for 2.0 migration

This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch check-sql-hook
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/check-sql-hook by this push:
     new 21820ab  Add DBApiHook check for 2.0 migration
21820ab is described below

commit 21820abe4b5ce76229c3fe68d7d7e30fd2f8c805
Author: Daniel Imberman <da...@gmail.com>
AuthorDate: Mon Nov 30 20:03:13 2020 -0800

    Add DBApiHook check for 2.0 migration
    
    Adds a check that ensures that any hook that uses the
    run, get_pandas_df or get_records functions does not import from the
    base_hook
---
 airflow/upgrade/rules/db_api_functions.py    | 85 ++++++++++++++++++++++++++++
 tests/upgrade/rules/test_db_api_functions.py | 62 ++++++++++++++++++++
 2 files changed, 147 insertions(+)

diff --git a/airflow/upgrade/rules/db_api_functions.py b/airflow/upgrade/rules/db_api_functions.py
new file mode 100644
index 0000000..47d8702
--- /dev/null
+++ b/airflow/upgrade/rules/db_api_functions.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.upgrade.rules.base_rule import BaseRule
+from airflow.exceptions import AirflowException
+
+
+def check_get_pandas_df(cls):
+    try:
+        cls.get_pandas_df(None, "fake SQL")
+        return return_error_string(cls, "get_pandas_df")
+    except NotImplementedError:
+        pass
+    except Exception as e:
+        raise AirflowException(
+            "the following hook incorrectly implements %s. error: %s", cls, e
+        )
+
+
+def check_run(cls):
+    try:
+        cls.run(None, "fake SQL")
+        return return_error_string(cls, "run")
+    except NotImplementedError:
+        pass
+    except Exception as e:
+        raise AirflowException(
+            "the following hook incorrectly implements run %s. error: %s", cls, e
+        )
+
+
+def check_get_records(cls):
+    try:
+        cls.get_records(None, "fake SQL")
+        return return_error_string(cls, "get_records")
+    except NotImplementedError:
+        pass
+    except Exception as e:
+        raise AirflowException(
+            "the following hook incorrectly implements run %s. error: %s", cls, e
+        )
+
+
+def return_error_string(cls, method):
+    return (
+        "Class {} incorrectly implements the function {} while inheriting from BaseHook. "
+        "Please make this class inherit from airflow.hooks.db_api_hook.DbApiHook instead".format(
+            cls, method
+        )
+    )
+
+
+class DbApiRule(BaseRule):
+    def check(self):
+        subclasses = BaseHook.__subclasses__()
+        incorrect_implementations = []
+        for s in subclasses:
+            if "airflow.hooks" in s.__module__:
+                pass
+            else:
+                pandas_df = check_get_pandas_df(s)
+                if pandas_df:
+                    incorrect_implementations.append(pandas_df)
+                run = check_run(s)
+                if run:
+                    incorrect_implementations.append(run)
+                get_records = check_get_records(s)
+                if get_records:
+                    incorrect_implementations.append(get_records)
+        return incorrect_implementations
diff --git a/tests/upgrade/rules/test_db_api_functions.py b/tests/upgrade/rules/test_db_api_functions.py
new file mode 100644
index 0000000..4a061e7
--- /dev/null
+++ b/tests/upgrade/rules/test_db_api_functions.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest import TestCase
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.hooks.dbapi_hook import DbApiHook
+from airflow.upgrade.rules.db_api_functions import DbApiRule
+
+
+class MyHook(BaseHook):
+    def get_records(self, sql):
+        pass
+
+    def run(self, sql):
+        pass
+
+    def get_pandas_df(self, sql):
+        pass
+
+    def get_conn(self):
+        pass
+
+
+class ProperDbApiHook(DbApiHook):
+
+    def bulk_dump(self, table, tmp_file):
+        pass
+
+    def bulk_load(self, table, tmp_file):
+        pass
+
+    def get_records(self, sql, *kwargs):
+        pass
+
+    def run(self, sql, *kwargs):
+        pass
+
+    def get_pandas_df(self, sql, *kwargs):
+        pass
+
+
+class TestSqlHookCheck(TestCase):
+    def test_fails_on_incorrect_hook(self):
+        db_api_rule_failures = DbApiRule().check()
+        self.assertEqual(len(db_api_rule_failures), 3)
+        proper_db_api_hook_failures = \
+            [failure for failure in db_api_rule_failures if "ProperDbApiHook" in failure]
+        self.assertEqual(len(proper_db_api_hook_failures), 0)