You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/05/07 09:21:44 UTC
[airflow] branch master updated: Add extra links for google
dataproc (#10343)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new b8c0fde Add extra links for google dataproc (#10343)
b8c0fde is described below
commit b8c0fde38a7df9d00185bf53e9f303b98fd064dc
Author: Santhosh Kumar <88...@users.noreply.github.com>
AuthorDate: Fri May 7 14:51:19 2021 +0530
Add extra links for google dataproc (#10343)
---
.../providers/google/cloud/operators/dataproc.py | 104 ++++-
airflow/providers/google/provider.yaml | 2 +
airflow/serialization/serialized_objects.py | 9 +
.../run_install_and_test_provider_packages.sh | 2 +-
tests/core/test_providers_manager.py | 2 +
.../google/cloud/operators/test_dataproc.py | 518 ++++++++++++++++++---
6 files changed, 582 insertions(+), 55 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index c8ee5e3..d8df03a 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -35,12 +35,57 @@ from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
+from airflow.models import BaseOperator, BaseOperatorLink
+from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.utils import timezone
from airflow.utils.decorators import apply_defaults
+DATAPROC_BASE_LINK = "https://console.cloud.google.com/dataproc"
+DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}"
+DATAPROC_CLUSTER_LINK = (
+ DATAPROC_BASE_LINK + "/clusters/{cluster_name}/monitoring?region={region}&project={project_id}"
+)
+
+
+class DataprocJobLink(BaseOperatorLink):
+ """Helper class for constructing Dataproc Job link"""
+
+ name = "Dataproc Job"
+
+ def get_link(self, operator, dttm):
+ ti = TaskInstance(task=operator, execution_date=dttm)
+ job_conf = ti.xcom_pull(task_ids=operator.task_id, key="job_conf")
+ return (
+ DATAPROC_JOB_LOG_LINK.format(
+ job_id=job_conf["job_id"],
+ region=job_conf["region"],
+ project_id=job_conf["project_id"],
+ )
+ if job_conf
+ else ""
+ )
+
+
+class DataprocClusterLink(BaseOperatorLink):
+ """Helper class for constructing Dataproc Cluster link"""
+
+ name = "Dataproc Cluster"
+
+ def get_link(self, operator, dttm):
+ ti = TaskInstance(task=operator, execution_date=dttm)
+ cluster_conf = ti.xcom_pull(task_ids=operator.task_id, key="cluster_conf")
+ return (
+ DATAPROC_CLUSTER_LINK.format(
+ cluster_name=cluster_conf["cluster_name"],
+ region=cluster_conf["region"],
+ project_id=cluster_conf["project_id"],
+ )
+ if cluster_conf
+ else ""
+ )
+
# pylint: disable=too-many-instance-attributes
class ClusterGenerator:
@@ -478,6 +523,8 @@ class DataprocCreateClusterOperator(BaseOperator):
)
template_fields_renderers = {'cluster_config': 'json'}
+ operator_extra_links = (DataprocClusterLink(),)
+
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
self,
@@ -620,6 +667,16 @@ class DataprocCreateClusterOperator(BaseOperator):
def execute(self, context) -> dict:
self.log.info('Creating cluster: %s', self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+ # Save data required to display extra link no matter what the cluster status will be
+ self.xcom_push(
+ context,
+ key="cluster_conf",
+ value={
+ "cluster_name": self.cluster_name,
+ "region": self.region,
+ "project_id": self.project_id,
+ },
+ )
try:
# First try to create a new cluster
cluster = self._create_cluster(hook)
@@ -694,6 +751,8 @@ class DataprocScaleClusterOperator(BaseOperator):
template_fields = ['cluster_name', 'project_id', 'region', 'impersonation_chain']
+ operator_extra_links = (DataprocClusterLink(),)
+
@apply_defaults
def __init__(
self,
@@ -773,6 +832,16 @@ class DataprocScaleClusterOperator(BaseOperator):
update_mask = ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"]
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+ # Save data required to display extra link no matter what the cluster status will be
+ self.xcom_push(
+ context,
+ key="cluster_conf",
+ value={
+ "cluster_name": self.cluster_name,
+ "region": self.region,
+ "project_id": self.project_id,
+ },
+ )
operation = hook.update_cluster(
project_id=self.project_id,
location=self.region,
@@ -931,6 +1000,8 @@ class DataprocJobBaseOperator(BaseOperator):
job_type = ""
+ operator_extra_links = (DataprocJobLink(),)
+
@apply_defaults
def __init__(
self,
@@ -1005,6 +1076,12 @@ class DataprocJobBaseOperator(BaseOperator):
)
job_id = job_object.reference.job_id
self.log.info('Job %s submitted successfully.', job_id)
+ # Save data required for extra links no matter what the job status will be
+ self.xcom_push(
+ context,
+ key='job_conf',
+ value={'job_id': job_id, 'region': self.region, 'project_id': self.project_id},
+ )
if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
@@ -1082,6 +1159,8 @@ class DataprocSubmitPigJobOperator(DataprocJobBaseOperator):
ui_color = '#0273d4'
job_type = 'pig_job'
+ operator_extra_links = (DataprocJobLink(),)
+
@apply_defaults
def __init__(
self,
@@ -1871,6 +1950,8 @@ class DataprocSubmitJobOperator(BaseOperator):
template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id')
template_fields_renderers = {"job": "json"}
+ operator_extra_links = (DataprocJobLink(),)
+
@apply_defaults
def __init__(
self,
@@ -1919,6 +2000,16 @@ class DataprocSubmitJobOperator(BaseOperator):
)
job_id = job_object.reference.job_id
self.log.info('Job %s submitted successfully.', job_id)
+ # Save data required by extra links no matter what the job status will be
+ self.xcom_push(
+ context,
+ key="job_conf",
+ value={
+ "job_id": job_id,
+ "region": self.location,
+ "project_id": self.project_id,
+ },
+ )
if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
@@ -1988,6 +2079,7 @@ class DataprocUpdateClusterOperator(BaseOperator):
"""
template_fields = ('impersonation_chain', 'cluster_name')
+ operator_extra_links = (DataprocClusterLink(),)
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
@@ -2023,6 +2115,16 @@ class DataprocUpdateClusterOperator(BaseOperator):
def execute(self, context: Dict):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+ # Save data required by extra links no matter what the cluster status will be
+ self.xcom_push(
+ context,
+ key="cluster_conf",
+ value={
+ "cluster_name": self.cluster_name,
+ "region": self.location,
+ "project_id": self.project_id,
+ },
+ )
self.log.info("Updating %s cluster.", self.cluster_name)
operation = hook.update_cluster(
project_id=self.project_id,
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 644f093..87ba4ae 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -743,6 +743,8 @@ extra-links:
- airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink
- airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink
- airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink
+ - airflow.providers.google.cloud.operators.dataproc.DataprocJobLink
+ - airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink
additional-extras:
apache.beam: apache-beam[gcp]
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 8a6fdc8..645844a 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -72,6 +72,15 @@ _OPERATOR_EXTRA_LINKS: Set[str] = {
"airflow.sensors.external_task_sensor.ExternalTaskSensorLink",
}
+BUILTIN_OPERATOR_EXTRA_LINKS: List[str] = [
+ "airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink",
+ "airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocJobLink",
+ "airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink",
+ "airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink",
+ "airflow.providers.qubole.operators.qubole.QDSLink",
+]
+
@cache
def get_operator_extra_links():
diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh
index 6dd1908..9dd8540 100755
--- a/scripts/in_container/run_install_and_test_provider_packages.sh
+++ b/scripts/in_container/run_install_and_test_provider_packages.sh
@@ -137,7 +137,7 @@ function discover_all_extra_links() {
group_start "Listing available extra links via 'airflow providers links'"
COLUMNS=180 airflow providers links
- local expected_number_of_extra_links=4
+ local expected_number_of_extra_links=6
local actual_number_of_extra_links
actual_number_of_extra_links=$(airflow providers links --output table | grep -c ^airflow.providers | xargs)
if [[ ${actual_number_of_extra_links} != "${expected_number_of_extra_links}" ]]; then
diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py
index 57b581b..e6f26ce 100644
--- a/tests/core/test_providers_manager.py
+++ b/tests/core/test_providers_manager.py
@@ -227,6 +227,8 @@ CONNECTIONS_WITH_FIELD_BEHAVIOURS = [
EXTRA_LINKS = [
'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink',
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink',
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocJobLink',
'airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink',
'airflow.providers.qubole.operators.qubole.QDSLink',
]
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index fb2ceef..9a0ef21 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -19,19 +19,23 @@ import inspect
import unittest
from datetime import datetime
from unittest import mock
+from unittest.mock import MagicMock, Mock, call
import pytest
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry
from airflow import AirflowException
+from airflow.models import DAG, DagBag, TaskInstance
from airflow.providers.google.cloud.operators.dataproc import (
ClusterGenerator,
+ DataprocClusterLink,
DataprocCreateClusterOperator,
DataprocCreateWorkflowTemplateOperator,
DataprocDeleteClusterOperator,
DataprocInstantiateInlineWorkflowTemplateOperator,
DataprocInstantiateWorkflowTemplateOperator,
+ DataprocJobLink,
DataprocScaleClusterOperator,
DataprocSubmitHadoopJobOperator,
DataprocSubmitHiveJobOperator,
@@ -42,7 +46,9 @@ from airflow.providers.google.cloud.operators.dataproc import (
DataprocSubmitSparkSqlJobOperator,
DataprocUpdateClusterOperator,
)
+from airflow.serialization.serialized_objects import SerializedDAG
from airflow.version import version as airflow_version
+from tests.test_utils.db import clear_db_runs, clear_db_xcom
cluster_params = inspect.signature(ClusterGenerator.__init__).parameters
@@ -171,12 +177,77 @@ WORKFLOW_TEMPLATE = {
},
"jobs": [{"step_id": "pig_job_1", "pig_job": {}}],
}
+TEST_DAG_ID = 'test-dataproc-operators'
+DEFAULT_DATE = datetime(2020, 1, 1)
+TEST_JOB_ID = "test-job"
+
+DATAPROC_JOB_LINK_EXPECTED = (
+ f"https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?"
+ f"region={GCP_LOCATION}&project={GCP_PROJECT}"
+)
+DATAPROC_CLUSTER_LINK_EXPECTED = (
+ f"https://console.cloud.google.com/dataproc/clusters/{CLUSTER_NAME}/monitoring?"
+ f"region={GCP_LOCATION}&project={GCP_PROJECT}"
+)
+DATAPROC_JOB_CONF_EXPECTED = {
+ "job_id": TEST_JOB_ID,
+ "region": GCP_LOCATION,
+ "project_id": GCP_PROJECT,
+}
+DATAPROC_CLUSTER_CONF_EXPECTED = {
+ "cluster_name": CLUSTER_NAME,
+ "region": GCP_LOCATION,
+ "project_id": GCP_PROJECT,
+}
def assert_warning(msg: str, warnings):
assert any(msg in str(w) for w in warnings)
+class DataprocTestBase(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.dagbag = DagBag(dag_folder="/dev/null", include_examples=False)
+ cls.dag = DAG(TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE})
+
+ def setUp(self):
+ self.mock_ti = MagicMock()
+ self.mock_context = {"ti": self.mock_ti}
+ self.extra_links_manager_mock = Mock()
+ self.extra_links_manager_mock.attach_mock(self.mock_ti, 'ti')
+
+ def tearDown(self):
+ self.mock_ti = MagicMock()
+ self.mock_context = {"ti": self.mock_ti}
+ self.extra_links_manager_mock = Mock()
+ self.extra_links_manager_mock.attach_mock(self.mock_ti, 'ti')
+
+ @classmethod
+ def tearDownClass(cls):
+ clear_db_runs()
+ clear_db_xcom()
+
+
+class DataprocJobTestBase(DataprocTestBase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls.extra_links_expected_calls = [
+ call.ti.xcom_push(execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED),
+ call.hook().wait_for_job(job_id=TEST_JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT),
+ ]
+
+
+class DataprocClusterTestBase(DataprocTestBase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls.extra_links_expected_calls_base = [
+ call.ti.xcom_push(execution_date=None, key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ]
+
+
class TestsClusterGenerator(unittest.TestCase):
def test_image_version(self):
with pytest.raises(ValueError) as ctx:
@@ -290,7 +361,7 @@ class TestsClusterGenerator(unittest.TestCase):
assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster
-class TestDataprocClusterCreateOperator(unittest.TestCase):
+class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
def test_deprecation_warning(self):
with pytest.warns(DeprecationWarning) as warnings:
op = DataprocCreateClusterOperator(
@@ -321,6 +392,23 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook, to_dict_mock):
+ self.extra_links_manager_mock.attach_mock(mock_hook, 'hook')
+ mock_hook.return_value.create_cluster.result.return_value = None
+ create_cluster_args = {
+ 'region': GCP_LOCATION,
+ 'project_id': GCP_PROJECT,
+ 'cluster_name': CLUSTER_NAME,
+ 'request_id': REQUEST_ID,
+ 'retry': RETRY,
+ 'timeout': TIMEOUT,
+ 'metadata': METADATA,
+ 'cluster_config': CONFIG,
+ 'labels': LABELS,
+ }
+ expected_calls = self.extra_links_expected_calls_base + [
+ call.hook().create_cluster(**create_cluster_args),
+ ]
+
op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
@@ -335,20 +423,19 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
metadata=METADATA,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
- mock_hook.return_value.create_cluster.assert_called_once_with(
- region=GCP_LOCATION,
- project_id=GCP_PROJECT,
- cluster_config=CONFIG,
- labels=LABELS,
- cluster_name=CLUSTER_NAME,
- request_id=REQUEST_ID,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
- )
+ mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args)
+
+ # Test whether xcom push occurs before create cluster is called
+ self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)
+
to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="cluster_conf",
+ value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ execution_date=None,
+ )
@mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -369,7 +456,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
request_id=REQUEST_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_cluster.assert_called_once_with(
region=GCP_LOCATION,
@@ -411,7 +498,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
use_if_exists=False,
)
with pytest.raises(AlreadyExists):
- op.execute(context={})
+ op.execute(context=self.mock_context)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_if_cluster_exists_in_error_state(self, mock_hook):
@@ -435,7 +522,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
request_id=REQUEST_ID,
)
with pytest.raises(AirflowException):
- op.execute(context={})
+ op.execute(context=self.mock_context)
mock_hook.return_value.diagnose_cluster.assert_called_once_with(
region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME
@@ -474,7 +561,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
gcp_conn_id=GCP_CONN_ID,
)
with pytest.raises(AirflowException):
- op.execute(context={})
+ op.execute(context=self.mock_context)
calls = [mock.call(mock_hook.return_value), mock.call(mock_hook.return_value)]
mock_get_cluster.assert_has_calls(calls)
@@ -483,8 +570,60 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME
)
+ def test_operator_extra_links(self):
+ op = DataprocCreateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ delete_on_error=True,
+ gcp_conn_id=GCP_CONN_ID,
+ dag=self.dag,
+ )
+
+ serialized_dag = SerializedDAG.to_dict(self.dag)
+ deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+ deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+ # Assert operator links for serialized DAG
+ self.assertEqual(
+ serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
+ [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}],
+ )
+
+ # Assert operator link types are preserved during deserialization
+ self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink)
+
+ ti = TaskInstance(task=op, execution_date=DEFAULT_DATE)
+
+ # Assert operator link is empty when no XCom push occured
+ self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "")
+
+ # Assert operator link is empty for deserialized task when no XCom push occured
+ self.assertEqual(
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ "",
+ )
+
+ ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+
+ # Assert operator links are preserved in deserialized tasks after execution
+ self.assertEqual(
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ DATAPROC_CLUSTER_LINK_EXPECTED,
+ )
+
+ # Assert operator links after execution
+ self.assertEqual(
+ op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ DATAPROC_CLUSTER_LINK_EXPECTED,
+ )
+
+ # Check negative case
+ self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "")
+
-class TestDataprocClusterScaleOperator(unittest.TestCase):
+class TestDataprocClusterScaleOperator(DataprocClusterTestBase):
def test_deprecation_warning(self):
with pytest.warns(DeprecationWarning) as warnings:
DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT)
@@ -492,9 +631,22 @@ class TestDataprocClusterScaleOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
+ self.extra_links_manager_mock.attach_mock(mock_hook, 'hook')
+ mock_hook.return_value.update_cluster.result.return_value = None
cluster_update = {
"config": {"worker_config": {"num_instances": 3}, "secondary_worker_config": {"num_instances": 4}}
}
+ update_cluster_args = {
+ 'project_id': GCP_PROJECT,
+ 'location': GCP_LOCATION,
+ 'cluster_name': CLUSTER_NAME,
+ 'cluster': cluster_update,
+ 'graceful_decommission_timeout': {"seconds": 600},
+ 'update_mask': UPDATE_MASK,
+ }
+ expected_calls = self.extra_links_expected_calls_base + [
+ call.hook().update_cluster(**update_cluster_args)
+ ]
op = DataprocScaleClusterOperator(
task_id=TASK_ID,
@@ -507,18 +659,73 @@ class TestDataprocClusterScaleOperator(unittest.TestCase):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
-
+ op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
- mock_hook.return_value.update_cluster.assert_called_once_with(
- project_id=GCP_PROJECT,
- location=GCP_LOCATION,
+ mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args)
+
+ # Test whether xcom push occurs before cluster is updated
+ self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)
+
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="cluster_conf",
+ value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ execution_date=None,
+ )
+
+ def test_operator_extra_links(self):
+ op = DataprocScaleClusterOperator(
+ task_id=TASK_ID,
cluster_name=CLUSTER_NAME,
- cluster=cluster_update,
- graceful_decommission_timeout={"seconds": 600},
- update_mask=UPDATE_MASK,
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ num_workers=3,
+ num_preemptible_workers=2,
+ graceful_decommission_timeout="2m",
+ gcp_conn_id=GCP_CONN_ID,
+ dag=self.dag,
)
+ serialized_dag = SerializedDAG.to_dict(self.dag)
+ deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+ deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+ # Assert operator links for serialized DAG
+ self.assertEqual(
+ serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
+ [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}],
+ )
+
+ # Assert operator link types are preserved during deserialization
+ self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink)
+
+ ti = TaskInstance(task=op, execution_date=DEFAULT_DATE)
+
+ # Assert operator link is empty when no XCom push occured
+ self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "")
+
+ # Assert operator link is empty for deserialized task when no XCom push occured
+ self.assertEqual(
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ "",
+ )
+
+ ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+
+ # Assert operator links are preserved in deserialized tasks after execution
+ self.assertEqual(
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ DATAPROC_CLUSTER_LINK_EXPECTED,
+ )
+
+ # Assert operator links after execution
+ self.assertEqual(
+ op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ DATAPROC_CLUSTER_LINK_EXPECTED,
+ )
+
+ # Check negative case
+ self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "")
+
class TestDataprocClusterDeleteOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -549,13 +756,20 @@ class TestDataprocClusterDeleteOperator(unittest.TestCase):
)
-class TestDataprocSubmitJobOperator(unittest.TestCase):
+class TestDataprocSubmitJobOperator(DataprocJobTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
+ xcom_push_call = call.ti.xcom_push(
+ execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED
+ )
+ wait_for_job_call = call.hook().wait_for_job(
+ job_id=TEST_JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, timeout=None
+ )
+
job = {}
- job_id = "job_id"
mock_hook.return_value.wait_for_job.return_value = None
- mock_hook.return_value.submit_job.return_value.reference.job_id = job_id
+ mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID
+ self.extra_links_manager_mock.attach_mock(mock_hook, 'hook')
op = DataprocSubmitJobOperator(
task_id=TASK_ID,
@@ -569,9 +783,17 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
request_id=REQUEST_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
+
+ # Test whether xcom push occurs before polling for job
+ self.assertLess(
+ self.extra_links_manager_mock.mock_calls.index(xcom_push_call),
+ self.extra_links_manager_mock.mock_calls.index(wait_for_job_call),
+ msg='Xcom push for Job Link has to be done before polling for job status',
+ )
+
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
@@ -582,15 +804,18 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
metadata=METADATA,
)
mock_hook.return_value.wait_for_job.assert_called_once_with(
- job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION, timeout=None
+ job_id=TEST_JOB_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, timeout=None
+ )
+
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_async(self, mock_hook):
job = {}
- job_id = "job_id"
mock_hook.return_value.wait_for_job.return_value = None
- mock_hook.return_value.submit_job.return_value.reference.job_id = job_id
+ mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID
op = DataprocSubmitJobOperator(
task_id=TASK_ID,
@@ -605,7 +830,7 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
request_id=REQUEST_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
@@ -622,6 +847,10 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
)
mock_hook.return_value.wait_for_job.assert_not_called()
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
+ )
+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill(self, mock_hook):
job = {}
@@ -642,7 +871,7 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
cancel_on_kill=False,
)
- op.execute(context={})
+ op.execute(context=self.mock_context)
op.on_kill()
mock_hook.return_value.cancel_job.assert_not_called()
@@ -653,10 +882,77 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
project_id=GCP_PROJECT, location=GCP_LOCATION, job_id=job_id
)
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_operator_extra_links(self, mock_hook):
+ mock_hook.return_value.project_id = GCP_PROJECT
+ op = DataprocSubmitJobOperator(
+ task_id=TASK_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ job={},
+ gcp_conn_id=GCP_CONN_ID,
+ dag=self.dag,
+ )
+
+ serialized_dag = SerializedDAG.to_dict(self.dag)
+ deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+ deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+ # Assert operator links for serialized_dag
+ self.assertEqual(
+ serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
+ [{"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}],
+ )
+
+ # Assert operator link types are preserved during deserialization
+ self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink)
+
+ ti = TaskInstance(task=op, execution_date=DEFAULT_DATE)
+
+ # Assert operator link is empty when no XCom push occured
+ self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "")
+
+ # Assert operator link is empty for deserialized task when no XCom push occured
+ self.assertEqual(deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "")
+
+ ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED)
-class TestDataprocUpdateClusterOperator(unittest.TestCase):
+ # Assert operator links are preserved in deserialized tasks
+ self.assertEqual(
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name),
+ DATAPROC_JOB_LINK_EXPECTED,
+ )
+ # Assert operator links after execution
+ self.assertEqual(
+ op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name),
+ DATAPROC_JOB_LINK_EXPECTED,
+ )
+ # Check for negative case
+ self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), "")
+
+
+class TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
+ self.extra_links_manager_mock.attach_mock(mock_hook, 'hook')
+ mock_hook.return_value.update_cluster.result.return_value = None
+ cluster_decommission_timeout = {"graceful_decommission_timeout": "600s"}
+ update_cluster_args = {
+ 'location': GCP_LOCATION,
+ 'project_id': GCP_PROJECT,
+ 'cluster_name': CLUSTER_NAME,
+ 'cluster': CLUSTER,
+ 'update_mask': UPDATE_MASK,
+ 'graceful_decommission_timeout': cluster_decommission_timeout,
+ 'request_id': REQUEST_ID,
+ 'retry': RETRY,
+ 'timeout': TIMEOUT,
+ 'metadata': METADATA,
+ }
+ expected_calls = self.extra_links_expected_calls_base + [
+ call.hook().update_cluster(**update_cluster_args)
+ ]
+
op = DataprocUpdateClusterOperator(
task_id=TASK_ID,
location=GCP_LOCATION,
@@ -664,7 +960,7 @@ class TestDataprocUpdateClusterOperator(unittest.TestCase):
cluster=CLUSTER,
update_mask=UPDATE_MASK,
request_id=REQUEST_ID,
- graceful_decommission_timeout={"graceful_decommission_timeout": "600s"},
+ graceful_decommission_timeout=cluster_decommission_timeout,
project_id=GCP_PROJECT,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
@@ -672,21 +968,71 @@ class TestDataprocUpdateClusterOperator(unittest.TestCase):
metadata=METADATA,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
- mock_hook.return_value.update_cluster.assert_called_once_with(
+ mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args)
+
+ # Test whether the xcom push happens before updating the cluster
+ self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)
+
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="cluster_conf",
+ value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ execution_date=None,
+ )
+
+ def test_operator_extra_links(self):
+ op = DataprocUpdateClusterOperator(
+ task_id=TASK_ID,
location=GCP_LOCATION,
- project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
cluster=CLUSTER,
update_mask=UPDATE_MASK,
graceful_decommission_timeout={"graceful_decommission_timeout": "600s"},
- request_id=REQUEST_ID,
- retry=RETRY,
- timeout=TIMEOUT,
- metadata=METADATA,
+ project_id=GCP_PROJECT,
+ gcp_conn_id=GCP_CONN_ID,
+ dag=self.dag,
)
+ serialized_dag = SerializedDAG.to_dict(self.dag)
+ deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+ deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+ # Assert operator links for serialized_dag
+ self.assertEqual(
+ serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
+ [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}],
+ )
+
+ # Assert operator link types are preserved during deserialization
+ self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink)
+
+ ti = TaskInstance(task=op, execution_date=DEFAULT_DATE)
+
+ # Assert operator link is empty when no XCom push occured
+ self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "")
+
+ # Assert operator link is empty for deserialized task when no XCom push occured
+ self.assertEqual(
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ "",
+ )
+
+ ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED)
+
+ # Assert operator links are preserved in deserialized tasks
+ self.assertEqual(
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ DATAPROC_CLUSTER_LINK_EXPECTED,
+ )
+ # Assert operator links after execution
+ self.assertEqual(
+ op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name),
+ DATAPROC_CLUSTER_LINK_EXPECTED,
+ )
+ # Check for negative case
+ self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "")
+
class TestDataprocWorkflowTemplateInstantiateOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -787,7 +1133,7 @@ class TestDataProcHiveOperator(unittest.TestCase):
variables=self.variables,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
@@ -846,7 +1192,7 @@ class TestDataProcPigOperator(unittest.TestCase):
variables=self.variables,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
@@ -911,7 +1257,7 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
variables=self.variables,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
@@ -937,7 +1283,7 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
variables=self.variables,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context={})
+ op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id="other-project", job=self.other_project_job, location=GCP_LOCATION
@@ -963,12 +1309,14 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
assert self.job == job
-class TestDataProcSparkOperator(unittest.TestCase):
+class TestDataProcSparkOperator(DataprocJobTestBase):
main_class = "org.apache.spark.examples.SparkPi"
jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"]
- job_id = "uuid_id"
job = {
- "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
+ "reference": {
+ "project_id": GCP_PROJECT,
+ "job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID,
+ },
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_job": {"jar_file_uris": jars, "main_class": main_class},
@@ -985,9 +1333,11 @@ class TestDataProcSparkOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook, mock_uuid):
- mock_uuid.return_value = self.job_id
+ mock_uuid.return_value = TEST_JOB_ID
mock_hook.return_value.project_id = GCP_PROJECT
- mock_uuid.return_value = self.job_id
+ mock_uuid.return_value = TEST_JOB_ID
+ mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID
+ self.extra_links_manager_mock.attach_mock(mock_hook, 'hook')
op = DataprocSubmitSparkJobOperator(
task_id=TASK_ID,
@@ -999,6 +1349,68 @@ class TestDataProcSparkOperator(unittest.TestCase):
job = op.generate_job()
assert self.job == job
+ op.execute(context=self.mock_context)
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
+ )
+
+ # Test whether xcom push occurs before polling for job
+ self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False)
+
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_operator_extra_links(self, mock_hook):
+ mock_hook.return_value.project_id = GCP_PROJECT
+
+ op = DataprocSubmitSparkJobOperator(
+ task_id=TASK_ID,
+ region=GCP_LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ main_class=self.main_class,
+ dataproc_jars=self.jars,
+ dag=self.dag,
+ )
+
+ serialized_dag = SerializedDAG.to_dict(self.dag)
+ deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+ deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+ # Assert operator links for serialized DAG
+ self.assertEqual(
+ serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
+ [{"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}],
+ )
+
+ # Assert operator link types are preserved during deserialization
+ self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink)
+
+ ti = TaskInstance(task=op, execution_date=DEFAULT_DATE)
+
+ # Assert operator link is empty when no XCom push occured
+ self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "")
+
+ # Assert operator link is empty for deserialized task when no XCom push occured
+ self.assertEqual(deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "")
+
+ ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED)
+
+ # Assert operator links after task execution
+ self.assertEqual(
+ op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name),
+ DATAPROC_JOB_LINK_EXPECTED,
+ )
+
+ # Assert operator links are preserved in deserialized tasks
+ self.assertEqual(
+ deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name),
+ DATAPROC_JOB_LINK_EXPECTED,
+ )
+
+ # Assert for negative case
+ self.assertEqual(
+ deserialized_task.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name),
+ "",
+ )
+
class TestDataProcHadoopOperator(unittest.TestCase):
args = ["wordcount", "gs://pub/shakespeare/rose.txt"]