You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/12/04 23:25:14 UTC

[airflow] branch main updated: Removed hardcoded connection types. Check if hook is instance of DbApiHook. (#19639)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f162aa  Removed hardcoded connection types. Check if hook is instance of DbApiHook. (#19639)
8f162aa is described below

commit 8f162aa9f6a355165f4557fc28d021825b346357
Author: Dmytro Kazanzhy <dk...@gmail.com>
AuthorDate: Sun Dec 5 01:24:44 2021 +0200

    Removed hardcoded connection types. Check if hook is instance of DbApiHook. (#19639)
    
    Co-authored-by: Dmytro Kazanzhy <dk...@demandbase.com>
---
 airflow/sensors/sql.py           | 25 ++++++-------------------
 tests/sensors/test_sql_sensor.py | 24 ++++++++++++------------
 2 files changed, 18 insertions(+), 31 deletions(-)

diff --git a/airflow/sensors/sql.py b/airflow/sensors/sql.py
index ad4dfa4..123c891 100644
--- a/airflow/sensors/sql.py
+++ b/airflow/sensors/sql.py
@@ -20,6 +20,7 @@ from typing import Iterable
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
+from airflow.hooks.dbapi import DbApiHook
 from airflow.sensors.base import BaseSensorOperator
 
 
@@ -83,27 +84,13 @@ class SqlSensor(BaseSensorOperator):
 
     def _get_hook(self):
         conn = BaseHook.get_connection(self.conn_id)
-
-        allowed_conn_type = {
-            'google_cloud_platform',
-            'jdbc',
-            'mssql',
-            'mysql',
-            'odbc',
-            'oracle',
-            'postgres',
-            'presto',
-            'snowflake',
-            'sqlite',
-            'trino',
-            'vertica',
-        }
-        if conn.conn_type not in allowed_conn_type:
+        hook = conn.get_hook(hook_params=self.hook_params)
+        if not isinstance(hook, DbApiHook):
             raise AirflowException(
-                f"Connection type ({conn.conn_type}) is not supported by SqlSensor. "
-                + f"Supported connection types: {list(allowed_conn_type)}"
+                f'The connection type is not supported by {self.__class__.__name__}. '
+                f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
             )
-        return conn.get_hook(hook_params=self.hook_params)
+        return hook
 
     def poke(self, context):
         hook = self._get_hook()
diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py
index 92ec1c7..23c31aa 100644
--- a/tests/sensors/test_sql_sensor.py
+++ b/tests/sensors/test_sql_sensor.py
@@ -23,7 +23,7 @@ import pytest
 
 from airflow.exceptions import AirflowException
 from airflow.models.dag import DAG
-from airflow.sensors.sql import SqlSensor
+from airflow.sensors.sql import DbApiHook, SqlSensor
 from airflow.utils.timezone import datetime
 from tests.providers.apache.hive import TestHiveEnvironment
 
@@ -94,7 +94,7 @@ class TestSqlSensor(TestHiveEnvironment):
             sql="SELECT 1",
         )
 
-        mock_hook.get_connection('postgres_default').conn_type = "postgres"
+        mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
         mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = []
@@ -124,7 +124,7 @@ class TestSqlSensor(TestHiveEnvironment):
             task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", fail_on_empty=True
         )
 
-        mock_hook.get_connection('postgres_default').conn_type = "postgres"
+        mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
         mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = []
@@ -137,7 +137,7 @@ class TestSqlSensor(TestHiveEnvironment):
             task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", success=lambda x: x in [1]
         )
 
-        mock_hook.get_connection('postgres_default').conn_type = "postgres"
+        mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
         mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = []
@@ -155,7 +155,7 @@ class TestSqlSensor(TestHiveEnvironment):
             task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT 1", failure=lambda x: x in [1]
         )
 
-        mock_hook.get_connection('postgres_default').conn_type = "postgres"
+        mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
         mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = []
@@ -175,7 +175,7 @@ class TestSqlSensor(TestHiveEnvironment):
             success=lambda x: x in [2],
         )
 
-        mock_hook.get_connection('postgres_default').conn_type = "postgres"
+        mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
         mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = []
@@ -198,7 +198,7 @@ class TestSqlSensor(TestHiveEnvironment):
             success=lambda x: x in [1],
         )
 
-        mock_hook.get_connection('postgres_default').conn_type = "postgres"
+        mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
         mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = []
@@ -217,7 +217,7 @@ class TestSqlSensor(TestHiveEnvironment):
             failure=[1],
         )
 
-        mock_hook.get_connection('postgres_default').conn_type = "postgres"
+        mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
         mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = [[1]]
@@ -234,7 +234,7 @@ class TestSqlSensor(TestHiveEnvironment):
             success=[1],
         )
 
-        mock_hook.get_connection('postgres_default').conn_type = "postgres"
+        mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook)
         mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = [[1]]
@@ -257,11 +257,11 @@ class TestSqlSensor(TestHiveEnvironment):
     def test_sql_sensor_hook_params(self):
         op = SqlSensor(
             task_id='sql_sensor_hook_params',
-            conn_id='google_cloud_default',
+            conn_id='postgres_default',
             sql="SELECT 1",
             hook_params={
-                'delegate_to': 'me',
+                'schema': 'public',
             },
         )
         hook = op._get_hook()
-        assert hook.delegate_to == 'me'
+        assert hook.schema == 'public'