You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2022/02/17 16:49:00 UTC

[GitHub] [airflow] josh-fell commented on a change in pull request #21619: adding dataproc cluster start and stop operators

josh-fell commented on a change in pull request #21619:
URL: https://github.com/apache/airflow/pull/21619#discussion_r809255629



##########
File path: airflow/providers/google/cloud/operators/dataproc.py
##########
@@ -659,6 +680,268 @@ def execute(self, context: 'Context') -> dict:
         return Cluster.to_dict(cluster)
 
 
+class DataprocStartClusterOperator(BaseOperator):
+    """
+    Starts a cluster in a project.
+
+    :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated).
+    :param region: Required. The Cloud Dataproc region in which to handle the request (templated).
+    :param cluster_name: Required. The cluster name (templated).
+    :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
+        if cluster with specified UUID does not exist.
+    :param request_id: Optional. A unique id used to identify the request. If the server receives two
+        ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the
+        first ``google.longrunning.Operation`` created and stored in the backend is returned.
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :param metadata: Additional metadata that is provided to the method.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    """
+
+    template_fields: Sequence[str] = (
+        'project_id', 'region', 'cluster_name', 'impersonation_chain')
+
+    operator_extra_links = (DataprocClusterLink(),)
+
+    def __init__(
+        self,
+        *,
+        project_id: str,
+        region: str,
+        cluster_name: str,
+        cluster_uuid: Optional[str] = None,
+        request_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Sequence[Tuple[str, str]] = (),
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.cluster_name = cluster_name
+        self.cluster_uuid = cluster_uuid
+        self.request_id = request_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def _start_cluster(self, hook: DataprocHook):
+        hook.start_cluster(
+            project_id=self.project_id,
+            region=self.region,
+            cluster_name=self.cluster_name,
+            labels=self.labels,
+            cluster_config=self.cluster_config,
+            request_id=self.request_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+    def _get_cluster(self, hook: DataprocHook) -> Cluster:
+        return hook.get_cluster(
+            project_id=self.project_id,
+            region=self.region,
+            cluster_name=self.cluster_name,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+    def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
+        if cluster.status.state != cluster.status.State.ERROR:
+            return
+        self.log.info("Cluster is in ERROR state")
+        gcs_uri = hook.diagnose_cluster(
+            region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
+        )
+        self.log.info(
+            'Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri)
+        raise AirflowException("Cluster was started but is in ERROR state")
+
+    def _wait_for_cluster_in_starting_state(self, hook: DataprocHook) -> Cluster:
+        time_left = self.timeout
+        cluster = self._get_cluster(hook)
+        for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
+            if cluster.status.state != cluster.status.State.RUNNING:
+                break
+            if time_left < 0:
+                raise AirflowException(
+                    f"Cluster {self.cluster_name} is still CREATING state, aborting")
+            time.sleep(time_to_sleep)
+            time_left = time_left - time_to_sleep
+            cluster = self._get_cluster(hook)
+        return cluster
+
+    def execute(self, context: 'Context') -> None:
+        self.log.info('Starting cluster: %s', self.cluster_name)
+        hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,

Review comment:
       You could simply the implementation in the operator by having `hook` as an instance attribute of the operator. Using the instance attribute, you wouldn't have to pass the object around and the other methods like `_get_cluster()` and `_start_cluster()` could be replaced by simply calling `self.hook.get_cluster()` and `self.hook.start_cluster()` directly.

##########
File path: airflow/providers/google/cloud/operators/dataproc.py
##########
@@ -698,7 +981,8 @@ class DataprocScaleClusterOperator(BaseOperator):
         account from the list granting this role to the originating account (templated).
     """
 
-    template_fields: Sequence[str] = ('cluster_name', 'project_id', 'region', 'impersonation_chain')
+    template_fields: Sequence[str] = (
+        'cluster_name', 'project_id', 'region', 'impersonation_chain')

Review comment:
       This (and other) formatting changes look unrelated, and the static checks in the CI probably won't like them. You could run these static checks locally though. Check out [this guide](https://github.com/apache/airflow/blob/main/STATIC_CODE_CHECKS.rst) for some pointers.

##########
File path: airflow/providers/google/cloud/operators/dataproc.py
##########
@@ -659,6 +680,268 @@ def execute(self, context: 'Context') -> dict:
         return Cluster.to_dict(cluster)
 
 
+class DataprocStartClusterOperator(BaseOperator):
+    """
+    Starts a cluster in a project.
+
+    :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated).
+    :param region: Required. The Cloud Dataproc region in which to handle the request (templated).
+    :param cluster_name: Required. The cluster name (templated).
+    :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
+        if cluster with specified UUID does not exist.
+    :param request_id: Optional. A unique id used to identify the request. If the server receives two
+        ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the
+        first ``google.longrunning.Operation`` created and stored in the backend is returned.
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :param metadata: Additional metadata that is provided to the method.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    """
+
+    template_fields: Sequence[str] = (
+        'project_id', 'region', 'cluster_name', 'impersonation_chain')
+
+    operator_extra_links = (DataprocClusterLink(),)
+
+    def __init__(
+        self,
+        *,
+        project_id: str,
+        region: str,
+        cluster_name: str,
+        cluster_uuid: Optional[str] = None,
+        request_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Sequence[Tuple[str, str]] = (),
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.cluster_name = cluster_name
+        self.cluster_uuid = cluster_uuid
+        self.request_id = request_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def _start_cluster(self, hook: DataprocHook):
+        hook.start_cluster(
+            project_id=self.project_id,
+            region=self.region,
+            cluster_name=self.cluster_name,
+            labels=self.labels,
+            cluster_config=self.cluster_config,
+            request_id=self.request_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+    def _get_cluster(self, hook: DataprocHook) -> Cluster:
+        return hook.get_cluster(
+            project_id=self.project_id,
+            region=self.region,
+            cluster_name=self.cluster_name,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+    def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
+        if cluster.status.state != cluster.status.State.ERROR:
+            return
+        self.log.info("Cluster is in ERROR state")
+        gcs_uri = hook.diagnose_cluster(
+            region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
+        )
+        self.log.info(
+            'Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri)
+        raise AirflowException("Cluster was started but is in ERROR state")
+
+    def _wait_for_cluster_in_starting_state(self, hook: DataprocHook) -> Cluster:
+        time_left = self.timeout
+        cluster = self._get_cluster(hook)
+        for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
+            if cluster.status.state != cluster.status.State.RUNNING:
+                break
+            if time_left < 0:
+                raise AirflowException(
+                    f"Cluster {self.cluster_name} is still CREATING state, aborting")
+            time.sleep(time_to_sleep)
+            time_left = time_left - time_to_sleep
+            cluster = self._get_cluster(hook)
+        return cluster
+
+    def execute(self, context: 'Context') -> None:
+        self.log.info('Starting cluster: %s', self.cluster_name)
+        hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
+                            impersonation_chain=self.impersonation_chain)
+        # Save data required to display extra link no matter what the cluster status will be
+        self.xcom_push(
+            context,
+            key="cluster_conf",
+            value={
+                "cluster_name": self.cluster_name,
+                "region": self.region,
+                "project_id": self.project_id,
+            },
+        )
+        self._start_cluster(hook)
+        cluster = self._get_cluster(hook)
+        self._handle_error_state(hook, cluster)
+        if cluster.status.state == cluster.status.State.STARTING:
+            # Wait for cluster to be running
+            cluster = self._wait_for_cluster_in_starting_state(hook)
+            self._handle_error_state(hook, cluster)
+
+
+class DataprocStopClusterOperator(BaseOperator):
+    """
+    Stops a cluster in a project.
+
+    :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to (templated).
+    :param region: Required. The Cloud Dataproc region in which to handle the request (templated).
+    :param cluster_name: Required. The cluster name (templated).
+    :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
+        if cluster with specified UUID does not exist.
+    :param request_id: Optional. A unique id used to identify the request. If the server receives two
+        ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the
+        first ``google.longrunning.Operation`` created and stored in the backend is returned.
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :param metadata: Additional metadata that is provided to the method.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    """
+
+    template_fields: Sequence[str] = (
+        'project_id', 'region', 'cluster_name', 'impersonation_chain')
+
+    operator_extra_links = (DataprocClusterLink(),)
+
+    def __init__(
+        self,
+        *,
+        project_id: str,
+        region: str,
+        cluster_name: str,
+        cluster_uuid: Optional[str] = None,
+        request_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Sequence[Tuple[str, str]] = (),
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.cluster_name = cluster_name
+        self.cluster_uuid = cluster_uuid
+        self.request_id = request_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def _stop_cluster(self, hook: DataprocHook):
+        hook.stop_cluster(
+            project_id=self.project_id,
+            region=self.region,
+            cluster_name=self.cluster_name,
+            labels=self.labels,
+            cluster_config=self.cluster_config,
+            request_id=self.request_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+    def _get_cluster(self, hook: DataprocHook) -> Cluster:
+        return hook.get_cluster(
+            project_id=self.project_id,
+            region=self.region,
+            cluster_name=self.cluster_name,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+    def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
+        if cluster.status.state != cluster.status.State.ERROR:
+            return
+        self.log.info("Cluster is in ERROR state")
+        gcs_uri = hook.diagnose_cluster(
+            region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
+        )
+        self.log.info(
+            'Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri)
+        raise AirflowException("Cluster was stopped but is in ERROR state")
+
+    def _wait_for_cluster_in_stopting_state(self, hook: DataprocHook) -> Cluster:
+        time_left = self.timeout
+        cluster = self._get_cluster(hook)
+        for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
+            if cluster.status.state != cluster.status.State.STOPPED:
+                break
+            if time_left < 0:
+                raise AirflowException(
+                    f"Cluster {self.cluster_name} is still STOPPING state, aborting")
+            time.sleep(time_to_sleep)
+            time_left = time_left - time_to_sleep
+            cluster = self._get_cluster(hook)
+        return cluster
+
+    def execute(self, context: 'Context') -> None:
+        self.log.info('Stopping cluster: %s', self.cluster_name)
+        hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,

Review comment:
       Same comment here re: using an instance attribute for the hook object.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org