You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by tu...@apache.org on 2021/01/18 16:49:33 UTC
[airflow] branch master updated: Refactor DataprocOperators to
support google-cloud-dataproc 2.0 (#13256)
This is an automated email from the ASF dual-hosted git repository.
turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 309788e Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256)
309788e is described below
commit 309788e5e2023c598095a4ee00df417d94b6a5df
Author: Tomek Urbaszek <tu...@gmail.com>
AuthorDate: Mon Jan 18 17:49:19 2021 +0100
Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256)
---
airflow/providers/google/ADDITIONAL_INFO.md | 2 +
airflow/providers/google/cloud/hooks/dataproc.py | 104 ++++++++---------
.../providers/google/cloud/operators/dataproc.py | 30 +++--
airflow/providers/google/cloud/sensors/dataproc.py | 12 +-
setup.py | 2 +-
.../providers/google/cloud/hooks/test_dataproc.py | 129 ++++++++++++---------
.../google/cloud/operators/test_dataproc.py | 14 ++-
.../google/cloud/sensors/test_dataproc.py | 8 +-
8 files changed, 157 insertions(+), 144 deletions(-)
diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index c696e1b..16a6683 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -32,11 +32,13 @@ Details are covered in the UPDATING.md files for each library, but there are som
| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
+| [``google-cloud-dataproc``](https://pypi.org/project/google-cloud-dataproc/) | ``>=1.0.1,<2.0.0`` | ``>=2.2.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-dataproc/blob/master/UPGRADING.md) |
| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
| [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
| [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) |
+
### The field names use the snake_case convention
If your DAG uses an object from the above mentioned libraries passed by XCom, it is necessary to update the naming convention of the fields that are read. Previously, the fields used the CamelSnake convention, now the snake_case convention is used.
diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py
index 12d5941..35d4786 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -26,18 +26,16 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from google.api_core.exceptions import ServerError
from google.api_core.retry import Retry
from google.cloud.dataproc_v1beta2 import ( # pylint: disable=no-name-in-module
- ClusterControllerClient,
- JobControllerClient,
- WorkflowTemplateServiceClient,
-)
-from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module
Cluster,
- Duration,
- FieldMask,
+ ClusterControllerClient,
Job,
+ JobControllerClient,
JobStatus,
WorkflowTemplate,
+ WorkflowTemplateServiceClient,
)
+from google.protobuf.duration_pb2 import Duration
+from google.protobuf.field_mask_pb2 import FieldMask
from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@@ -291,10 +289,12 @@ class DataprocHook(GoogleBaseHook):
client = self.get_cluster_client(location=region)
result = client.create_cluster(
- project_id=project_id,
- region=region,
- cluster=cluster,
- request_id=request_id,
+ request={
+ 'project_id': project_id,
+ 'region': region,
+ 'cluster': cluster,
+ 'request_id': request_id,
+ },
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -340,11 +340,13 @@ class DataprocHook(GoogleBaseHook):
"""
client = self.get_cluster_client(location=region)
result = client.delete_cluster(
- project_id=project_id,
- region=region,
- cluster_name=cluster_name,
- cluster_uuid=cluster_uuid,
- request_id=request_id,
+ request={
+ 'project_id': project_id,
+ 'region': region,
+ 'cluster_name': cluster_name,
+ 'cluster_uuid': cluster_uuid,
+ 'request_id': request_id,
+ },
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -382,9 +384,7 @@ class DataprocHook(GoogleBaseHook):
"""
client = self.get_cluster_client(location=region)
operation = client.diagnose_cluster(
- project_id=project_id,
- region=region,
- cluster_name=cluster_name,
+ request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -423,9 +423,7 @@ class DataprocHook(GoogleBaseHook):
"""
client = self.get_cluster_client(location=region)
result = client.get_cluster(
- project_id=project_id,
- region=region,
- cluster_name=cluster_name,
+ request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -467,10 +465,7 @@ class DataprocHook(GoogleBaseHook):
"""
client = self.get_cluster_client(location=region)
result = client.list_clusters(
- project_id=project_id,
- region=region,
- filter_=filter_,
- page_size=page_size,
+ request={'project_id': project_id, 'region': region, 'filter': filter_, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -551,13 +546,15 @@ class DataprocHook(GoogleBaseHook):
"""
client = self.get_cluster_client(location=location)
operation = client.update_cluster(
- project_id=project_id,
- region=location,
- cluster_name=cluster_name,
- cluster=cluster,
- update_mask=update_mask,
- graceful_decommission_timeout=graceful_decommission_timeout,
- request_id=request_id,
+ request={
+ 'project_id': project_id,
+ 'region': location,
+ 'cluster_name': cluster_name,
+ 'cluster': cluster,
+ 'update_mask': update_mask,
+ 'graceful_decommission_timeout': graceful_decommission_timeout,
+ 'request_id': request_id,
+ },
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -593,10 +590,11 @@ class DataprocHook(GoogleBaseHook):
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
+ metadata = metadata or ()
client = self.get_template_client(location)
- parent = client.region_path(project_id, location)
+ parent = f'projects/{project_id}/regions/{location}'
return client.create_workflow_template(
- parent=parent, template=template, retry=retry, timeout=timeout, metadata=metadata
+ request={'parent': parent, 'template': template}, retry=retry, timeout=timeout, metadata=metadata
)
@GoogleBaseHook.fallback_to_default_project_id
@@ -643,13 +641,11 @@ class DataprocHook(GoogleBaseHook):
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
+ metadata = metadata or ()
client = self.get_template_client(location)
- name = client.workflow_template_path(project_id, location, template_name)
+ name = f'projects/{project_id}/regions/{location}/workflowTemplates/{template_name}'
operation = client.instantiate_workflow_template(
- name=name,
- version=version,
- parameters=parameters,
- request_id=request_id,
+ request={'name': name, 'version': version, 'request_id': request_id, 'parameters': parameters},
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -690,12 +686,11 @@ class DataprocHook(GoogleBaseHook):
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
+ metadata = metadata or ()
client = self.get_template_client(location)
- parent = client.region_path(project_id, location)
+ parent = f'projects/{project_id}/regions/{location}'
operation = client.instantiate_inline_workflow_template(
- parent=parent,
- template=template,
- request_id=request_id,
+ request={'parent': parent, 'template': template, 'request_id': request_id},
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -722,19 +717,19 @@ class DataprocHook(GoogleBaseHook):
"""
state = None
start = time.monotonic()
- while state not in (JobStatus.ERROR, JobStatus.DONE, JobStatus.CANCELLED):
+ while state not in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
if timeout and start + timeout < time.monotonic():
raise AirflowException(f"Timeout: dataproc job {job_id} is not ready after {timeout}s")
time.sleep(wait_time)
try:
- job = self.get_job(location=location, job_id=job_id, project_id=project_id)
+ job = self.get_job(project_id=project_id, location=location, job_id=job_id)
state = job.status.state
except ServerError as err:
self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)
- if state == JobStatus.ERROR:
+ if state == JobStatus.State.ERROR:
raise AirflowException(f'Job failed:\n{job}')
- if state == JobStatus.CANCELLED:
+ if state == JobStatus.State.CANCELLED:
raise AirflowException(f'Job was cancelled:\n{job}')
@GoogleBaseHook.fallback_to_default_project_id
@@ -767,9 +762,7 @@ class DataprocHook(GoogleBaseHook):
"""
client = self.get_job_client(location=location)
job = client.get_job(
- project_id=project_id,
- region=location,
- job_id=job_id,
+ request={'project_id': project_id, 'region': location, 'job_id': job_id},
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -812,10 +805,7 @@ class DataprocHook(GoogleBaseHook):
"""
client = self.get_job_client(location=location)
return client.submit_job(
- project_id=project_id,
- region=location,
- job=job,
- request_id=request_id,
+ request={'project_id': project_id, 'region': location, 'job': job, 'request_id': request_id},
retry=retry,
timeout=timeout,
metadata=metadata,
@@ -884,9 +874,7 @@ class DataprocHook(GoogleBaseHook):
client = self.get_job_client(location=location)
job = client.cancel_job(
- project_id=project_id,
- region=location,
- job_id=job_id,
+ request={'project_id': project_id, 'region': location, 'job_id': job_id},
retry=retry,
timeout=timeout,
metadata=metadata,
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index 839a624..a7d1379 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -17,7 +17,6 @@
# under the License.
#
"""This module contains Google Dataproc operators."""
-# pylint: disable=C0302
import inspect
import ntpath
@@ -31,12 +30,9 @@ from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry, exponential_sleep_generator
-from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module
- Cluster,
- Duration,
- FieldMask,
-)
-from google.protobuf.json_format import MessageToDict
+from google.cloud.dataproc_v1beta2 import Cluster # pylint: disable=no-name-in-module
+from google.protobuf.duration_pb2 import Duration
+from google.protobuf.field_mask_pb2 import FieldMask
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -562,7 +558,7 @@ class DataprocCreateClusterOperator(BaseOperator):
)
def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
- if cluster.status.state != cluster.status.ERROR:
+ if cluster.status.state != cluster.status.State.ERROR:
return
self.log.info("Cluster is in ERROR state")
gcs_uri = hook.diagnose_cluster(
@@ -590,7 +586,7 @@ class DataprocCreateClusterOperator(BaseOperator):
time_left = self.timeout
cluster = self._get_cluster(hook)
for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
- if cluster.status.state != cluster.status.CREATING:
+ if cluster.status.state != cluster.status.State.CREATING:
break
if time_left < 0:
raise AirflowException(f"Cluster {self.cluster_name} is still CREATING state, aborting")
@@ -613,18 +609,18 @@ class DataprocCreateClusterOperator(BaseOperator):
# Check if cluster is not in ERROR state
self._handle_error_state(hook, cluster)
- if cluster.status.state == cluster.status.CREATING:
+ if cluster.status.state == cluster.status.State.CREATING:
# Wait for cluster to be be created
cluster = self._wait_for_cluster_in_creating_state(hook)
self._handle_error_state(hook, cluster)
- elif cluster.status.state == cluster.status.DELETING:
+ elif cluster.status.state == cluster.status.State.DELETING:
# Wait for cluster to be deleted
self._wait_for_cluster_in_deleting_state(hook)
# Create new cluster
cluster = self._create_cluster(hook)
self._handle_error_state(hook, cluster)
- return MessageToDict(cluster)
+ return Cluster.to_dict(cluster)
class DataprocScaleClusterOperator(BaseOperator):
@@ -1855,7 +1851,7 @@ class DataprocSubmitJobOperator(BaseOperator):
:type wait_timeout: int
"""
- template_fields = ('project_id', 'location', 'job', 'impersonation_chain')
+ template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id')
template_fields_renderers = {"job": "json"}
@apply_defaults
@@ -1941,14 +1937,14 @@ class DataprocUpdateClusterOperator(BaseOperator):
example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be
specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify the
new value. If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.dataproc_v1beta2.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.dataproc_v1beta2.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful
decommissioning allows removing nodes from the cluster without interrupting jobs in progress. Timeout
specifies how long to wait for jobs in progress to finish before forcefully removing nodes (and
potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the maximum
allowed timeout is 1 day.
- :type graceful_decommission_timeout: Union[Dict, google.cloud.dataproc_v1beta2.types.Duration]
+ :type graceful_decommission_timeout: Union[Dict, google.protobuf.duration_pb2.Duration]
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and the
first ``google.longrunning.Operation`` created and stored in the backend is returned.
@@ -1974,7 +1970,7 @@ class DataprocUpdateClusterOperator(BaseOperator):
:type impersonation_chain: Union[str, Sequence[str]]
"""
- template_fields = ('impersonation_chain',)
+ template_fields = ('impersonation_chain', 'cluster_name')
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py
index 1777a22..93656df 100644
--- a/airflow/providers/google/cloud/sensors/dataproc.py
+++ b/airflow/providers/google/cloud/sensors/dataproc.py
@@ -65,14 +65,18 @@ class DataprocJobSensor(BaseSensorOperator):
job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id)
state = job.status.state
- if state == JobStatus.ERROR:
+ if state == JobStatus.State.ERROR:
raise AirflowException(f'Job failed:\n{job}')
- elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, JobStatus.CANCEL_STARTED}:
+ elif state in {
+ JobStatus.State.CANCELLED,
+ JobStatus.State.CANCEL_PENDING,
+ JobStatus.State.CANCEL_STARTED,
+ }:
raise AirflowException(f'Job was cancelled:\n{job}')
- elif JobStatus.DONE == state:
+ elif JobStatus.State.DONE == state:
self.log.debug("Job %s completed successfully.", self.dataproc_job_id)
return True
- elif JobStatus.ATTEMPT_FAILURE == state:
+ elif JobStatus.State.ATTEMPT_FAILURE == state:
self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)
self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
diff --git a/setup.py b/setup.py
index eba4a7a..da29cd1 100644
--- a/setup.py
+++ b/setup.py
@@ -286,7 +286,7 @@ google = [
'google-cloud-bigtable>=1.0.0,<2.0.0',
'google-cloud-container>=0.1.1,<2.0.0',
'google-cloud-datacatalog>=3.0.0,<4.0.0',
- 'google-cloud-dataproc>=1.0.1,<2.0.0',
+ 'google-cloud-dataproc>=2.2.0,<3.0.0',
'google-cloud-dlp>=0.11.0,<2.0.0',
'google-cloud-kms>=2.0.0,<3.0.0',
'google-cloud-language>=1.1.1,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py
index d09c91e..6842acc 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -20,7 +20,7 @@ import unittest
from unittest import mock
import pytest
-from google.cloud.dataproc_v1beta2.types import JobStatus # pylint: disable=no-name-in-module
+from google.cloud.dataproc_v1beta2 import JobStatus # pylint: disable=no-name-in-module
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder
@@ -43,8 +43,6 @@ CLUSTER = {
"project_id": GCP_PROJECT,
}
-PARENT = "parent"
-NAME = "name"
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
DATAPROC_STRING = "airflow.providers.google.cloud.hooks.dataproc.{}"
@@ -113,11 +111,13 @@ class TestDataprocHook(unittest.TestCase):
)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.create_cluster.assert_called_once_with(
- project_id=GCP_PROJECT,
- region=GCP_LOCATION,
- cluster=CLUSTER,
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster=CLUSTER,
+ request_id=None,
+ ),
metadata=None,
- request_id=None,
retry=None,
timeout=None,
)
@@ -127,12 +127,14 @@ class TestDataprocHook(unittest.TestCase):
self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.delete_cluster.assert_called_once_with(
- project_id=GCP_PROJECT,
- region=GCP_LOCATION,
- cluster_name=CLUSTER_NAME,
- cluster_uuid=None,
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ cluster_uuid=None,
+ request_id=None,
+ ),
metadata=None,
- request_id=None,
retry=None,
timeout=None,
)
@@ -142,9 +144,11 @@ class TestDataprocHook(unittest.TestCase):
self.hook.diagnose_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.diagnose_cluster.assert_called_once_with(
- project_id=GCP_PROJECT,
- region=GCP_LOCATION,
- cluster_name=CLUSTER_NAME,
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ ),
metadata=None,
retry=None,
timeout=None,
@@ -156,9 +160,11 @@ class TestDataprocHook(unittest.TestCase):
self.hook.get_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.get_cluster.assert_called_once_with(
- project_id=GCP_PROJECT,
- region=GCP_LOCATION,
- cluster_name=CLUSTER_NAME,
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ ),
metadata=None,
retry=None,
timeout=None,
@@ -171,10 +177,12 @@ class TestDataprocHook(unittest.TestCase):
self.hook.list_clusters(project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.list_clusters.assert_called_once_with(
- project_id=GCP_PROJECT,
- region=GCP_LOCATION,
- filter_=filter_,
- page_size=None,
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ filter=filter_,
+ page_size=None,
+ ),
metadata=None,
retry=None,
timeout=None,
@@ -192,14 +200,16 @@ class TestDataprocHook(unittest.TestCase):
)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.update_cluster.assert_called_once_with(
- project_id=GCP_PROJECT,
- region=GCP_LOCATION,
- cluster=CLUSTER,
- cluster_name=CLUSTER_NAME,
- update_mask=update_mask,
- graceful_decommission_timeout=None,
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster=CLUSTER,
+ cluster_name=CLUSTER_NAME,
+ update_mask=update_mask,
+ graceful_decommission_timeout=None,
+ request_id=None,
+ ),
metadata=None,
- request_id=None,
retry=None,
timeout=None,
)
@@ -207,44 +217,45 @@ class TestDataprocHook(unittest.TestCase):
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
def test_create_workflow_template(self, mock_client):
template = {"test": "test"}
- mock_client.return_value.region_path.return_value = PARENT
+ parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}'
self.hook.create_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT)
- mock_client.return_value.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION)
mock_client.return_value.create_workflow_template.assert_called_once_with(
- parent=PARENT, template=template, retry=None, timeout=None, metadata=None
+ request=dict(parent=parent, template=template), retry=None, timeout=None, metadata=()
)
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
def test_instantiate_workflow_template(self, mock_client):
template_name = "template_name"
- mock_client.return_value.workflow_template_path.return_value = NAME
+ name = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}/workflowTemplates/{template_name}'
self.hook.instantiate_workflow_template(
location=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT
)
- mock_client.return_value.workflow_template_path.assert_called_once_with(
- GCP_PROJECT, GCP_LOCATION, template_name
- )
mock_client.return_value.instantiate_workflow_template.assert_called_once_with(
- name=NAME, version=None, parameters=None, request_id=None, retry=None, timeout=None, metadata=None
+ request=dict(name=name, version=None, parameters=None, request_id=None),
+ retry=None,
+ timeout=None,
+ metadata=(),
)
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
def test_instantiate_inline_workflow_template(self, mock_client):
template = {"test": "test"}
- mock_client.return_value.region_path.return_value = PARENT
+ parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}'
self.hook.instantiate_inline_workflow_template(
location=GCP_LOCATION, template=template, project_id=GCP_PROJECT
)
- mock_client.return_value.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION)
mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with(
- parent=PARENT, template=template, request_id=None, retry=None, timeout=None, metadata=None
+ request=dict(parent=parent, template=template, request_id=None),
+ retry=None,
+ timeout=None,
+ metadata=(),
)
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_job"))
def test_wait_for_job(self, mock_get_job):
mock_get_job.side_effect = [
- mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)),
- mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)),
+ mock.MagicMock(status=mock.MagicMock(state=JobStatus.State.RUNNING)),
+ mock.MagicMock(status=mock.MagicMock(state=JobStatus.State.ERROR)),
]
with pytest.raises(AirflowException):
self.hook.wait_for_job(job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0)
@@ -259,9 +270,11 @@ class TestDataprocHook(unittest.TestCase):
self.hook.get_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.get_job.assert_called_once_with(
- region=GCP_LOCATION,
- job_id=JOB_ID,
- project_id=GCP_PROJECT,
+ request=dict(
+ region=GCP_LOCATION,
+ job_id=JOB_ID,
+ project_id=GCP_PROJECT,
+ ),
retry=None,
timeout=None,
metadata=None,
@@ -272,10 +285,12 @@ class TestDataprocHook(unittest.TestCase):
self.hook.submit_job(location=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.submit_job.assert_called_once_with(
- region=GCP_LOCATION,
- job=JOB,
- project_id=GCP_PROJECT,
- request_id=None,
+ request=dict(
+ region=GCP_LOCATION,
+ job=JOB,
+ project_id=GCP_PROJECT,
+ request_id=None,
+ ),
retry=None,
timeout=None,
metadata=None,
@@ -297,9 +312,11 @@ class TestDataprocHook(unittest.TestCase):
self.hook.cancel_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT)
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.cancel_job.assert_called_once_with(
- region=GCP_LOCATION,
- job_id=JOB_ID,
- project_id=GCP_PROJECT,
+ request=dict(
+ region=GCP_LOCATION,
+ job_id=JOB_ID,
+ project_id=GCP_PROJECT,
+ ),
retry=None,
timeout=None,
metadata=None,
@@ -311,9 +328,11 @@ class TestDataprocHook(unittest.TestCase):
self.hook.cancel_job(job_id=JOB_ID, project_id=GCP_PROJECT)
mock_client.assert_called_once_with(location='global')
mock_client.return_value.cancel_job.assert_called_once_with(
- region='global',
- job_id=JOB_ID,
- project_id=GCP_PROJECT,
+ request=dict(
+ region='global',
+ job_id=JOB_ID,
+ project_id=GCP_PROJECT,
+ ),
retry=None,
timeout=None,
metadata=None,
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index 8c06ef7..e1c712e 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -217,8 +217,9 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
assert_warning("Default region value", warnings)
assert op_default_region.region == 'global'
+ @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_execute(self, mock_hook):
+ def test_execute(self, mock_hook, to_dict_mock):
op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
@@ -246,9 +247,11 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
timeout=TIMEOUT,
metadata=METADATA,
)
+ to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
+ @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_execute_if_cluster_exists(self, mock_hook):
+ def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock):
mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")]
mock_hook.return_value.get_cluster.return_value.status.state = 0
op = DataprocCreateClusterOperator(
@@ -286,6 +289,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
timeout=TIMEOUT,
metadata=METADATA,
)
+ to_dict_mock.assert_called_once_with(mock_hook.return_value.get_cluster.return_value)
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_if_cluster_exists_do_not_use(self, mock_hook):
@@ -313,7 +317,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")]
cluster_status = mock_hook.return_value.get_cluster.return_value.status
cluster_status.state = 0
- cluster_status.ERROR = 0
+ cluster_status.State.ERROR = 0
op = DataprocCreateClusterOperator(
task_id=TASK_ID,
@@ -348,11 +352,11 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
):
cluster = mock.MagicMock()
cluster.status.state = 0
- cluster.status.DELETING = 0
+ cluster.status.State.DELETING = 0 # pylint: disable=no-member
cluster2 = mock.MagicMock()
cluster2.status.state = 0
- cluster2.status.ERROR = 0
+ cluster2.status.State.ERROR = 0 # pylint: disable=no-member
mock_create_cluster.side_effect = [AlreadyExists("test"), cluster2]
mock_generator.return_value = [0]
diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py b/tests/providers/google/cloud/sensors/test_dataproc.py
index 1ce8eea..6f2991a 100644
--- a/tests/providers/google/cloud/sensors/test_dataproc.py
+++ b/tests/providers/google/cloud/sensors/test_dataproc.py
@@ -45,7 +45,7 @@ class TestDataprocJobSensor(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_done(self, mock_hook):
- job = self.create_job(JobStatus.DONE)
+ job = self.create_job(JobStatus.State.DONE)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job
@@ -66,7 +66,7 @@ class TestDataprocJobSensor(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_error(self, mock_hook):
- job = self.create_job(JobStatus.ERROR)
+ job = self.create_job(JobStatus.State.ERROR)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job
@@ -88,7 +88,7 @@ class TestDataprocJobSensor(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_wait(self, mock_hook):
- job = self.create_job(JobStatus.RUNNING)
+ job = self.create_job(JobStatus.State.RUNNING)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job
@@ -109,7 +109,7 @@ class TestDataprocJobSensor(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_cancelled(self, mock_hook):
- job = self.create_job(JobStatus.CANCELLED)
+ job = self.create_job(JobStatus.State.CANCELLED)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job