You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jo...@apache.org on 2022/06/06 23:33:14 UTC

[airflow] branch main updated: Enable dbt Cloud provider to interact with single tenant instances (#24264)

This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7498fba826 Enable dbt Cloud provider to interact with single tenant instances (#24264)
7498fba826 is described below

commit 7498fba826ec477b02a40a2e23e1c685f148e20f
Author: Elize Papineau <el...@gmail.com>
AuthorDate: Mon Jun 6 16:32:56 2022 -0700

    Enable dbt Cloud provider to interact with single tenant instances (#24264)
    
    * Enable provider to interact with single tenant
    
    * Define single tenant arg on Operator
    
    * Add test for single tenant endpoint
    
    * Enable provider to interact with single tenant
    
    * Define single tenant arg on Operator
    
    * Add test for single tenant endpoint
    
    * Code linting from black
    
    * Code linting from black
    
    * Pass tenant to dbtCloudHook in DbtCloudGetJobRunArtifactOperator class
    
    * Make Tenant a connection-level setting
    
    * Remove tenant arg from Operator
    
    * Make tenant connection-level param that defaults to 'cloud'
    
    * Remove tenant param from sensor
    
    * Remove leftover param string from hook
    
    * Update airflow/providers/dbt/cloud/hooks/dbt.py
    
    Co-authored-by: Josh Fell <48...@users.noreply.github.com>
    
    * Parameterize test_init_hook to test single and multi tenant connections
    
    * Integrate test simplification suggestion
    
    * Add connection to TestDbtCloudJobRunSesnor
    
    Co-authored-by: Josh Fell <48...@users.noreply.github.com>
---
 airflow/providers/dbt/cloud/hooks/dbt.py           |  9 +++++---
 tests/providers/dbt/cloud/hooks/test_dbt_cloud.py  | 26 ++++++++++++++++++----
 .../providers/dbt/cloud/sensors/test_dbt_cloud.py  |  8 +++++++
 3 files changed, 36 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py
index d88c0053d3..31214d7342 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -144,14 +144,17 @@ class DbtCloudHook(HttpHook):
     def get_ui_field_behaviour() -> Dict[str, Any]:
         """Builds custom field behavior for the dbt Cloud connection form in the Airflow UI."""
         return {
-            "hidden_fields": ["host", "port", "schema", "extra"],
-            "relabeling": {"login": "Account ID", "password": "API Token"},
+            "hidden_fields": ["host", "port", "extra"],
+            "relabeling": {"login": "Account ID", "password": "API Token", "schema": "Tenant"},
+            "placeholders": {"schema": "Defaults to 'cloud'."},
         }
 
     def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) -> None:
         super().__init__(auth_type=TokenAuth)
         self.dbt_cloud_conn_id = dbt_cloud_conn_id
-        self.base_url = "https://cloud.getdbt.com/api/v2/accounts/"
+        tenant = self.connection.schema if self.connection.schema else 'cloud'
+
+        self.base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"
 
     @cached_property
     def connection(self) -> Connection:
diff --git a/tests/providers/dbt/cloud/hooks/test_dbt_cloud.py b/tests/providers/dbt/cloud/hooks/test_dbt_cloud.py
index 0c4ad65456..012138e5aa 100644
--- a/tests/providers/dbt/cloud/hooks/test_dbt_cloud.py
+++ b/tests/providers/dbt/cloud/hooks/test_dbt_cloud.py
@@ -34,14 +34,17 @@ from airflow.utils import db
 
 ACCOUNT_ID_CONN = "account_id_conn"
 NO_ACCOUNT_ID_CONN = "no_account_id_conn"
+SINGLE_TENANT_CONN = "single_tenant_conn"
 DEFAULT_ACCOUNT_ID = 11111
 ACCOUNT_ID = 22222
+SINGLE_TENANT_SCHEMA = "single.tenant"
 TOKEN = "token"
 PROJECT_ID = 33333
 JOB_ID = 4444
 RUN_ID = 5555
 
 BASE_URL = "https://cloud.getdbt.com/api/v2/accounts/"
+SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/api/v2/accounts/"
 
 
 class TestDbtCloudJobRunStatus:
@@ -119,15 +122,30 @@ class TestDbtCloudHook:
             password=TOKEN,
         )
 
+        # Connection with `schema` parameter set
+        schema_conn = Connection(
+            conn_id=SINGLE_TENANT_CONN,
+            conn_type=DbtCloudHook.conn_type,
+            login=DEFAULT_ACCOUNT_ID,
+            password=TOKEN,
+            schema=SINGLE_TENANT_SCHEMA,
+        )
+
         db.merge_conn(account_id_conn)
         db.merge_conn(no_account_id_conn)
+        db.merge_conn(schema_conn)
 
-    def test_init_hook(self):
-        hook = DbtCloudHook()
-        assert hook.dbt_cloud_conn_id == "dbt_cloud_default"
-        assert hook.base_url == BASE_URL
+    @pytest.mark.parametrize(
+        argnames="conn_id, url",
+        argvalues=[(ACCOUNT_ID_CONN, BASE_URL), (SINGLE_TENANT_CONN, SINGLE_TENANT_URL)],
+        ids=["multi-tenant", "single-tenant"],
+    )
+    def test_init_hook(self, conn_id, url):
+        hook = DbtCloudHook(conn_id)
         assert hook.auth_type == TokenAuth
         assert hook.method == "POST"
+        assert hook.dbt_cloud_conn_id == conn_id
+        assert hook.base_url == url
 
     @pytest.mark.parametrize(
         argnames="conn_id, account_id",
diff --git a/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py b/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py
index 6cd5069028..c5d834fd1e 100644
--- a/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py
+++ b/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py
@@ -19,11 +19,14 @@ from unittest.mock import patch
 
 import pytest
 
+from airflow.models.connection import Connection
 from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus
 from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor
+from airflow.utils import db
 
 ACCOUNT_ID = 11111
 RUN_ID = 5555
+TOKEN = "token"
 
 
 class TestDbtCloudJobRunSensor:
@@ -37,6 +40,11 @@ class TestDbtCloudJobRunSensor:
             poke_interval=15,
         )
 
+        # Connection
+        conn = Connection(conn_id="dbt", conn_type=DbtCloudHook.conn_type, login=ACCOUNT_ID, password=TOKEN)
+
+        db.merge_conn(conn)
+
     def test_init(self):
         assert self.sensor.dbt_cloud_conn_id == "dbt"
         assert self.sensor.run_id == RUN_ID