You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/08/03 09:55:12 UTC
[airflow] branch main updated: Refactor of links in Dataproc. (#31895)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1ea7ae809b Refactor of links in Dataproc. (#31895)
1ea7ae809b is described below
commit 1ea7ae809bf0b8d1c8edf97e4d456b3753a1feca
Author: Beata Kossakowska <10...@users.noreply.github.com>
AuthorDate: Thu Aug 3 11:55:03 2023 +0200
Refactor of links in Dataproc. (#31895)
Co-authored-by: Beata Kossakowska <bk...@google.com>
---
airflow/providers/google/cloud/links/dataproc.py | 165 ++++++++++++++++++++-
.../providers/google/cloud/operators/dataproc.py | 157 +++++++++++++-------
airflow/providers/google/provider.yaml | 6 +
.../google/cloud/operators/test_dataproc.py | 140 ++++++++++-------
4 files changed, 351 insertions(+), 117 deletions(-)
diff --git a/airflow/providers/google/cloud/links/dataproc.py b/airflow/providers/google/cloud/links/dataproc.py
index d560d2a5ee..16c1493e1e 100644
--- a/airflow/providers/google/cloud/links/dataproc.py
+++ b/airflow/providers/google/cloud/links/dataproc.py
@@ -18,10 +18,12 @@
"""This module contains Google Dataproc links."""
from __future__ import annotations
+import warnings
from typing import TYPE_CHECKING
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperatorLink, XCom
-from airflow.providers.google.cloud.links.base import BASE_LINK
+from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink
if TYPE_CHECKING:
from airflow.models import BaseOperator
@@ -29,21 +31,38 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
DATAPROC_BASE_LINK = BASE_LINK + "/dataproc"
-DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{resource}?region={region}&project={project_id}"
+DATAPROC_JOB_LINK = DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}"
+
DATAPROC_CLUSTER_LINK = (
- DATAPROC_BASE_LINK + "/clusters/{resource}/monitoring?region={region}&project={project_id}"
+ DATAPROC_BASE_LINK + "/clusters/{cluster_id}/monitoring?region={region}&project={project_id}"
)
DATAPROC_WORKFLOW_TEMPLATE_LINK = (
- DATAPROC_BASE_LINK + "/workflows/templates/{region}/{resource}?project={project_id}"
+ DATAPROC_BASE_LINK + "/workflows/templates/{region}/{workflow_template_id}?project={project_id}"
)
-DATAPROC_WORKFLOW_LINK = DATAPROC_BASE_LINK + "/workflows/instances/{region}/{resource}?project={project_id}"
-DATAPROC_BATCH_LINK = DATAPROC_BASE_LINK + "/batches/{region}/{resource}/monitoring?project={project_id}"
+DATAPROC_WORKFLOW_LINK = (
+ DATAPROC_BASE_LINK + "/workflows/instances/{region}/{workflow_id}?project={project_id}"
+)
+
+DATAPROC_BATCH_LINK = DATAPROC_BASE_LINK + "/batches/{region}/{batch_id}/monitoring?project={project_id}"
DATAPROC_BATCHES_LINK = DATAPROC_BASE_LINK + "/batches?project={project_id}"
+DATAPROC_JOB_LINK_DEPRECATED = DATAPROC_BASE_LINK + "/jobs/{resource}?region={region}&project={project_id}"
+DATAPROC_CLUSTER_LINK_DEPRECATED = (
+ DATAPROC_BASE_LINK + "/clusters/{resource}/monitoring?region={region}&project={project_id}"
+)
class DataprocLink(BaseOperatorLink):
- """Helper class for constructing Dataproc resource link."""
+ """
+ Helper class for constructing Dataproc resource link.
+ .. warning::
+ This link is deprecated.
+ """
+
+ warnings.warn(
+ "This DataprocLink is deprecated.",
+ AirflowProviderDeprecationWarning,
+ )
name = "Dataproc resource"
key = "conf"
@@ -82,8 +101,14 @@ class DataprocLink(BaseOperatorLink):
class DataprocListLink(BaseOperatorLink):
- """Helper class for constructing list of Dataproc resources link."""
+ """
+ Helper class for constructing list of Dataproc resources link.
+
+ .. warning::
+ This link is deprecated.
+ """
+ warnings.warn("This DataprocListLink is deprecated.", AirflowProviderDeprecationWarning)
name = "Dataproc resources"
key = "list_conf"
@@ -116,3 +141,127 @@ class DataprocListLink(BaseOperatorLink):
if list_conf
else ""
)
+
+
+class DataprocClusterLink(BaseGoogleLink):
+ """Helper class for constructing Dataproc Cluster Link."""
+
+ name = "Dataproc Cluster"
+ key = "dataproc_cluster"
+ format_str = DATAPROC_CLUSTER_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ operator: BaseOperator,
+ cluster_id: str,
+ region: str,
+ project_id: str,
+ ):
+ operator.xcom_push(
+ context,
+ key=DataprocClusterLink.key,
+ value={"cluster_id": cluster_id, "region": region, "project_id": project_id},
+ )
+
+
+class DataprocJobLink(BaseGoogleLink):
+ """Helper class for constructing Dataproc Job Link."""
+
+ name = "Dataproc Job"
+ key = "dataproc_job"
+ format_str = DATAPROC_JOB_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ operator: BaseOperator,
+ job_id: str,
+ region: str,
+ project_id: str,
+ ):
+ operator.xcom_push(
+ context,
+ key=DataprocJobLink.key,
+ value={"job_id": job_id, "region": region, "project_id": project_id},
+ )
+
+
+class DataprocWorkflowLink(BaseGoogleLink):
+ """Helper class for constructing Dataproc Workflow Link."""
+
+ name = "Dataproc Workflow"
+ key = "dataproc_workflow"
+ format_str = DATAPROC_WORKFLOW_LINK
+
+ @staticmethod
+ def persist(context: Context, operator: BaseOperator, workflow_id: str, project_id: str, region: str):
+ operator.xcom_push(
+ context,
+ key=DataprocWorkflowLink.key,
+ value={"workflow_id": workflow_id, "region": region, "project_id": project_id},
+ )
+
+
+class DataprocWorkflowTemplateLink(BaseGoogleLink):
+ """Helper class for constructing Dataproc Workflow Template Link."""
+
+ name = "Dataproc Workflow Template"
+ key = "dataproc_workflow_template"
+ format_str = DATAPROC_WORKFLOW_TEMPLATE_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ operator: BaseOperator,
+ workflow_template_id: str,
+ project_id: str,
+ region: str,
+ ):
+ operator.xcom_push(
+ context,
+ key=DataprocWorkflowTemplateLink.key,
+ value={"workflow_template_id": workflow_template_id, "region": region, "project_id": project_id},
+ )
+
+
+class DataprocBatchLink(BaseGoogleLink):
+ """Helper class for constructing Dataproc Batch Link."""
+
+ name = "Dataproc Batch"
+ key = "dataproc_batch"
+ format_str = DATAPROC_BATCH_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ operator: BaseOperator,
+ batch_id: str,
+ project_id: str,
+ region: str,
+ ):
+ operator.xcom_push(
+ context,
+ key=DataprocBatchLink.key,
+ value={"batch_id": batch_id, "region": region, "project_id": project_id},
+ )
+
+
+class DataprocBatchesListLink(BaseGoogleLink):
+ """Helper class for constructing Dataproc Batches List Link."""
+
+ name = "Dataproc Batches List"
+ key = "dataproc_batches_list"
+ format_str = DATAPROC_BATCHES_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ operator: BaseOperator,
+ project_id: str,
+ ):
+ operator.xcom_push(
+ context,
+ key=DataprocBatchesListLink.key,
+ value={"project_id": project_id},
+ )
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index 423ebf988f..09e76ae206 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -43,13 +43,15 @@ from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProc
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.links.dataproc import (
DATAPROC_BATCH_LINK,
- DATAPROC_BATCHES_LINK,
- DATAPROC_CLUSTER_LINK,
- DATAPROC_JOB_LOG_LINK,
- DATAPROC_WORKFLOW_LINK,
- DATAPROC_WORKFLOW_TEMPLATE_LINK,
+ DATAPROC_CLUSTER_LINK_DEPRECATED,
+ DATAPROC_JOB_LINK_DEPRECATED,
+ DataprocBatchesListLink,
+ DataprocBatchLink,
+ DataprocClusterLink,
+ DataprocJobLink,
DataprocLink,
- DataprocListLink,
+ DataprocWorkflowLink,
+ DataprocWorkflowTemplateLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.dataproc import (
@@ -189,6 +191,7 @@ class ClusterGenerator:
enable_component_gateway: bool | None = False,
**kwargs,
) -> None:
+
self.project_id = project_id
self.num_masters = num_masters
self.num_workers = num_workers
@@ -488,7 +491,7 @@ class DataprocCreateClusterOperator(GoogleCloudBaseOperator):
)
template_fields_renderers = {"cluster_config": "json", "virtual_cluster_config": "json"}
- operator_extra_links = (DataprocLink(),)
+ operator_extra_links = (DataprocClusterLink(),)
def __init__(
self,
@@ -629,9 +632,15 @@ class DataprocCreateClusterOperator(GoogleCloudBaseOperator):
self.log.info("Creating cluster: %s", self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required to display extra link no matter what the cluster status will be
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
- )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ DataprocClusterLink.persist(
+ context=context,
+ operator=self,
+ cluster_id=self.cluster_name,
+ project_id=project_id,
+ region=self.region,
+ )
try:
# First try to create a new cluster
operation = self._create_cluster(hook)
@@ -814,7 +823,10 @@ class DataprocScaleClusterOperator(GoogleCloudBaseOperator):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required to display extra link no matter what the cluster status will be
DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
+ context=context,
+ task_instance=self,
+ url=DATAPROC_CLUSTER_LINK_DEPRECATED,
+ resource=self.cluster_name,
)
operation = hook.update_cluster(
project_id=self.project_id,
@@ -1070,7 +1082,7 @@ class DataprocJobBaseOperator(GoogleCloudBaseOperator):
self.log.info("Job %s submitted successfully.", job_id)
# Save data required for extra links no matter what the job status will be
DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id
+ context=context, task_instance=self, url=DATAPROC_JOB_LINK_DEPRECATED, resource=job_id
)
if self.deferrable:
@@ -1669,7 +1681,7 @@ class DataprocCreateWorkflowTemplateOperator(GoogleCloudBaseOperator):
template_fields: Sequence[str] = ("region", "template")
template_fields_renderers = {"template": "json"}
- operator_extra_links = (DataprocLink(),)
+ operator_extra_links = (DataprocWorkflowTemplateLink(),)
def __init__(
self,
@@ -1709,12 +1721,15 @@ class DataprocCreateWorkflowTemplateOperator(GoogleCloudBaseOperator):
self.log.info("Workflow %s created", workflow.name)
except AlreadyExists:
self.log.info("Workflow with given id already exists")
- DataprocLink.persist(
- context=context,
- task_instance=self,
- url=DATAPROC_WORKFLOW_TEMPLATE_LINK,
- resource=self.template["id"],
- )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ DataprocWorkflowTemplateLink.persist(
+ context=context,
+ operator=self,
+ workflow_template_id=self.template["id"],
+ region=self.region,
+ project_id=project_id,
+ )
class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
@@ -1759,7 +1774,7 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters")
template_fields_renderers = {"parameters": "json"}
- operator_extra_links = (DataprocLink(),)
+ operator_extra_links = (DataprocWorkflowLink(),)
def __init__(
self,
@@ -1811,9 +1826,15 @@ class DataprocInstantiateWorkflowTemplateOperator(GoogleCloudBaseOperator):
metadata=self.metadata,
)
self.workflow_id = operation.operation.name.split("/")[-1]
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id
- )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ DataprocWorkflowLink.persist(
+ context=context,
+ operator=self,
+ workflow_id=self.workflow_id,
+ region=self.region,
+ project_id=project_id,
+ )
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
@@ -1889,7 +1910,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
template_fields: Sequence[str] = ("template", "impersonation_chain")
template_fields_renderers = {"template": "json"}
- operator_extra_links = (DataprocLink(),)
+ operator_extra_links = (DataprocWorkflowLink(),)
def __init__(
self,
@@ -1926,9 +1947,10 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
def execute(self, context: Context):
self.log.info("Instantiating Inline Template")
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+ project_id = self.project_id or hook.project_id
operation = hook.instantiate_inline_workflow_template(
template=self.template,
- project_id=self.project_id or hook.project_id,
+ project_id=project_id,
region=self.region,
request_id=self.request_id,
retry=self.retry,
@@ -1936,9 +1958,14 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(GoogleCloudBaseOperator)
metadata=self.metadata,
)
self.workflow_id = operation.operation.name.split("/")[-1]
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id
- )
+ if project_id:
+ DataprocWorkflowLink.persist(
+ context=context,
+ operator=self,
+ workflow_id=self.workflow_id,
+ region=self.region,
+ project_id=project_id,
+ )
if not self.deferrable:
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
operation.result()
@@ -2010,7 +2037,7 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
template_fields: Sequence[str] = ("project_id", "region", "job", "impersonation_chain", "request_id")
template_fields_renderers = {"job": "json"}
- operator_extra_links = (DataprocLink(),)
+ operator_extra_links = (DataprocJobLink(),)
def __init__(
self,
@@ -2066,9 +2093,15 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):
new_job_id: str = job_object.reference.job_id
self.log.info("Job %s submitted successfully.", new_job_id)
# Save data required by extra links no matter what the job status will be
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=new_job_id
- )
+ project_id = self.project_id or self.hook.project_id
+ if project_id:
+ DataprocJobLink.persist(
+ context=context,
+ operator=self,
+ job_id=new_job_id,
+ region=self.region,
+ project_id=project_id,
+ )
self.job_id = new_job_id
if self.deferrable:
@@ -2168,7 +2201,7 @@ class DataprocUpdateClusterOperator(GoogleCloudBaseOperator):
"project_id",
"impersonation_chain",
)
- operator_extra_links = (DataprocLink(),)
+ operator_extra_links = (DataprocClusterLink(),)
def __init__(
self,
@@ -2210,9 +2243,15 @@ class DataprocUpdateClusterOperator(GoogleCloudBaseOperator):
def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required by extra links no matter what the cluster status will be
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
- )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ DataprocClusterLink.persist(
+ context=context,
+ operator=self,
+ cluster_id=self.cluster_name,
+ project_id=project_id,
+ region=self.region,
+ )
self.log.info("Updating %s cluster.", self.cluster_name)
operation = hook.update_cluster(
project_id=self.project_id,
@@ -2299,7 +2338,7 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
"region",
"impersonation_chain",
)
- operator_extra_links = (DataprocLink(),)
+ operator_extra_links = (DataprocBatchLink(),)
def __init__(
self,
@@ -2344,7 +2383,7 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
# batch_id might not be set and will be generated
if self.batch_id:
link = DATAPROC_BATCH_LINK.format(
- region=self.region, project_id=self.project_id, resource=self.batch_id
+ region=self.region, project_id=self.project_id, batch_id=self.batch_id
)
self.log.info("Creating batch %s", self.batch_id)
self.log.info("Once started, the batch job will be available at %s", link)
@@ -2423,7 +2462,17 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
wait_check_interval=self.polling_interval_seconds,
)
batch_id = self.batch_id or result.name.split("/")[-1]
+
self.handle_batch_status(context, result.state, batch_id)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ DataprocBatchLink.persist(
+ context=context,
+ operator=self,
+ project_id=project_id,
+ region=self.region,
+ batch_id=batch_id,
+ )
return Batch.to_dict(result)
def execute_complete(self, context, event=None) -> None:
@@ -2446,24 +2495,14 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
# The existing batch may be a number of states other than 'SUCCEEDED'\
# wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
# finds a cancelling|canceled|unspecified job from wait_for_batch or the deferred trigger
- link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, resource=batch_id)
+ link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, batch_id=batch_id)
if state == Batch.State.FAILED:
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id
- )
raise AirflowException("Batch job %s failed. Driver Logs: %s", batch_id, link)
if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id
- )
raise AirflowException("Batch job %s was cancelled. Driver logs: %s", batch_id, link)
if state == Batch.State.STATE_UNSPECIFIED:
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id
- )
raise AirflowException("Batch job %s unspecified. Driver logs: %s", batch_id, link)
self.log.info("Batch job %s completed. Driver logs: %s", batch_id, link)
- DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id)
class DataprocDeleteBatchOperator(GoogleCloudBaseOperator):
@@ -2554,7 +2593,7 @@ class DataprocGetBatchOperator(GoogleCloudBaseOperator):
"""
template_fields: Sequence[str] = ("batch_id", "region", "project_id", "impersonation_chain")
- operator_extra_links = (DataprocLink(),)
+ operator_extra_links = (DataprocBatchLink(),)
def __init__(
self,
@@ -2590,9 +2629,15 @@ class DataprocGetBatchOperator(GoogleCloudBaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- DataprocLink.persist(
- context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=self.batch_id
- )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ DataprocBatchLink.persist(
+ context=context,
+ operator=self,
+ project_id=project_id,
+ region=self.region,
+ batch_id=self.batch_id,
+ )
return Batch.to_dict(batch)
@@ -2624,7 +2669,7 @@ class DataprocListBatchesOperator(GoogleCloudBaseOperator):
"""
template_fields: Sequence[str] = ("region", "project_id", "impersonation_chain")
- operator_extra_links = (DataprocListLink(),)
+ operator_extra_links = (DataprocBatchesListLink(),)
def __init__(
self,
@@ -2668,7 +2713,9 @@ class DataprocListBatchesOperator(GoogleCloudBaseOperator):
filter=self.filter,
order_by=self.order_by,
)
- DataprocListLink.persist(context=context, task_instance=self, url=DATAPROC_BATCHES_LINK)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ DataprocBatchesListLink.persist(context=context, operator=self, project_id=project_id)
return [Batch.to_dict(result) for result in results]
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 1b96ca972a..f43b088720 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -1063,6 +1063,12 @@ extra-links:
- airflow.providers.google.cloud.links.datacatalog.DataCatalogTagTemplateLink
- airflow.providers.google.cloud.links.dataproc.DataprocLink
- airflow.providers.google.cloud.links.dataproc.DataprocListLink
+ - airflow.providers.google.cloud.links.dataproc.DataprocClusterLink
+ - airflow.providers.google.cloud.links.dataproc.DataprocJobLink
+ - airflow.providers.google.cloud.links.dataproc.DataprocWorkflowLink
+ - airflow.providers.google.cloud.links.dataproc.DataprocWorkflowTemplateLink
+ - airflow.providers.google.cloud.links.dataproc.DataprocBatchLink
+ - airflow.providers.google.cloud.links.dataproc.DataprocBatchesListLink
- airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreDetailedLink
- airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreLink
- airflow.providers.google.cloud.links.dataprep.DataprepFlowLink
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index 38e4ffeef5..6ddaec8be8 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -32,10 +32,14 @@ from airflow.exceptions import (
TaskDeferred,
)
from airflow.models import DAG, DagBag
+from airflow.providers.google.cloud.links.dataproc import (
+ DATAPROC_CLUSTER_LINK_DEPRECATED,
+ DATAPROC_JOB_LINK_DEPRECATED,
+ DataprocClusterLink,
+ DataprocJobLink,
+ DataprocWorkflowLink,
+)
from airflow.providers.google.cloud.operators.dataproc import (
- DATAPROC_CLUSTER_LINK,
- DATAPROC_JOB_LOG_LINK,
- DATAPROC_WORKFLOW_LINK,
ClusterGenerator,
DataprocCreateBatchOperator,
DataprocCreateClusterOperator,
@@ -241,19 +245,28 @@ DATAPROC_JOB_CONF_EXPECTED = {
"resource": TEST_JOB_ID,
"region": GCP_REGION,
"project_id": GCP_PROJECT,
- "url": DATAPROC_JOB_LOG_LINK,
+ "url": DATAPROC_JOB_LINK_DEPRECATED,
+}
+DATAPROC_JOB_EXPECTED = {
+ "job_id": TEST_JOB_ID,
+ "region": GCP_REGION,
+ "project_id": GCP_PROJECT,
}
DATAPROC_CLUSTER_CONF_EXPECTED = {
"resource": CLUSTER_NAME,
"region": GCP_REGION,
"project_id": GCP_PROJECT,
- "url": DATAPROC_CLUSTER_LINK,
+ "url": DATAPROC_CLUSTER_LINK_DEPRECATED,
}
-DATAPROC_WORKFLOW_CONF_EXPECTED = {
- "resource": TEST_WORKFLOW_ID,
+DATAPROC_CLUSTER_EXPECTED = {
+ "cluster_id": CLUSTER_NAME,
+ "region": GCP_REGION,
+ "project_id": GCP_PROJECT,
+}
+DATAPROC_WORKFLOW_EXPECTED = {
+ "workflow_id": TEST_WORKFLOW_ID,
"region": GCP_REGION,
"project_id": GCP_PROJECT,
- "url": DATAPROC_WORKFLOW_LINK,
}
BATCH_ID = "test-batch-id"
BATCH = {
@@ -306,7 +319,7 @@ class DataprocClusterTestBase(DataprocTestBase):
def setup_class(cls):
super().setup_class()
cls.extra_links_expected_calls_base = [
- call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ call.ti.xcom_push(execution_date=None, key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED)
]
@@ -488,8 +501,8 @@ class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
self.mock_ti.xcom_push.assert_called_once_with(
- key="conf",
- value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
execution_date=None,
)
@@ -537,8 +550,8 @@ class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
self.mock_ti.xcom_push.assert_called_once_with(
- key="conf",
- value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
execution_date=None,
)
@@ -742,28 +755,35 @@ def test_create_cluster_operator_extra_links(dag_maker, create_task_instance_of_
# Assert operator links for serialized DAG
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocClusterLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(ti, DataprocLink.name) == ""
+ assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == ""
+ assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == ""
- ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED)
# Assert operator links are preserved in deserialized tasks after execution
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
+ assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
# Assert operator links after execution
- assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
+ assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
class TestDataprocClusterScaleOperator(DataprocClusterTestBase):
+ @classmethod
+ def setup_class(cls):
+ super().setup_class()
+ cls.extra_links_expected_calls_base = [
+ call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ]
+
def test_deprecation_warning(self):
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT)
@@ -847,7 +867,10 @@ def test_scale_cluster_operator_extra_links(dag_maker, create_task_instance_of_o
# Assert operator link is empty for deserialized task when no XCom push occurred
assert deserialized_task.get_extra_links(ti, DataprocLink.name) == ""
- ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ti.xcom_push(
+ key="conf",
+ value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ )
# Assert operator links are preserved in deserialized tasks after execution
assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
@@ -929,7 +952,9 @@ class TestDataprocClusterDeleteOperator:
class TestDataprocSubmitJobOperator(DataprocJobTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
- xcom_push_call = call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_JOB_CONF_EXPECTED)
+ xcom_push_call = call.ti.xcom_push(
+ execution_date=None, key="dataproc_job", value=DATAPROC_JOB_EXPECTED
+ )
wait_for_job_call = call.hook().wait_for_job(
job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT, timeout=None
)
@@ -976,7 +1001,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
)
self.mock_ti.xcom_push.assert_called_once_with(
- key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
+ key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None
)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1016,7 +1041,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
mock_hook.return_value.wait_for_job.assert_not_called()
self.mock_ti.xcom_push.assert_called_once_with(
- key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
+ key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None
)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1185,25 +1210,25 @@ def test_submit_job_operator_extra_links(mock_hook, dag_maker, create_task_insta
# Assert operator links for serialized_dag
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocJobLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocJobLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(ti, DataprocLink.name) == ""
+ assert ti.task.get_extra_links(ti, DataprocJobLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == ""
+ assert deserialized_task.get_extra_links(ti, DataprocJobLink.name) == ""
- ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED)
+ ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED)
# Assert operator links are preserved in deserialized tasks
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_JOB_LINK_EXPECTED
+ assert deserialized_task.get_extra_links(ti, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED
# Assert operator links after execution
- assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_JOB_LINK_EXPECTED
+ assert ti.task.get_extra_links(ti, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED
class TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
@@ -1251,8 +1276,8 @@ class TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)
self.mock_ti.xcom_push.assert_called_once_with(
- key="conf",
- value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
execution_date=None,
)
@@ -1342,25 +1367,25 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_
# Assert operator links for serialized_dag
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocClusterLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(ti, DataprocLink.name) == ""
+ assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == ""
+ assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == ""
- ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED)
# Assert operator links are preserved in deserialized tasks
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
+ assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
# Assert operator links after execution
- assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
+ assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
class TestDataprocWorkflowTemplateInstantiateOperator:
@@ -1448,25 +1473,25 @@ def test_instantiate_workflow_operator_extra_links(mock_hook, dag_maker, create_
# Assert operator links for serialized_dag
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocWorkflowLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocWorkflowLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(ti, DataprocLink.name) == ""
+ assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == ""
+ assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == ""
- ti.xcom_push(key="conf", value=DATAPROC_WORKFLOW_CONF_EXPECTED)
+ ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED)
# Assert operator links are preserved in deserialized tasks
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED
+ assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED
# Assert operator links after execution
- assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED
+ assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED
class TestDataprocWorkflowTemplateInstantiateInlineOperator:
@@ -1548,25 +1573,25 @@ def test_instantiate_inline_workflow_operator_extra_links(
# Assert operator links for serialized_dag
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocWorkflowLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocWorkflowLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(ti, DataprocLink.name) == ""
+ assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == ""
+ assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == ""
- ti.xcom_push(key="conf", value=DATAPROC_WORKFLOW_CONF_EXPECTED)
+ ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED)
# Assert operator links are preserved in deserialized tasks
- assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED
+ assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED
# Assert operator links after execution
- assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED
+ assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED
class TestDataProcHiveOperator:
@@ -1789,6 +1814,13 @@ class TestDataProcSparkSqlOperator:
class TestDataProcSparkOperator(DataprocJobTestBase):
+ @classmethod
+ def setup_class(cls):
+ cls.extra_links_expected_calls = [
+ call.ti.xcom_push(execution_date=None, key="conf", value=DATAPROC_JOB_CONF_EXPECTED),
+ call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT),
+ ]
+
main_class = "org.apache.spark.examples.SparkPi"
jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"]
job_name = "simple"