You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/12/04 23:25:14 UTC
[airflow] branch main updated: Removed hardcoded connection types. Check if hook is instance of DbApiHook. (#19639)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8f162aa Removed hardcoded connection types. Check if hook is instance of DbApiHook. (#19639)
8f162aa is described below
commit 8f162aa9f6a355165f4557fc28d021825b346357
Author: Dmytro Kazanzhy <dk...@gmail.com>
AuthorDate: Sun Dec 5 01:24:44 2021 +0200
Removed hardcoded connection types. Check if hook is instance of DbApiHook. (#19639)
Co-authored-by: Dmytro Kazanzhy <dk...@demandbase.com>
---
airflow/sensors/sql.py | 25 ++++++-------------------
tests/sensors/test_sql_sensor.py | 24 ++++++++++++------------
2 files changed, 18 insertions(+), 31 deletions(-)
diff --git a/airflow/sensors/sql.py b/airflow/sensors/sql.py
index ad4dfa4..123c891 100644
--- a/airflow/sensors/sql.py
+++ b/airflow/sensors/sql.py
@@ -20,6 +20,7 @@ from typing import Iterable
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
+from airflow.hooks.dbapi import DbApiHook
from airflow.sensors.base import BaseSensorOperator
@@ -83,27 +84,13 @@ class SqlSensor(BaseSensorOperator):
def _get_hook(self):
conn = BaseHook.get_connection(self.conn_id)
-
- allowed_conn_type = {
- 'google_cloud_platform',
- 'jdbc',
- 'mssql',
- 'mysql',
- 'odbc',
- 'oracle',
- 'postgres',
- 'presto',
- 'snowflake',
- 'sqlite',
- 'trino',
- 'vertica',
- }
- if conn.conn_type not in allowed_conn_type:
+ hook = conn.get_hook(hook_params=self.hook_params)
+ if not isinstance(hook, DbApiHook):
raise AirflowException(
- f"Connection type ({conn.conn_type}) is not supported by SqlSensor. "
- + f"Supported connection types: {list(allowed_conn_type)}"
+ f'The connection type is not supported by {self.__class__.__name__}. '
+ f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
)
- return conn.get_hook(hook_params=self.hook_params)
+ return hook
def poke(self, context):
hook = self._get_hook()
diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py
index 92ec1c7..23c31aa 100644
--- a/tests/sensors/test_sql_sensor.py
+++ b/tests/sensors/test_sql_sensor.py
@@ -23,7 +23,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
-from airflow.sensors.sql import SqlSensor
+from airflow.sensors.sql import DbApiHook, SqlSensor
from airflow.utils.timezone import datetime
from tests.providers.apache.hive import TestHiveEnvironment
@@ -94,7 +94,7 @@ class TestSqlSensor(TestHiveEnvironment):
sql="SELECT 1",
)
- mock_hook.get_connection('postgres_default').conn_type = "postgres"
+ mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
@@ -124,7 +124,7 @@ class TestSqlSensor(TestHiveEnvironment):
task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", fail_on_empty=True
)
- mock_hook.get_connection('postgres_default').conn_type = "postgres"
+ mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
@@ -137,7 +137,7 @@ class TestSqlSensor(TestHiveEnvironment):
task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", success=lambda x: x in [1]
)
- mock_hook.get_connection('postgres_default').conn_type = "postgres"
+ mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
@@ -155,7 +155,7 @@ class TestSqlSensor(TestHiveEnvironment):
task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", failure=lambda x: x in [1]
)
- mock_hook.get_connection('postgres_default').conn_type = "postgres"
+ mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
@@ -175,7 +175,7 @@ class TestSqlSensor(TestHiveEnvironment):
success=lambda x: x in [2],
)
- mock_hook.get_connection('postgres_default').conn_type = "postgres"
+ mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
@@ -198,7 +198,7 @@ class TestSqlSensor(TestHiveEnvironment):
success=lambda x: x in [1],
)
- mock_hook.get_connection('postgres_default').conn_type = "postgres"
+ mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
@@ -217,7 +217,7 @@ class TestSqlSensor(TestHiveEnvironment):
failure=[1],
)
- mock_hook.get_connection('postgres_default').conn_type = "postgres"
+ mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = [[1]]
@@ -234,7 +234,7 @@ class TestSqlSensor(TestHiveEnvironment):
success=[1],
)
- mock_hook.get_connection('postgres_default').conn_type = "postgres"
+ mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = [[1]]
@@ -257,11 +257,11 @@ class TestSqlSensor(TestHiveEnvironment):
def test_sql_sensor_hook_params(self):
op = SqlSensor(
task_id='sql_sensor_hook_params',
- conn_id='google_cloud_default',
+ conn_id='postgres_default',
sql="SELECT 1",
hook_params={
- 'delegate_to': 'me',
+ 'schema': 'public',
},
)
hook = op._get_hook()
- assert hook.delegate_to == 'me'
+ assert hook.schema == 'public'