You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/22 19:23:53 UTC
[airflow] branch main updated: Dataproc submit job operator async (#25302)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ecf0460b7d Dataproc submit job operator async (#25302)
ecf0460b7d is described below
commit ecf0460b7d9c9e9b6462c3dfa92cdf4e373dbfd5
Author: Bartosz Jankiewicz <bj...@users.noreply.github.com>
AuthorDate: Mon Aug 22 21:23:46 2022 +0200
Dataproc submit job operator async (#25302)
---
airflow/providers/google/cloud/hooks/dataproc.py | 746 +++++++++++++++++++++
.../providers/google/cloud/operators/dataproc.py | 76 ++-
.../providers/google/cloud/triggers/dataproc.py | 86 +++
.../operators/cloud/dataproc.rst | 8 +
.../providers/google/cloud/hooks/test_dataproc.py | 406 ++++++++++-
.../google/cloud/operators/test_dataproc.py | 48 +-
.../dataproc/example_dataproc_spark_deferrable.py | 111 +++
7 files changed, 1477 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py
index a8acac0ea5..8e916ff0e5 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -26,16 +26,21 @@ from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import ServerError
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.operation import Operation
+from google.api_core.operation_async import AsyncOperation
from google.api_core.retry import Retry
from google.cloud.dataproc_v1 import (
Batch,
+ BatchControllerAsyncClient,
BatchControllerClient,
Cluster,
+ ClusterControllerAsyncClient,
ClusterControllerClient,
Job,
+ JobControllerAsyncClient,
JobControllerClient,
JobStatus,
WorkflowTemplate,
+ WorkflowTemplateServiceAsyncClient,
WorkflowTemplateServiceClient,
)
from google.protobuf.duration_pb2 import Duration
@@ -200,6 +205,14 @@ class DataprocHook(GoogleBaseHook):
keyword arguments rather than positional.
"""
+ def __init__(
+ self,
+ gcp_conn_id: str = 'google_cloud_default',
+ delegate_to: Optional[str] = None,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ ) -> None:
+ super().__init__(gcp_conn_id, delegate_to, impersonation_chain)
+
def get_cluster_client(self, region: Optional[str] = None) -> ClusterControllerClient:
"""Returns ClusterControllerClient."""
client_options = None
@@ -958,3 +971,736 @@ class DataprocHook(GoogleBaseHook):
metadata=metadata,
)
return result
+
+
+class DataprocAsyncHook(GoogleBaseHook):
+ """
+ Asynchronous Hook for Google Cloud Dataproc APIs.
+
+ All the methods in the hook where project_id is used must be called with
+ keyword arguments rather than positional.
+ """
+
+ def __init__(
+ self,
+ gcp_conn_id: str = 'google_cloud_default',
+ delegate_to: Optional[str] = None,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ ) -> None:
+ super().__init__(gcp_conn_id, delegate_to, impersonation_chain)
+
+ def get_cluster_client(self, region: Optional[str] = None) -> ClusterControllerAsyncClient:
+ """Returns ClusterControllerAsyncClient."""
+ client_options = None
+ if region and region != 'global':
+ client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443')
+
+ return ClusterControllerAsyncClient(
+ credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
+ )
+
+ def get_template_client(self, region: Optional[str] = None) -> WorkflowTemplateServiceAsyncClient:
+ """Returns WorkflowTemplateServiceAsyncClient."""
+ client_options = None
+ if region and region != 'global':
+ client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443')
+
+ return WorkflowTemplateServiceAsyncClient(
+ credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
+ )
+
+ def get_job_client(self, region: Optional[str] = None) -> JobControllerAsyncClient:
+ """Returns JobControllerAsyncClient."""
+ client_options = None
+ if region and region != 'global':
+ client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443')
+
+ return JobControllerAsyncClient(
+ credentials=self.get_credentials(),
+ client_info=CLIENT_INFO,
+ client_options=client_options,
+ )
+
+ def get_batch_client(self, region: Optional[str] = None) -> BatchControllerAsyncClient:
+ """Returns BatchControllerAsyncClient"""
+ client_options = None
+ if region and region != 'global':
+ client_options = ClientOptions(api_endpoint=f'{region}-dataproc.googleapis.com:443')
+
+ return BatchControllerAsyncClient(
+ credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
+ )
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def create_cluster(
+ self,
+ region: str,
+ project_id: str,
+ cluster_name: str,
+ cluster_config: Union[Dict, Cluster, None] = None,
+ virtual_cluster_config: Optional[Dict] = None,
+ labels: Optional[Dict[str, str]] = None,
+ request_id: Optional[str] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Creates a cluster in a project.
+
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param cluster_name: Name of the cluster to create
+ :param labels: Labels that will be assigned to created cluster
+ :param cluster_config: Required. The cluster config to create.
+ If a dict is provided, it must be of the same form as the protobuf message
+ :class:`~google.cloud.dataproc_v1.types.ClusterConfig`
+ :param virtual_cluster_config: Optional. The virtual cluster config, used when creating a Dataproc
+ cluster that does not directly control the underlying compute resources, for example, when
+ creating a `Dataproc-on-GKE cluster`
+ :class:`~google.cloud.dataproc_v1.types.VirtualClusterConfig`
+ :param request_id: Optional. A unique id used to identify the request. If the server receives two
+ ``CreateClusterRequest`` requests with the same id, then the second request will be ignored and
+ the first ``google.longrunning.Operation`` created and stored in the backend is returned.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ # Dataproc labels must conform to the following regex:
+ # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows
+ # semantic versioning spec: x.y.z).
+ labels = labels or {}
+ labels.update({'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-')})
+
+ cluster = {
+ "project_id": project_id,
+ "cluster_name": cluster_name,
+ }
+ if virtual_cluster_config is not None:
+ cluster['virtual_cluster_config'] = virtual_cluster_config # type: ignore
+ if cluster_config is not None:
+ cluster['config'] = cluster_config # type: ignore
+ cluster['labels'] = labels # type: ignore
+
+ client = self.get_cluster_client(region=region)
+ result = await client.create_cluster(
+ request={
+ 'project_id': project_id,
+ 'region': region,
+ 'cluster': cluster,
+ 'request_id': request_id,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def delete_cluster(
+ self,
+ region: str,
+ cluster_name: str,
+ project_id: str,
+ cluster_uuid: Optional[str] = None,
+ request_id: Optional[str] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Deletes a cluster in a project.
+
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param cluster_name: Required. The cluster name.
+ :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
+ if cluster with specified UUID does not exist.
+ :param request_id: Optional. A unique id used to identify the request. If the server receives two
+ ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and
+ the first ``google.longrunning.Operation`` created and stored in the backend is returned.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_cluster_client(region=region)
+ result = client.delete_cluster(
+ request={
+ 'project_id': project_id,
+ 'region': region,
+ 'cluster_name': cluster_name,
+ 'cluster_uuid': cluster_uuid,
+ 'request_id': request_id,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def diagnose_cluster(
+ self,
+ region: str,
+ cluster_name: str,
+ project_id: str,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Gets cluster diagnostic information. After the operation completes GCS uri to
+ diagnose is returned
+
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param cluster_name: Required. The cluster name.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_cluster_client(region=region)
+ operation = await client.diagnose_cluster(
+ request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ operation.result()
+ gcs_uri = str(operation.operation.response.value)
+ return gcs_uri
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def get_cluster(
+ self,
+ region: str,
+ cluster_name: str,
+ project_id: str,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Gets the resource representation for a cluster in a project.
+
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param cluster_name: Required. The cluster name.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_cluster_client(region=region)
+ result = await client.get_cluster(
+ request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def list_clusters(
+ self,
+ region: str,
+ filter_: str,
+ project_id: str,
+ page_size: Optional[int] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Lists all regions/{region}/clusters in a project.
+
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param filter_: Optional. A filter constraining the clusters to list. Filters are case-sensitive.
+ :param page_size: The maximum number of resources contained in the underlying API response. If page
+ streaming is performed per- resource, this parameter does not affect the return value. If page
+ streaming is performed per-page, this determines the maximum number of resources in a page.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_cluster_client(region=region)
+ result = await client.list_clusters(
+ request={'project_id': project_id, 'region': region, 'filter': filter_, 'page_size': page_size},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def update_cluster(
+ self,
+ cluster_name: str,
+ cluster: Union[Dict, Cluster],
+ update_mask: Union[Dict, FieldMask],
+ project_id: str,
+ region: str,
+ graceful_decommission_timeout: Optional[Union[Dict, Duration]] = None,
+ request_id: Optional[str] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Updates a cluster in a project.
+
+ :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param cluster_name: Required. The cluster name.
+ :param cluster: Required. The changes to the cluster.
+
+ If a dict is provided, it must be of the same form as the protobuf message
+ :class:`~google.cloud.dataproc_v1.types.Cluster`
+ :param update_mask: Required. Specifies the path, relative to ``Cluster``, of the field to update. For
+ example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be
+ specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify
+ the new value, as follows:
+
+ ::
+
+ { "config":{ "workerConfig":{ "numInstances":"5" } } }
+
+ Similarly, to change the number of preemptible workers in a cluster to 5, the ``update_mask``
+ parameter would be ``config.secondary_worker_config.num_instances``, and the ``PATCH`` request
+ body would be set as follows:
+
+ ::
+
+ { "config":{ "secondaryWorkerConfig":{ "numInstances":"5" } } }
+
+ If a dict is provided, it must be of the same form as the protobuf message
+ :class:`~google.cloud.dataproc_v1.types.FieldMask`
+ :param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful
+ decommissioning allows removing nodes from the cluster without interrupting jobs in progress.
+ Timeout specifies how long to wait for jobs in progress to finish before forcefully removing nodes
+ (and potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the
+ maximum allowed timeout is 1 day.
+
+ Only supported on Dataproc image versions 1.2 and higher.
+
+ If a dict is provided, it must be of the same form as the protobuf message
+ :class:`~google.cloud.dataproc_v1.types.Duration`
+ :param request_id: Optional. A unique id used to identify the request. If the server receives two
+ ``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and
+ the first ``google.longrunning.Operation`` created and stored in the backend is returned.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ if region is None:
+ raise TypeError("missing 1 required keyword argument: 'region'")
+ client = self.get_cluster_client(region=region)
+ operation = await client.update_cluster(
+ request={
+ 'project_id': project_id,
+ 'region': region,
+ 'cluster_name': cluster_name,
+ 'cluster': cluster,
+ 'update_mask': update_mask,
+ 'graceful_decommission_timeout': graceful_decommission_timeout,
+ 'request_id': request_id,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return operation
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def create_workflow_template(
+ self,
+ template: Union[Dict, WorkflowTemplate],
+ project_id: str,
+ region: str,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> WorkflowTemplate:
+ """
+ Creates new workflow template.
+
+ :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param template: The Dataproc workflow template to create. If a dict is provided,
+ it must be of the same form as the protobuf message WorkflowTemplate.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ if region is None:
+ raise TypeError("missing 1 required keyword argument: 'region'")
+ metadata = metadata or ()
+ client = self.get_template_client(region)
+ parent = f'projects/{project_id}/regions/{region}'
+ return await client.create_workflow_template(
+ request={'parent': parent, 'template': template}, retry=retry, timeout=timeout, metadata=metadata
+ )
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def instantiate_workflow_template(
+ self,
+ template_name: str,
+ project_id: str,
+ region: str,
+ version: Optional[int] = None,
+ request_id: Optional[str] = None,
+ parameters: Optional[Dict[str, str]] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Instantiates a template and begins execution.
+
+ :param template_name: Name of template to instantiate.
+ :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param version: Optional. The version of workflow template to instantiate. If specified,
+ the workflow will be instantiated only if the current version of
+ the workflow template has the supplied version.
+ This option cannot be used to instantiate a previous version of
+ workflow template.
+ :param request_id: Optional. A tag that prevents multiple concurrent workflow instances
+ with the same tag from running. This mitigates risk of concurrent
+ instances started due to retries.
+ :param parameters: Optional. Map from parameter names to values that should be used for those
+ parameters. Values may not exceed 100 characters.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ if region is None:
+ raise TypeError("missing 1 required keyword argument: 'region'")
+ metadata = metadata or ()
+ client = self.get_template_client(region)
+ name = f'projects/{project_id}/regions/{region}/workflowTemplates/{template_name}'
+ operation = await client.instantiate_workflow_template(
+ request={'name': name, 'version': version, 'request_id': request_id, 'parameters': parameters},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return operation
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def instantiate_inline_workflow_template(
+ self,
+ template: Union[Dict, WorkflowTemplate],
+ project_id: str,
+ region: str,
+ request_id: Optional[str] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Instantiates a template and begins execution.
+
+ :param template: The workflow template to instantiate. If a dict is provided,
+ it must be of the same form as the protobuf message WorkflowTemplate
+ :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param request_id: Optional. A tag that prevents multiple concurrent workflow instances
+ with the same tag from running. This mitigates risk of concurrent
+ instances started due to retries.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ if region is None:
+ raise TypeError("missing 1 required keyword argument: 'region'")
+ metadata = metadata or ()
+ client = self.get_template_client(region)
+ parent = f'projects/{project_id}/regions/{region}'
+ operation = await client.instantiate_inline_workflow_template(
+ request={'parent': parent, 'template': template, 'request_id': request_id},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return operation
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def get_job(
+ self,
+ job_id: str,
+ project_id: str,
+ region: str,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> Job:
+ """
+ Gets the resource representation for a job in a project.
+
+ :param job_id: Id of the Dataproc job
+ :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ if region is None:
+ raise TypeError("missing 1 required keyword argument: 'region'")
+ client = self.get_job_client(region=region)
+ job = await client.get_job(
+ request={'project_id': project_id, 'region': region, 'job_id': job_id},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return job
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def submit_job(
+ self,
+ job: Union[dict, Job],
+ project_id: str,
+ region: str,
+ request_id: Optional[str] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> Job:
+ """
+ Submits a job to a cluster.
+
+ :param job: The job resource. If a dict is provided,
+ it must be of the same form as the protobuf message Job
+ :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param request_id: Optional. A tag that prevents multiple concurrent workflow instances
+ with the same tag from running. This mitigates risk of concurrent
+ instances started due to retries.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ if region is None:
+ raise TypeError("missing 1 required keyword argument: 'region'")
+ client = self.get_job_client(region=region)
+ return await client.submit_job(
+ request={'project_id': project_id, 'region': region, 'job': job, 'request_id': request_id},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def cancel_job(
+ self,
+ job_id: str,
+ project_id: str,
+ region: Optional[str] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> Job:
+ """
+ Starts a job cancellation request.
+
+ :param project_id: Required. The ID of the Google Cloud project that the job belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param job_id: Required. The job ID.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_job_client(region=region)
+
+ job = await client.cancel_job(
+ request={'project_id': project_id, 'region': region, 'job_id': job_id},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return job
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def create_batch(
+ self,
+ region: str,
+ project_id: str,
+ batch: Union[Dict, Batch],
+ batch_id: Optional[str] = None,
+ request_id: Optional[str] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> AsyncOperation:
+ """
+ Creates a batch workload.
+
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param batch: Required. The batch to create.
+ :param batch_id: Optional. The ID to use for the batch, which will become the final component
+ of the batch's resource name.
+ This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
+ :param request_id: Optional. A unique id used to identify the request. If the server receives two
+ ``CreateBatchRequest`` requests with the same id, then the second request will be ignored and
+ the first ``google.longrunning.Operation`` created and stored in the backend is returned.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_batch_client(region)
+ parent = f'projects/{project_id}/regions/{region}'
+
+ result = await client.create_batch(
+ request={
+ 'parent': parent,
+ 'batch': batch,
+ 'batch_id': batch_id,
+ 'request_id': request_id,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def delete_batch(
+ self,
+ batch_id: str,
+ region: str,
+ project_id: str,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ """
+ Deletes the batch workload resource.
+
+ :param batch_id: Required. The ID to use for the batch, which will become the final component
+ of the batch's resource name.
+ This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_batch_client(region)
+ name = f"projects/{project_id}/regions/{region}/batches/{batch_id}"
+
+ await client.delete_batch(
+ request={
+ 'name': name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def get_batch(
+ self,
+ batch_id: str,
+ region: str,
+ project_id: str,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> Batch:
+ """
+ Gets the batch workload resource representation.
+
+ :param batch_id: Required. The ID to use for the batch, which will become the final component
+ of the batch's resource name.
+ This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_batch_client(region)
+ name = f"projects/{project_id}/regions/{region}/batches/{batch_id}"
+
+ result = await client.get_batch(
+ request={
+ 'name': name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def list_batches(
+ self,
+ region: str,
+ project_id: str,
+ page_size: Optional[int] = None,
+ page_token: Optional[str] = None,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ):
+ """
+ Lists batch workloads.
+
+ :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
+ :param region: Required. The Cloud Dataproc region in which to handle the request.
+ :param page_size: Optional. The maximum number of batches to return in each response. The service may
+ return fewer than this value. The default page size is 20; the maximum page size is 1000.
+ :param page_token: Optional. A page token received from a previous ``ListBatches`` call.
+ Provide this token to retrieve the subsequent page.
+ :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = self.get_batch_client(region)
+ parent = f'projects/{project_id}/regions/{region}'
+
+ result = await client.list_batches(
+ request={
+ 'parent': parent,
+ 'page_size': page_size,
+ 'page_token': page_token,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index 69ba3943e8..07ce34dc4f 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -32,7 +32,7 @@ from google.api_core import operation # type: ignore
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.retry import Retry, exponential_sleep_generator
-from google.cloud.dataproc_v1 import Batch, Cluster
+from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask
@@ -50,6 +50,7 @@ from airflow.providers.google.cloud.links.dataproc import (
DataprocLink,
DataprocListLink,
)
+from airflow.providers.google.cloud.triggers.dataproc import DataprocBaseTrigger
from airflow.utils import timezone
if TYPE_CHECKING:
@@ -867,6 +868,9 @@ class DataprocJobBaseOperator(BaseOperator):
:param asynchronous: Flag to return after submitting the job to the Dataproc API.
This is useful for submitting long running jobs and
waiting on them asynchronously using the DataprocJobSensor
+ :param deferrable: Run operator in the deferrable mode
+ :param polling_interval_seconds: time in seconds between polling for job completion.
+ The value is considered only when running in deferrable mode. Must be greater than 0.
:var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
This is useful for identifying or linking to the job in the Google Cloud Console
@@ -894,9 +898,13 @@ class DataprocJobBaseOperator(BaseOperator):
job_error_states: Optional[Set[str]] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
asynchronous: bool = False,
+ deferrable: bool = False,
+ polling_interval_seconds: int = 10,
**kwargs,
) -> None:
super().__init__(**kwargs)
+ if deferrable and polling_interval_seconds <= 0:
+ raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.labels = labels
@@ -914,6 +922,8 @@ class DataprocJobBaseOperator(BaseOperator):
self.job: Optional[dict] = None
self.dataproc_job_id = None
self.asynchronous = asynchronous
+ self.deferrable = deferrable
+ self.polling_interval_seconds = polling_interval_seconds
def create_job_template(self) -> DataProcJobBuilder:
"""Initialize `self.job_template` with default values"""
@@ -958,6 +968,19 @@ class DataprocJobBaseOperator(BaseOperator):
context=context, task_instance=self, url=DATAPROC_JOB_LOG_LINK, resource=job_id
)
+ if self.deferrable:
+ self.defer(
+ trigger=DataprocBaseTrigger(
+ job_id=job_id,
+ project_id=self.project_id,
+ region=self.region,
+ delegate_to=self.delegate_to,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ polling_interval_seconds=self.polling_interval_seconds,
+ ),
+ method_name="execute_complete",
+ )
if not self.asynchronous:
self.log.info('Waiting for job %s to complete', job_id)
self.hook.wait_for_job(job_id=job_id, region=self.region, project_id=self.project_id)
@@ -966,6 +989,20 @@ class DataprocJobBaseOperator(BaseOperator):
else:
raise AirflowException("Create a job template before")
+ def execute_complete(self, context, event=None) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ job_state = event["job_state"]
+ job_id = event["job_id"]
+ if job_state == JobStatus.State.ERROR:
+ raise AirflowException(f'Job failed:\n{job_id}')
+ if job_state == JobStatus.State.CANCELLED:
+ raise AirflowException(f'Job was cancelled:\n{job_id}')
+ self.log.info("%s completed successfully.", self.task_id)
+
def on_kill(self) -> None:
"""
Callback called when the operator is killed.
@@ -1771,6 +1808,9 @@ class DataprocSubmitJobOperator(BaseOperator):
:param asynchronous: Flag to return after submitting the job to the Dataproc API.
This is useful for submitting long running jobs and
waiting on them asynchronously using the DataprocJobSensor
+ :param deferrable: Run operator in the deferrable mode
+ :param polling_interval_seconds: time in seconds between polling for job completion.
+ The value is considered only when running in deferrable mode. Must be greater than 0.
:param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called
:param wait_timeout: How many seconds wait for job to be ready. Used only if ``asynchronous`` is False
"""
@@ -1793,11 +1833,15 @@ class DataprocSubmitJobOperator(BaseOperator):
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
asynchronous: bool = False,
+ deferrable: bool = False,
+ polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
wait_timeout: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
+ if deferrable and polling_interval_seconds <= 0:
+ raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.project_id = project_id
self.region = region
self.job = job
@@ -1808,6 +1852,8 @@ class DataprocSubmitJobOperator(BaseOperator):
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.asynchronous = asynchronous
+ self.deferrable = deferrable
+ self.polling_interval_seconds = polling_interval_seconds
self.cancel_on_kill = cancel_on_kill
self.hook: Optional[DataprocHook] = None
self.job_id: Optional[str] = None
@@ -1833,7 +1879,19 @@ class DataprocSubmitJobOperator(BaseOperator):
)
self.job_id = new_job_id
- if not self.asynchronous:
+ if self.deferrable:
+ self.defer(
+ trigger=DataprocBaseTrigger(
+ job_id=self.job_id,
+ project_id=self.project_id,
+ region=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ polling_interval_seconds=self.polling_interval_seconds,
+ ),
+ method_name="execute_complete",
+ )
+ elif not self.asynchronous:
self.log.info('Waiting for job %s to complete', new_job_id)
self.hook.wait_for_job(
job_id=new_job_id, region=self.region, project_id=self.project_id, timeout=self.wait_timeout
@@ -1842,6 +1900,20 @@ class DataprocSubmitJobOperator(BaseOperator):
return self.job_id
+ def execute_complete(self, context, event=None) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ job_state = event["job_state"]
+ job_id = event["job_id"]
+ if job_state == JobStatus.State.ERROR:
+ raise AirflowException(f'Job failed:\n{job_id}')
+ if job_state == JobStatus.State.CANCELLED:
+ raise AirflowException(f'Job was cancelled:\n{job_id}')
+ self.log.info("%s completed successfully.", self.task_id)
+
def on_kill(self):
if self.job_id and self.cancel_on_kill:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, region=self.region)
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py
new file mode 100644
index 0000000000..cdeb1feb8f
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+"""This module contains Google Dataproc triggers."""
+
+import asyncio
+from typing import Optional, Sequence, Union
+
+from google.cloud.dataproc_v1 import JobStatus
+
+from airflow import AirflowException
+from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class DataprocBaseTrigger(BaseTrigger):
+ """
+ Trigger that periodically polls information from Dataproc API to verify job status.
+ Implementation leverages asynchronous transport.
+ """
+
+ def __init__(
+ self,
+ job_id: str,
+ region: str,
+ project_id: Optional[str] = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ delegate_to: Optional[str] = None,
+ polling_interval_seconds: int = 30,
+ ):
+ super().__init__()
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.job_id = job_id
+ self.project_id = project_id
+ self.region = region
+ self.polling_interval_seconds = polling_interval_seconds
+ self.delegate_to = delegate_to
+ self.hook = DataprocAsyncHook(
+ delegate_to=self.delegate_to,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ def serialize(self):
+ return (
+ "airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger",
+ {
+ "job_id": self.job_id,
+ "project_id": self.project_id,
+ "region": self.region,
+ "gcp_conn_id": self.gcp_conn_id,
+ "delegate_to": self.delegate_to,
+ "impersonation_chain": self.impersonation_chain,
+ "polling_interval_seconds": self.polling_interval_seconds,
+ },
+ )
+
+ async def run(self):
+ while True:
+ job = await self.hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id)
+ state = job.status.state
+ self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
+ if state in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
+ if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED):
+ break
+ elif state == JobStatus.State.ERROR:
+ raise AirflowException(f"Dataproc job execution failed {self.job_id}")
+ await asyncio.sleep(self.polling_interval_seconds)
+ yield TriggerEvent({"job_id": self.job_id, "job_state": state})
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
index 4bcb3dc089..924ee7d24d 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
@@ -174,6 +174,14 @@ Example of the configuration for a Spark Job:
:start-after: [START how_to_cloud_dataproc_spark_config]
:end-before: [END how_to_cloud_dataproc_spark_config]
+Example of the configuration for a Spark Job running in `deferrable mode <https://airflow.apache.org/docs/apache-airflow/stable/concepts/deferring.html>`__:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py
+ :language: python
+ :dedent: 0
+ :start-after: [START how_to_cloud_dataproc_spark_deferrable_config]
+ :end-before: [END how_to_cloud_dataproc_spark_deferrable_config]
+
Example of the configuration for a Hive Job:
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_hive.py
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py
index 2713d7f0bb..536bb017ba 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -26,7 +26,7 @@ from google.cloud.dataproc_v1 import JobStatus
from parameterized import parameterized
from airflow.exceptions import AirflowException
-from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder
+from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook, DataProcJobBuilder
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.version import version
@@ -462,6 +462,410 @@ class TestDataprocHook(unittest.TestCase):
)
+class TestDataprocAsyncHook(unittest.TestCase):
+ def setUp(self):
+ with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init):
+ self.hook = DataprocAsyncHook(gcp_conn_id="test")
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @mock.patch(DATAPROC_STRING.format("ClusterControllerAsyncClient"))
+ def test_get_cluster_client(self, mock_client, mock_get_credentials):
+ self.hook.get_cluster_client(region=GCP_LOCATION)
+ mock_client.assert_called_once_with(
+ credentials=mock_get_credentials.return_value,
+ client_info=CLIENT_INFO,
+ client_options=None,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @mock.patch(DATAPROC_STRING.format("ClusterControllerAsyncClient"))
+ def test_get_cluster_client_region(self, mock_client, mock_get_credentials):
+ self.hook.get_cluster_client(region='region1')
+ mock_client.assert_called_once_with(
+ credentials=mock_get_credentials.return_value,
+ client_info=CLIENT_INFO,
+ client_options=ANY,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceAsyncClient"))
+ def test_get_template_client_global(self, mock_client, mock_get_credentials):
+ _ = self.hook.get_template_client()
+ mock_client.assert_called_once_with(
+ credentials=mock_get_credentials.return_value,
+ client_info=CLIENT_INFO,
+ client_options=None,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceAsyncClient"))
+ def test_get_template_client_region(self, mock_client, mock_get_credentials):
+ _ = self.hook.get_template_client(region='region1')
+ mock_client.assert_called_once_with(
+ credentials=mock_get_credentials.return_value,
+ client_info=CLIENT_INFO,
+ client_options=ANY,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @mock.patch(DATAPROC_STRING.format("JobControllerAsyncClient"))
+ def test_get_job_client(self, mock_client, mock_get_credentials):
+ self.hook.get_job_client(region=GCP_LOCATION)
+ mock_client.assert_called_once_with(
+ credentials=mock_get_credentials.return_value,
+ client_info=CLIENT_INFO,
+ client_options=None,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @mock.patch(DATAPROC_STRING.format("JobControllerAsyncClient"))
+ def test_get_job_client_region(self, mock_client, mock_get_credentials):
+ self.hook.get_job_client(region='region1')
+ mock_client.assert_called_once_with(
+ credentials=mock_get_credentials.return_value,
+ client_info=CLIENT_INFO,
+ client_options=ANY,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @mock.patch(DATAPROC_STRING.format("BatchControllerAsyncClient"))
+ def test_get_batch_client(self, mock_client, mock_get_credentials):
+ self.hook.get_batch_client(region=GCP_LOCATION)
+ mock_client.assert_called_once_with(
+ credentials=mock_get_credentials.return_value,
+ client_info=CLIENT_INFO,
+ client_options=None,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @mock.patch(DATAPROC_STRING.format("BatchControllerAsyncClient"))
+ def test_get_batch_client_region(self, mock_client, mock_get_credentials):
+ self.hook.get_batch_client(region='region1')
+ mock_client.assert_called_once_with(
+ credentials=mock_get_credentials.return_value, client_info=CLIENT_INFO, client_options=ANY
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_cluster_client"))
+ async def test_create_cluster(self, mock_client):
+ await self.hook.create_cluster(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ cluster_config=CLUSTER_CONFIG,
+ labels=LABELS,
+ )
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.create_cluster.assert_called_once_with(
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster=CLUSTER,
+ request_id=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_cluster_client"))
+ async def test_delete_cluster(self, mock_client):
+ await self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.delete_cluster.assert_called_once_with(
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ cluster_uuid=None,
+ request_id=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_cluster_client"))
+ async def test_diagnose_cluster(self, mock_client):
+ await self.hook.diagnose_cluster(
+ project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME
+ )
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.diagnose_cluster.assert_called_once_with(
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.diagnose_cluster.return_value.result.assert_called_once_with()
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_cluster_client"))
+ async def test_get_cluster(self, mock_client):
+ await self.hook.get_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.get_cluster.assert_called_once_with(
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_cluster_client"))
+ async def test_list_clusters(self, mock_client):
+ filter_ = "filter"
+
+ await self.hook.list_clusters(project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_)
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.list_clusters.assert_called_once_with(
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ filter=filter_,
+ page_size=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_cluster_client"))
+ async def test_update_cluster(self, mock_client):
+ update_mask = "update-mask"
+ await self.hook.update_cluster(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster=CLUSTER,
+ cluster_name=CLUSTER_NAME,
+ update_mask=update_mask,
+ )
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.update_cluster.assert_called_once_with(
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster=CLUSTER,
+ cluster_name=CLUSTER_NAME,
+ update_mask=update_mask,
+ graceful_decommission_timeout=None,
+ request_id=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_cluster_client"))
+ def test_update_cluster_missing_region(self, mock_client):
+ with pytest.raises(TypeError):
+ self.hook.update_cluster(
+ project_id=GCP_PROJECT,
+ cluster=CLUSTER,
+ cluster_name=CLUSTER_NAME,
+ update_mask="update-mask",
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_template_client"))
+ async def test_create_workflow_template(self, mock_client):
+ template = {"test": "test"}
+ parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}'
+ await self.hook.create_workflow_template(
+ region=GCP_LOCATION, template=template, project_id=GCP_PROJECT
+ )
+ mock_client.return_value.create_workflow_template.assert_called_once_with(
+ request=dict(parent=parent, template=template), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_template_client"))
+ async def test_instantiate_workflow_template(self, mock_client):
+ template_name = "template_name"
+ name = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}/workflowTemplates/{template_name}'
+ await self.hook.instantiate_workflow_template(
+ region=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT
+ )
+ mock_client.return_value.instantiate_workflow_template.assert_called_once_with(
+ request=dict(name=name, version=None, parameters=None, request_id=None),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_template_client"))
+ def test_instantiate_workflow_template_missing_region(self, mock_client):
+ with pytest.raises(TypeError):
+ self.hook.instantiate_workflow_template(template_name="template_name", project_id=GCP_PROJECT)
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_template_client"))
+ async def test_instantiate_inline_workflow_template(self, mock_client):
+ template = {"test": "test"}
+ parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}'
+ await self.hook.instantiate_inline_workflow_template(
+ region=GCP_LOCATION, template=template, project_id=GCP_PROJECT
+ )
+ mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with(
+ request=dict(parent=parent, template=template, request_id=None),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_template_client"))
+ def test_instantiate_inline_workflow_template_missing_region(self, mock_client):
+ with pytest.raises(TypeError):
+ self.hook.instantiate_inline_workflow_template(template={"test": "test"}, project_id=GCP_PROJECT)
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_job_client"))
+ async def test_get_job(self, mock_client):
+ await self.hook.get_job(region=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT)
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.get_job.assert_called_once_with(
+ request=dict(
+ region=GCP_LOCATION,
+ job_id=JOB_ID,
+ project_id=GCP_PROJECT,
+ ),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_job_client"))
+ def test_get_job_missing_region(self, mock_client):
+ with pytest.raises(TypeError):
+ self.hook.get_job(job_id=JOB_ID, project_id=GCP_PROJECT)
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_job_client"))
+ async def test_submit_job(self, mock_client):
+ await self.hook.submit_job(region=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT)
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.submit_job.assert_called_once_with(
+ request=dict(
+ region=GCP_LOCATION,
+ job=JOB,
+ project_id=GCP_PROJECT,
+ request_id=None,
+ ),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_job_client"))
+ def test_submit_job_missing_region(self, mock_client):
+ with pytest.raises(TypeError):
+ self.hook.submit_job(job=JOB, project_id=GCP_PROJECT)
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_job_client"))
+ async def test_cancel_job(self, mock_client):
+ await self.hook.cancel_job(region=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT)
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.cancel_job.assert_called_once_with(
+ request=dict(
+ region=GCP_LOCATION,
+ job_id=JOB_ID,
+ project_id=GCP_PROJECT,
+ ),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_batch_client"))
+ async def test_create_batch(self, mock_client):
+ await self.hook.create_batch(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ batch=BATCH,
+ batch_id=BATCH_ID,
+ )
+ mock_client.assert_called_once_with(GCP_LOCATION)
+ mock_client.return_value.create_batch.assert_called_once_with(
+ request=dict(
+ parent=PARENT.format(GCP_PROJECT, GCP_LOCATION),
+ batch=BATCH,
+ batch_id=BATCH_ID,
+ request_id=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_batch_client"))
+ async def test_delete_batch(self, mock_client):
+ await self.hook.delete_batch(
+ batch_id=BATCH_ID,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ )
+ mock_client.assert_called_once_with(GCP_LOCATION)
+ mock_client.return_value.delete_batch.assert_called_once_with(
+ request=dict(
+ name=BATCH_NAME.format(GCP_PROJECT, GCP_LOCATION, BATCH_ID),
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_batch_client"))
+ async def test_get_batch(self, mock_client):
+ await self.hook.get_batch(
+ batch_id=BATCH_ID,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ )
+ mock_client.assert_called_once_with(GCP_LOCATION)
+ mock_client.return_value.get_batch.assert_called_once_with(
+ request=dict(
+ name=BATCH_NAME.format(GCP_PROJECT, GCP_LOCATION, BATCH_ID),
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_batch_client"))
+ async def test_list_batches(self, mock_client):
+ await self.hook.list_batches(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ )
+ mock_client.assert_called_once_with(GCP_LOCATION)
+ mock_client.return_value.list_batches.assert_called_once_with(
+ request=dict(
+ parent=PARENT.format(GCP_PROJECT, GCP_LOCATION),
+ page_size=None,
+ page_token=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+
class TestDataProcJobBuilder(unittest.TestCase):
def setUp(self) -> None:
self.job_type = "test"
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index a7101ca558..ca3173b9e2 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -26,7 +26,7 @@ from google.api_core.retry import Retry
from google.cloud.dataproc_v1 import Batch
from airflow import AirflowException
-from airflow.exceptions import AirflowTaskTimeout
+from airflow.exceptions import AirflowTaskTimeout, TaskDeferred
from airflow.models import DAG, DagBag
from airflow.providers.google.cloud.operators.dataproc import (
DATAPROC_CLUSTER_LINK,
@@ -52,6 +52,8 @@ from airflow.providers.google.cloud.operators.dataproc import (
DataprocSubmitSparkSqlJobOperator,
DataprocUpdateClusterOperator,
)
+from airflow.providers.google.cloud.triggers.dataproc import DataprocBaseTrigger
+from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.timezone import datetime
from airflow.version import version as airflow_version
@@ -62,6 +64,7 @@ cluster_params = inspect.signature(ClusterGenerator.__init__).parameters
AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-")
DATAPROC_PATH = "airflow.providers.google.cloud.operators.dataproc.{}"
+DATAPROC_TRIGGERS_PATH = "airflow.providers.google.cloud.triggers.dataproc.{}"
TASK_ID = "task-id"
GCP_PROJECT = "test-project"
@@ -904,6 +907,49 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
)
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+ def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
+ job = {}
+ mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID
+
+ op = DataprocSubmitJobOperator(
+ task_id=TASK_ID,
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ job=job,
+ gcp_conn_id=GCP_CONN_ID,
+ retry=RETRY,
+ asynchronous=True,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ request_id=REQUEST_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(mock.MagicMock())
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.submit_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ job=job,
+ request_id=REQUEST_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ mock_hook.return_value.wait_for_job.assert_not_called()
+
+ self.mock_ti.xcom_push.assert_not_called()
+
+ assert isinstance(exc.value.trigger, DataprocBaseTrigger)
+ assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_on_kill(self, mock_hook):
job = {}
diff --git a/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py
new file mode 100644
index 0000000000..65c8621b05
--- /dev/null
+++ b/tests/system/providers/google/cloud/dataproc/example_dataproc_spark_deferrable.py
@@ -0,0 +1,111 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocSubmitJobOperator with spark job
+in deferrable mode.
+"""
+
+import os
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.google.cloud.operators.dataproc import (
+ DataprocCreateClusterOperator,
+ DataprocDeleteClusterOperator,
+ DataprocSubmitJobOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "dataproc_spark"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "")
+
+CLUSTER_NAME = f"cluster-dataproc-spark-{ENV_ID}"
+REGION = "europe-west1"
+ZONE = "europe-west1-b"
+
+
+# Cluster definition
+CLUSTER_CONFIG = {
+ "master_config": {
+ "num_instances": 1,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024},
+ },
+ "worker_config": {
+ "num_instances": 2,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024},
+ },
+}
+
+TIMEOUT = {"seconds": 1 * 24 * 60 * 60}
+
+# Jobs definitions
+# [START how_to_cloud_dataproc_spark_deferrable_config]
+SPARK_JOB = {
+ "reference": {"project_id": PROJECT_ID},
+ "placement": {"cluster_name": CLUSTER_NAME},
+ "spark_job": {
+ "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
+ "main_class": "org.apache.spark.examples.SparkPi",
+ },
+}
+# [END how_to_cloud_dataproc_spark_deferrable_config]
+
+
+with models.DAG(
+ DAG_ID,
+ schedule_interval='@once',
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "dataproc"],
+) as dag:
+ create_cluster = DataprocCreateClusterOperator(
+ task_id="create_cluster",
+ project_id=PROJECT_ID,
+ cluster_config=CLUSTER_CONFIG,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ )
+
+ spark_task = DataprocSubmitJobOperator(
+ task_id="spark_task", job=SPARK_JOB, region=REGION, project_id=PROJECT_ID, deferrable=True
+ )
+
+ delete_cluster = DataprocDeleteClusterOperator(
+ task_id="delete_cluster",
+ project_id=PROJECT_ID,
+ cluster_name=CLUSTER_NAME,
+ region=REGION,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ create_cluster >> spark_task >> delete_cluster
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)