You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by tu...@apache.org on 2022/02/23 11:00:36 UTC
[airflow] branch main updated: Add Dataproc assets/links (#21756)
This is an automated email from the ASF dual-hosted git repository.
turbaszek pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3b4c26e Add Dataproc assets/links (#21756)
3b4c26e is described below
commit 3b4c26eb3a1c8d4938be80ab7fa0711561e91f8f
Author: Wojciech Januszek <wj...@sigma.ug.edu.pl>
AuthorDate: Wed Feb 23 10:59:48 2022 +0000
Add Dataproc assets/links (#21756)
Co-authored-by: Wojciech Januszek <ja...@google.com>
---
airflow/providers/google/cloud/links/dataproc.py | 112 ++++++++++++++++
.../providers/google/cloud/operators/dataproc.py | 143 ++++++++-------------
airflow/providers/google/provider.yaml | 4 +-
.../google/cloud/operators/test_dataproc.py | 126 +++++++++---------
4 files changed, 227 insertions(+), 158 deletions(-)
diff --git a/airflow/providers/google/cloud/links/dataproc.py b/airflow/providers/google/cloud/links/dataproc.py
new file mode 100644
index 0000000..ffa286f
--- /dev/null
+++ b/airflow/providers/google/cloud/links/dataproc.py
@@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains Google Dataproc links."""
+
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+from airflow.models import BaseOperator, BaseOperatorLink, XCom
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+DATAPROC_BASE_LINK = "https://console.cloud.google.com/dataproc"
+DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{resource}?region={region}&project={project_id}"
+DATAPROC_CLUSTER_LINK = (
+ DATAPROC_BASE_LINK + "/clusters/{resource}/monitoring?region={region}&project={project_id}"
+)
+DATAPROC_WORKFLOW_TEMPLATE_LINK = (
+ DATAPROC_BASE_LINK + "/workflows/templates/{region}/{resource}?project={project_id}"
+)
+DATAPROC_WORKFLOW_LINK = DATAPROC_BASE_LINK + "/workflows/instances/{region}/{resource}?project={project_id}"
+DATAPROC_BATCH_LINK = DATAPROC_BASE_LINK + "/batches/{region}/{resource}/monitoring?project={project_id}"
+DATAPROC_BATCHES_LINK = DATAPROC_BASE_LINK + "/batches?project={project_id}"
+
+
+class DataprocLink(BaseOperatorLink):
+ """Helper class for constructing Dataproc resource link"""
+
+ name = "Dataproc resource"
+ key = "conf"
+
+ @staticmethod
+ def persist(
+ context: "Context",
+ task_instance,
+ url: str,
+ resource: str,
+ ):
+ task_instance.xcom_push(
+ context=context,
+ key=DataprocLink.key,
+ value={
+ "region": task_instance.region,
+ "project_id": task_instance.project_id,
+ "url": url,
+ "resource": resource,
+ },
+ )
+
+ def get_link(self, operator: BaseOperator, dttm: datetime):
+ conf = XCom.get_one(
+ key=DataprocLink.key, dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
+ )
+ return (
+ conf["url"].format(
+ region=conf["region"], project_id=conf["project_id"], resource=conf["resource"]
+ )
+ if conf
+ else ""
+ )
+
+
+class DataprocListLink(BaseOperatorLink):
+ """Helper class for constructing list of Dataproc resources link"""
+
+ name = "Dataproc resources"
+ key = "list_conf"
+
+ @staticmethod
+ def persist(
+ context: "Context",
+ task_instance,
+ url: str,
+ ):
+ task_instance.xcom_push(
+ context=context,
+ key=DataprocListLink.key,
+ value={
+ "project_id": task_instance.project_id,
+ "url": url,
+ },
+ )
+
+ def get_link(self, operator: BaseOperator, dttm: datetime):
+ list_conf = XCom.get_one(
+ key=DataprocListLink.key,
+ dag_id=operator.dag.dag_id,
+ task_id=operator.task_id,
+ execution_date=dttm,
+ )
+ return (
+ list_conf["url"].format(
+ project_id=list_conf["project_id"],
+ )
+ if list_conf
+ else ""
+ )
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index ce6e96e..a3de849 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -36,62 +36,25 @@ from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator, BaseOperatorLink, XCom
+from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.cloud.links.dataproc import (
+ DATAPROC_BATCH_LINK,
+ DATAPROC_BATCHES_LINK,
+ DATAPROC_CLUSTER_LINK,
+ DATAPROC_JOB_LOG_LINK,
+ DATAPROC_WORKFLOW_LINK,
+ DATAPROC_WORKFLOW_TEMPLATE_LINK,
+ DataprocLink,
+ DataprocListLink,
+)
from airflow.utils import timezone
if TYPE_CHECKING:
from airflow.utils.context import Context
-DATAPROC_BASE_LINK = "https://console.cloud.google.com/dataproc"
-DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}"
-DATAPROC_CLUSTER_LINK = (
- DATAPROC_BASE_LINK + "/clusters/{cluster_name}/monitoring?region={region}&project={project_id}"
-)
-
-
-class DataprocJobLink(BaseOperatorLink):
- """Helper class for constructing Dataproc Job link"""
-
- name = "Dataproc Job"
-
- def get_link(self, operator, dttm):
- job_conf = XCom.get_one(
- key="job_conf", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
- )
- return (
- DATAPROC_JOB_LOG_LINK.format(
- job_id=job_conf["job_id"],
- region=job_conf["region"],
- project_id=job_conf["project_id"],
- )
- if job_conf
- else ""
- )
-
-
-class DataprocClusterLink(BaseOperatorLink):
- """Helper class for constructing Dataproc Cluster link"""
-
- name = "Dataproc Cluster"
-
- def get_link(self, operator, dttm):
- cluster_conf = XCom.get_one(
- key="cluster_conf", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
- )
- return (
- DATAPROC_CLUSTER_LINK.format(
- cluster_name=cluster_conf["cluster_name"],
- region=cluster_conf["region"],
- project_id=cluster_conf["project_id"],
- )
- if cluster_conf
- else ""
- )
-
-
class ClusterGenerator:
"""
Create a new Dataproc Cluster.
@@ -481,7 +444,7 @@ class DataprocCreateClusterOperator(BaseOperator):
)
template_fields_renderers = {'cluster_config': 'json'}
- operator_extra_links = (DataprocClusterLink(),)
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -625,14 +588,8 @@ class DataprocCreateClusterOperator(BaseOperator):
self.log.info('Creating cluster: %s', self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required to display extra link no matter what the cluster status will be
- self.xcom_push(
- context,
- key="cluster_conf",
- value={
- "cluster_name": self.cluster_name,
- "region": self.region,
- "project_id": self.project_id,
- },
+ DataprocLink.persist(
+ context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
)
try:
# First try to create a new cluster
@@ -700,7 +657,7 @@ class DataprocScaleClusterOperator(BaseOperator):
template_fields: Sequence[str] = ('cluster_name', 'project_id', 'region', 'impersonation_chain')
- operator_extra_links = (DataprocClusterLink(),)
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -780,14 +737,8 @@ class DataprocScaleClusterOperator(BaseOperator):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required to display extra link no matter what the cluster status will be
- self.xcom_push(
- context,
- key="cluster_conf",
- value={
- "cluster_name": self.cluster_name,
- "region": self.region,
- "project_id": self.project_id,
- },
+ DataprocLink.persist(
+ context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
)
operation = hook.update_cluster(
project_id=self.project_id,
@@ -924,7 +875,7 @@ class DataprocJobBaseOperator(BaseOperator):
job_type = ""
- operator_extra_links = (DataprocJobLink(),)
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -1007,10 +958,8 @@ class DataprocJobBaseOperator(BaseOperator):
job_id = job_object.reference.job_id
self.log.info('Job %s submitted successfully.', job_id)
# Save data required for extra links no matter what the job status will be
- self.xcom_push(
- context,
- key='job_conf',
- value={'job_id': job_id, 'region': self.region, 'project_id': self.project_id},
+ DataprocLink.persist(
+ context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id
)
if not self.asynchronous:
@@ -1084,7 +1033,7 @@ class DataprocSubmitPigJobOperator(DataprocJobBaseOperator):
ui_color = '#0273d4'
job_type = 'pig_job'
- operator_extra_links = (DataprocJobLink(),)
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -1564,6 +1513,7 @@ class DataprocCreateWorkflowTemplateOperator(BaseOperator):
template_fields: Sequence[str] = ("region", "template")
template_fields_renderers = {"template": "json"}
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -1615,6 +1565,12 @@ class DataprocCreateWorkflowTemplateOperator(BaseOperator):
self.log.info("Workflow %s created", workflow.name)
except AlreadyExists:
self.log.info("Workflow with given id already exists")
+ DataprocLink.persist(
+ context=context,
+ task_instance=self,
+ url=DATAPROC_WORKFLOW_TEMPLATE_LINK,
+ resource=self.template["id"],
+ )
class DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
@@ -1657,6 +1613,7 @@ class DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
template_fields: Sequence[str] = ('template_id', 'impersonation_chain', 'request_id', 'parameters')
template_fields_renderers = {"parameters": "json"}
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -1703,6 +1660,10 @@ class DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
metadata=self.metadata,
)
operation.result()
+ workflow_id = operation.operation.name.split('/')[-1]
+ DataprocLink.persist(
+ context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=workflow_id
+ )
self.log.info('Template instantiated.')
@@ -1746,6 +1707,7 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator):
template_fields: Sequence[str] = ('template', 'impersonation_chain')
template_fields_renderers = {"template": "json"}
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -1786,6 +1748,10 @@ class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator):
metadata=self.metadata,
)
operation.result()
+ workflow_id = operation.operation.name.split('/')[-1]
+ DataprocLink.persist(
+ context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=workflow_id
+ )
self.log.info('Template instantiated.')
@@ -1827,7 +1793,7 @@ class DataprocSubmitJobOperator(BaseOperator):
template_fields: Sequence[str] = ('project_id', 'region', 'job', 'impersonation_chain', 'request_id')
template_fields_renderers = {"job": "json"}
- operator_extra_links = (DataprocJobLink(),)
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -1889,15 +1855,7 @@ class DataprocSubmitJobOperator(BaseOperator):
job_id = job_object.reference.job_id
self.log.info('Job %s submitted successfully.', job_id)
# Save data required by extra links no matter what the job status will be
- self.xcom_push(
- context,
- key="job_conf",
- value={
- "job_id": job_id,
- "region": self.region,
- "project_id": self.project_id,
- },
- )
+ DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id)
if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
@@ -1956,7 +1914,7 @@ class DataprocUpdateClusterOperator(BaseOperator):
"""
template_fields: Sequence[str] = ('impersonation_chain', 'cluster_name')
- operator_extra_links = (DataprocClusterLink(),)
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -2004,14 +1962,8 @@ class DataprocUpdateClusterOperator(BaseOperator):
def execute(self, context: 'Context'):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required by extra links no matter what the cluster status will be
- self.xcom_push(
- context,
- key="cluster_conf",
- value={
- "cluster_name": self.cluster_name,
- "region": self.region,
- "project_id": self.project_id,
- },
+ DataprocLink.persist(
+ context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
)
self.log.info("Updating %s cluster.", self.cluster_name)
operation = hook.update_cluster(
@@ -2066,6 +2018,7 @@ class DataprocCreateBatchOperator(BaseOperator):
'region',
'impersonation_chain',
)
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -2125,6 +2078,8 @@ class DataprocCreateBatchOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
+ batch_id = self.batch_id or result.name.split('/')[-1]
+ DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id)
return Batch.to_dict(result)
def on_kill(self):
@@ -2222,6 +2177,7 @@ class DataprocGetBatchOperator(BaseOperator):
"""
template_fields: Sequence[str] = ("batch_id", "region", "project_id", "impersonation_chain")
+ operator_extra_links = (DataprocLink(),)
def __init__(
self,
@@ -2257,6 +2213,9 @@ class DataprocGetBatchOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
+ DataprocLink.persist(
+ context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=self.batch_id
+ )
return Batch.to_dict(batch)
@@ -2289,6 +2248,7 @@ class DataprocListBatchesOperator(BaseOperator):
"""
template_fields: Sequence[str] = ("region", "project_id", "impersonation_chain")
+ operator_extra_links = (DataprocListLink(),)
def __init__(
self,
@@ -2326,4 +2286,5 @@ class DataprocListBatchesOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
+ DataprocListLink.persist(context=context, task_instance=self, url=DATAPROC_BATCHES_LINK)
return [Batch.to_dict(result) for result in results]
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 2034e81..8af8cae 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -837,8 +837,8 @@ extra-links:
- airflow.providers.google.cloud.operators.datafusion.DataFusionInstanceLink
- airflow.providers.google.cloud.operators.datafusion.DataFusionPipelineLink
- airflow.providers.google.cloud.operators.datafusion.DataFusionPipelinesLink
- - airflow.providers.google.cloud.operators.dataproc.DataprocJobLink
- - airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink
+ - airflow.providers.google.cloud.links.dataproc.DataprocLink
+ - airflow.providers.google.cloud.links.dataproc.DataprocListLink
- airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreDetailedLink
- airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreLink
- airflow.providers.google.cloud.links.vertex_ai.VertexAIModelLink
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index d67ee71..a836d13 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -27,8 +27,9 @@ from google.api_core.retry import Retry
from airflow import AirflowException
from airflow.models import DAG, DagBag
from airflow.providers.google.cloud.operators.dataproc import (
+ DATAPROC_CLUSTER_LINK,
+ DATAPROC_JOB_LOG_LINK,
ClusterGenerator,
- DataprocClusterLink,
DataprocCreateBatchOperator,
DataprocCreateClusterOperator,
DataprocCreateWorkflowTemplateOperator,
@@ -37,7 +38,7 @@ from airflow.providers.google.cloud.operators.dataproc import (
DataprocGetBatchOperator,
DataprocInstantiateInlineWorkflowTemplateOperator,
DataprocInstantiateWorkflowTemplateOperator,
- DataprocJobLink,
+ DataprocLink,
DataprocListBatchesOperator,
DataprocScaleClusterOperator,
DataprocSubmitHadoopJobOperator,
@@ -198,14 +199,16 @@ DATAPROC_CLUSTER_LINK_EXPECTED = (
f"region={GCP_LOCATION}&project={GCP_PROJECT}"
)
DATAPROC_JOB_CONF_EXPECTED = {
- "job_id": TEST_JOB_ID,
+ "resource": TEST_JOB_ID,
"region": GCP_LOCATION,
"project_id": GCP_PROJECT,
+ "url": DATAPROC_JOB_LOG_LINK,
}
DATAPROC_CLUSTER_CONF_EXPECTED = {
- "cluster_name": CLUSTER_NAME,
+ "resource": CLUSTER_NAME,
"region": GCP_LOCATION,
"project_id": GCP_PROJECT,
+ "url": DATAPROC_CLUSTER_LINK,
}
BATCH_ID = "test-batch-id"
BATCH = {
@@ -249,7 +252,7 @@ class DataprocJobTestBase(DataprocTestBase):
def setUpClass(cls):
super().setUpClass()
cls.extra_links_expected_calls = [
- call.ti.xcom_push(execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED),
+ call.ti.xcom_push(execution_date=None, key='conf', value=DATAPROC_JOB_CONF_EXPECTED),
call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_LOCATION, project_id=GCP_PROJECT),
]
@@ -259,7 +262,7 @@ class DataprocClusterTestBase(DataprocTestBase):
def setUpClass(cls):
super().setUpClass()
cls.extra_links_expected_calls_base = [
- call.ti.xcom_push(execution_date=None, key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ call.ti.xcom_push(execution_date=None, key='conf', value=DATAPROC_CLUSTER_CONF_EXPECTED)
]
@@ -448,7 +451,7 @@ class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
self.mock_ti.xcom_push.assert_called_once_with(
- key="cluster_conf",
+ key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
execution_date=None,
)
@@ -607,28 +610,27 @@ def test_create_cluster_operator_extra_links(dag_maker, create_task_instance_of_
# Assert operator links for serialized DAG
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == ""
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == ""
+ assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
- ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
# Assert operator links are preserved in deserialized tasks after execution
assert (
- deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name)
- == DATAPROC_CLUSTER_LINK_EXPECTED
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
)
# Assert operator links after execution
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
class TestDataprocClusterScaleOperator(DataprocClusterTestBase):
@@ -675,7 +677,7 @@ class TestDataprocClusterScaleOperator(DataprocClusterTestBase):
self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)
self.mock_ti.xcom_push.assert_called_once_with(
- key="cluster_conf",
+ key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
execution_date=None,
)
@@ -703,28 +705,27 @@ def test_scale_cluster_operator_extra_links(dag_maker, create_task_instance_of_o
# Assert operator links for serialized DAG
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == ""
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == ""
+ assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
- ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
# Assert operator links are preserved in deserialized tasks after execution
assert (
- deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name)
- == DATAPROC_CLUSTER_LINK_EXPECTED
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
)
# Assert operator links after execution
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
class TestDataprocClusterDeleteOperator(unittest.TestCase):
@@ -759,9 +760,7 @@ class TestDataprocClusterDeleteOperator(unittest.TestCase):
class TestDataprocSubmitJobOperator(DataprocJobTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
- xcom_push_call = call.ti.xcom_push(
- execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED
- )
+ xcom_push_call = call.ti.xcom_push(execution_date=None, key='conf', value=DATAPROC_JOB_CONF_EXPECTED)
wait_for_job_call = call.hook().wait_for_job(
job_id=TEST_JOB_ID, region=GCP_LOCATION, project_id=GCP_PROJECT, timeout=None
)
@@ -808,7 +807,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
)
self.mock_ti.xcom_push.assert_called_once_with(
- key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
+ key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -848,7 +847,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
mock_hook.return_value.wait_for_job.assert_not_called()
self.mock_ti.xcom_push.assert_called_once_with(
- key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
+ key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -884,9 +883,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_location_deprecation_warning(self, mock_hook):
- xcom_push_call = call.ti.xcom_push(
- execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED
- )
+ xcom_push_call = call.ti.xcom_push(execution_date=None, key='conf', value=DATAPROC_JOB_CONF_EXPECTED)
wait_for_job_call = call.hook().wait_for_job(
job_id=TEST_JOB_ID, region=GCP_LOCATION, project_id=GCP_PROJECT, timeout=None
)
@@ -941,7 +938,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
)
self.mock_ti.xcom_push.assert_called_once_with(
- key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
+ key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
)
assert warning_message == str(warnings[0].message)
@@ -982,25 +979,25 @@ def test_submit_job_operator_extra_links(mock_hook, dag_maker, create_task_insta
# Assert operator links for serialized_dag
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocJobLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == ""
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == ""
+ assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
- ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED)
+ ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED)
# Assert operator links are preserved in deserialized tasks
- assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED
+ assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_JOB_LINK_EXPECTED
# Assert operator links after execution
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_JOB_LINK_EXPECTED
class TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
@@ -1048,7 +1045,7 @@ class TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)
self.mock_ti.xcom_push.assert_called_once_with(
- key="cluster_conf",
+ key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
execution_date=None,
)
@@ -1105,7 +1102,7 @@ class TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)
self.mock_ti.xcom_push.assert_called_once_with(
- key="cluster_conf",
+ key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
execution_date=None,
)
@@ -1150,28 +1147,27 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_
# Assert operator links for serialized_dag
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == ""
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == ""
+ assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
- ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ti.xcom_push(key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
# Assert operator links are preserved in deserialized tasks
assert (
- deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name)
- == DATAPROC_CLUSTER_LINK_EXPECTED
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
)
# Assert operator links after execution
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED
class TestDataprocWorkflowTemplateInstantiateOperator(unittest.TestCase):
@@ -1195,7 +1191,7 @@ class TestDataprocWorkflowTemplateInstantiateOperator(unittest.TestCase):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.instantiate_workflow_template.assert_called_once_with(
template_name=template_id,
@@ -1227,7 +1223,7 @@ class TestDataprocWorkflowTemplateInstantiateInlineOperator(unittest.TestCase):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with(
template=template,
@@ -1491,7 +1487,7 @@ class TestDataProcSparkOperator(DataprocJobTestBase):
op.execute(context=self.mock_context)
self.mock_ti.xcom_push.assert_called_once_with(
- key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
+ key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
)
# Test whether xcom push occurs before polling for job
@@ -1520,25 +1516,25 @@ def test_submit_spark_job_operator_extra_links(mock_hook, dag_maker, create_task
# Assert operator links for serialized DAG
assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}
+ {"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
]
# Assert operator link types are preserved during deserialization
- assert isinstance(deserialized_task.operator_extra_links[0], DataprocJobLink)
+ assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink)
# Assert operator link is empty when no XCom push occurred
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == ""
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
# Assert operator link is empty for deserialized task when no XCom push occurred
- assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == ""
+ assert deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == ""
- ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED)
+ ti.xcom_push(key="conf", value=DATAPROC_JOB_CONF_EXPECTED)
# Assert operator links after task execution
- assert ti.task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED
+ assert ti.task.get_extra_links(DEFAULT_DATE, DataprocLink.name) == DATAPROC_JOB_LINK_EXPECTED
# Assert operator links are preserved in deserialized tasks
- link = deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name)
+ link = deserialized_task.get_extra_links(DEFAULT_DATE, DataprocLink.name)
assert link == DATAPROC_JOB_LINK_EXPECTED
@@ -1622,7 +1618,7 @@ class TestDataprocCreateWorkflowTemplateOperator:
metadata=METADATA,
template=WORKFLOW_TEMPLATE,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_workflow_template.assert_called_once_with(
region=GCP_LOCATION,
@@ -1651,7 +1647,7 @@ class TestDataprocCreateWorkflowTemplateOperator:
metadata=METADATA,
template=WORKFLOW_TEMPLATE,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN
)
@@ -1696,7 +1692,7 @@ class TestDataprocCreateBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_batch.assert_called_once_with(
region=GCP_LOCATION,
@@ -1751,7 +1747,7 @@ class TestDataprocGetBatchOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.get_batch.assert_called_once_with(
project_id=GCP_PROJECT,
@@ -1781,7 +1777,7 @@ class TestDataprocListBatchesOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.list_batches.assert_called_once_with(
region=GCP_LOCATION,