You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by "ahidalgob (via GitHub)" <gi...@apache.org> on 2023/08/07 19:41:34 UTC

[GitHub] [airflow] ahidalgob commented on a diff in pull request #32606: Add `CloudBatchHook` and operators

ahidalgob commented on code in PR #32606:
URL: https://github.com/apache/airflow/pull/32606#discussion_r1286265897


##########
tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py:
##########
@@ -0,0 +1,202 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that uses Google Cloud Batch Operators.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from google.cloud import batch_v1
+
+from airflow import models
+from airflow.operators.python import PythonOperator
+from airflow.providers.google.cloud.operators.cloud_batch import (
+    CloudBatchDeleteJobOperator,
+    CloudBatchListJobsOperator,
+    CloudBatchListTasksOperator,
+    CloudBatchSubmitJobOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "example_cloud_batch"
+
+region = "us-central1"
+job_name_prefix = "batch-system-test-job"
+job1_name = f"{job_name_prefix}1"
+job2_name = f"{job_name_prefix}2"
+
+submit1_task_name = "submit-job1"
+submit2_task_name = "submit-job2"
+
+delete1_task_name = "delete-job1"
+delete2_task_name = "delete-job2"
+
+list_jobs_task_name = "list-jobs"
+list_tasks_task_name = "list-tasks"
+
+clean1_task_name = "clean-job1"
+clean2_task_name = "clean-job2"
+
+
+def _assert_jobs(ti):
+    job_names = ti.xcom_pull(task_ids=[list_jobs_task_name], key="return_value")
+    job_names_str = job_names[0][0]["name"].split("/")[-1] + " " + job_names[0][1]["name"].split("/")[-1]
+    assert job1_name in job_names_str
+    assert job2_name in job_names_str
+
+
+def _assert_tasks(ti):
+    tasks_names = ti.xcom_pull(task_ids=[list_tasks_task_name], key="return_value")
+    assert len(tasks_names[0]) == 2
+    assert "tasks/0" in tasks_names[0][0]["name"]
+    assert "tasks/1" in tasks_names[0][1]["name"]
+
+
+# [START howto_operator_batch_job_creation]
+def _create_job():
+    runnable = batch_v1.Runnable()
+    runnable.container = batch_v1.Runnable.Container()
+    runnable.container.image_uri = "gcr.io/google-containers/busybox"
+    runnable.container.entrypoint = "/bin/sh"
+    runnable.container.commands = [
+        "-c",
+        "echo Hello world! This is task ${BATCH_TASK_INDEX}.\
+          This job has a total of ${BATCH_TASK_COUNT} tasks.",
+    ]
+
+    task = batch_v1.TaskSpec()
+    task.runnables = [runnable]
+
+    resources = batch_v1.ComputeResource()
+    resources.cpu_milli = 2000
+    resources.memory_mib = 16
+    task.compute_resource = resources
+    task.max_retry_count = 2
+
+    group = batch_v1.TaskGroup()
+    group.task_count = 2
+    group.task_spec = task
+    policy = batch_v1.AllocationPolicy.InstancePolicy()
+    policy.machine_type = "e2-standard-4"
+    instances = batch_v1.AllocationPolicy.InstancePolicyOrTemplate()
+    instances.policy = policy
+    allocation_policy = batch_v1.AllocationPolicy()
+    allocation_policy.instances = [instances]
+
+    job = batch_v1.Job()
+    job.task_groups = [group]
+    job.allocation_policy = allocation_policy
+    job.labels = {"env": "testing", "type": "container"}
+
+    job.logs_policy = batch_v1.LogsPolicy()
+    job.logs_policy.destination = batch_v1.LogsPolicy.Destination.CLOUD_LOGGING
+
+    return job
+
+
+# [END howto_operator_batch_job_creation]
+
+
+with models.DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+    tags=["example"],
+) as dag:
+
+    # [START howto_operator_batch_submit_job]
+    submit1 = CloudBatchSubmitJobOperator(
+        task_id=submit1_task_name,
+        project_id=PROJECT_ID,
+        region=region,
+        job_name=job1_name,
+        job=_create_job(),
+        dag=dag,
+        deferrable=False,
+    )
+    # [END howto_operator_batch_submit_job]
+
+    # [START howto_operator_batch_submit_job_deferrable_mode]
+    submit2 = CloudBatchSubmitJobOperator(
+        task_id=submit2_task_name,
+        project_id=PROJECT_ID,
+        region=region,
+        job_name=job2_name,
+        job=batch_v1.Job.to_dict(_create_job()),
+        dag=dag,
+        deferrable=True,
+    )
+    # [END howto_operator_batch_submit_job_deferrable_mode]
+
+    # [START howto_operator_batch_list_tasks]
+    list_tasks = CloudBatchListTasksOperator(
+        task_id=list_tasks_task_name, project_id=PROJECT_ID, region=region, job_name=job1_name, dag=dag
+    )
+    # [END howto_operator_batch_list_tasks]
+
+    assert_tasks = PythonOperator(task_id="assert-tasks", python_callable=_assert_tasks, dag=dag)
+
+    # [START howto_operator_batch_list_jobs]
+    list_jobs = CloudBatchListJobsOperator(
+        task_id=list_jobs_task_name,
+        project_id=PROJECT_ID,
+        region=region,
+        limit=2,
+        filter=f"name:projects/{PROJECT_ID}/locations/{region}/jobs/{job_name_prefix}*",
+        dag=dag,
+    )
+    # [END howto_operator_batch_list_jobs]
+
+    get_name = PythonOperator(task_id="assert-jobs", python_callable=_assert_jobs, dag=dag)
+
+    # [START howto_operator_delete_job]
+    delete_job1 = CloudBatchDeleteJobOperator(
+        task_id="delete-job1",
+        project_id=PROJECT_ID,
+        region=region,
+        job_name=job1_name,
+        dag=dag,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+    # [END howto_operator_delete_job]
+
+    delete_job2 = CloudBatchDeleteJobOperator(
+        task_id="delete-job2",
+        project_id=PROJECT_ID,
+        region=region,
+        job_name=job2_name,
+        dag=dag,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    ((submit1, submit2) >> list_tasks >> assert_tasks >> list_jobs >> get_name >> (delete_job1, delete_job2))

Review Comment:
   Most code uses [ ] instead of ( )



##########
airflow/config_templates/default_airflow.cfg:
##########
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-

Review Comment:
   :smiley: 



##########
airflow/providers/google/cloud/hooks/cloud_batch.py:
##########
@@ -0,0 +1,215 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import itertools
+import json
+from time import sleep
+from typing import Iterable, Sequence
+
+from google.api_core import operation  # type: ignore
+from google.cloud.batch import ListJobsRequest, ListTasksRequest
+from google.cloud.batch_v1 import (
+    BatchServiceAsyncClient,
+    BatchServiceClient,
+    CreateJobRequest,
+    Job,
+    JobStatus,
+    Task,
+)
+from google.cloud.batch_v1.services.batch_service import pagers
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.common.consts import CLIENT_INFO
+from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
+
+
+class CloudBatchHook(GoogleBaseHook):
+    """
+    Hook for the Google Cloud Batch service.
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account.
+    """
+
+    def __init__(
+        self,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+        if kwargs.get("delegate_to") is not None:
+            raise RuntimeError(
+                "The `delegate_to` parameter has been deprecated before and finally removed in this version"
+                " of Google Provider. You MUST convert it to `impersonate_chain`"
+            )
+        super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
+        self._client: BatchServiceClient | None = None
+
+    def get_conn(self) -> BatchServiceClient:
+        """
+        Retrieves connection to GCE Batch.
+        :return: BatchServiceClient.
+        """
+        if self._client is None:
+            self._client = BatchServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
+        return self._client
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def submit_build_job(

Review Comment:
   Why is it "build" job?



##########
airflow/providers/google/cloud/operators/cloud_batch.py:
##########
@@ -0,0 +1,298 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Sequence
+
+from google.api_core import operation  # type: ignore
+from google.cloud.batch_v1 import Job, Task
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_batch import CloudBatchHook
+from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
+from airflow.providers.google.cloud.triggers.cloud_batch import CloudBatchJobFinishedTrigger
+from airflow.utils.context import Context
+
+
+class CloudBatchSubmitJobOperator(GoogleCloudBaseOperator):
+    """
+    Submit a job and wait for its completion.
+
+    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the service belongs to.
+    :param job_name: Required. The name of the job to create.
+    :param job: Required. The job descriptor containing the configuration of the job to submit.
+    :param polling_period_seconds: Optional: Control the rate of the poll for the result of deferrable run.
+        By default, the trigger will poll every 10 seconds.
+    :param timeout: The timeout for this request.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    :param deferrable: Run operator in the deferrable mode
+
+    """
+
+    template_fields = ("project_id", "region", "gcp_conn_id", "impersonation_chain", "job_name")
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        job_name: str,
+        job: dict | Job,
+        polling_period_seconds: float = 10,
+        timeout_seconds: float | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.job_name = job_name
+        self.job = job
+        self.polling_period_seconds = polling_period_seconds
+        self.timeout_seconds = timeout_seconds
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.deferrable = deferrable
+        self.polling_period_seconds = polling_period_seconds
+
+    def execute(self, context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain)
+        job = hook.submit_build_job(job_name=self.job_name, job=self.job, region=self.region)
+
+        if not self.deferrable:
+            completed_job = hook.wait_for_job(
+                job_name=job.name,
+                polling_period_seconds=self.polling_period_seconds,
+                timeout=self.timeout_seconds,
+            )
+
+            return Job.to_dict(completed_job)
+
+        else:
+            self.defer(
+                trigger=CloudBatchJobFinishedTrigger(
+                    job_name=job.name,
+                    project_id=self.project_id,
+                    gcp_conn_id=self.gcp_conn_id,
+                    impersonation_chain=self.impersonation_chain,
+                    location=self.region,
+                    polling_period_seconds=self.polling_period_seconds,
+                    timeout=self.timeout_seconds,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context: Context, event: dict):
+        job_status = event["status"]
+        if job_status == "success":
+            hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain)
+            job = hook.get_job(job_name=event["job_name"])
+            return Job.to_dict(job)
+        else:
+            raise AirflowException(f"Unexpected error in the operation: {event['message']}")
+
+
+class CloudBatchDeleteJobOperator(GoogleCloudBaseOperator):
+    """
+    Deletes a job and wait for the operation to be completed.
+
+    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the service belongs to.
+    :param job_name: Required. The name of the job to be deleted.
+    :param timeout: The timeout for this request.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+
+    """
+
+    template_fields = ("project_id", "region", "gcp_conn_id", "impersonation_chain", "job_name")
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        job_name: str,
+        timeout: float | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.job_name = job_name
+        self.timeout = timeout
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain)
+
+        operation = hook.delete_job(job_name=self.job_name, region=self.region, project_id=self.project_id)
+
+        self._wait_for_operation(operation)
+
+    def _wait_for_operation(self, operation: operation.Operation):
+        try:
+            return operation.result(timeout=self.timeout)
+        except Exception:
+            error = operation.exception(timeout=self.timeout)
+            raise AirflowException(error)
+
+
+class CloudBatchListJobsOperator(GoogleCloudBaseOperator):
+    """
+    List Cloud Batch jobs.
+
+    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the service belongs to.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param filter: The filter based on which to list the jobs. If left empty, all the jobs are listed.
+    :param limit: The number of jobs to list. If left empty,
+        all the jobs matching the filter will be returned.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+
+    """
+
+    template_fields = (
+        "project_id",
+        "region",
+        "gcp_conn_id",
+        "impersonation_chain",
+    )
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        gcp_conn_id: str = "google_cloud_default",
+        filter: str | None = None,
+        limit: int | None = None,
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.filter = filter
+        self.limit = limit
+        if limit is not None and limit < 0:
+            raise AirflowException("The limit for the list jobs request should be greater or equal to zero")
+
+    def execute(self, context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain)
+
+        jobs_list = hook.list_jobs(
+            region=self.region, project_id=self.project_id, filter=self.filter, limit=self.limit
+        )
+
+        return [Job.to_dict(job) for job in jobs_list]
+
+
+class CloudBatchListTasksOperator(GoogleCloudBaseOperator):
+    """
+    List Cloud Batch tasks for a given job.
+
+    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the service belongs to.
+    :param job_name: Required. The name of the job for which to list tasks.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param filter: The filter based on which to list the jobs. If left empty, all the jobs are listed.
+    :param group_name: The name of the group that owns the task. By default it's `group0`.
+    :param limit: The number of tasks to list.
+        If left empty, all the tasks matching the filter will be returned.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+
+    """
+
+    template_fields = ("project_id", "region", "job_name", "gcp_conn_id", "impersonation_chain", "group_name")
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        job_name: str,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        group_name: str = "group0",
+        filter: str | None = None,
+        limit: int | None = None,
+        **kwargs,
+    ) -> None:
+
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.job_name = job_name
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.group_name = group_name
+        self.filter = filter
+        self.limit = limit
+        if limit is not None and limit < 0:
+            raise AirflowException("The limit for the list jobs request should be greater or equal to zero")
+
+    def execute(self, context):

Review Comment:
   Missing type hints



##########
docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst:
##########


Review Comment:
   In Google operators doc we include `/operators/_partials/prerequisite_tasks.rst` as prerequisites



##########
airflow/providers/google/cloud/triggers/cloud_batch.py:
##########
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator, Sequence
+
+from google.cloud.batch_v1 import Job, JobStatus
+
+from airflow.providers.google.cloud.hooks.cloud_batch import CloudBatchAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+DEFAULT_BATCH_LOCATION = "us-central1"
+
+
+class CloudBatchJobFinishedTrigger(BaseTrigger):
+    """Cloud Batch trigger to check if templated job has been finished.
+
+    :param job_name: Required. Name of the job.
+    :param project_id: Required. the Google Cloud project ID in which the job was started.
+    :param location: Optional. the location where job is executed. If set to None then
+        the value of DEFAULT_BATCH_LOCATION will be used
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional. Service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    :param poll_sleep: Polling period in seconds to check for the status
+
+    """
+
+    def __init__(
+        self,
+        job_name: str,
+        project_id: str | None,
+        location: str = DEFAULT_BATCH_LOCATION,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        polling_period_seconds: float = 10,
+        timeout: float | None = None,
+    ):
+        super().__init__()
+        self.project_id = project_id
+        self.job_name = job_name
+        self.location = location
+        self.gcp_conn_id = gcp_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.timeout = timeout
+        self.impersonation_chain = impersonation_chain
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes class arguments and classpath."""
+        return (
+            "airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchJobFinishedTrigger",
+            {
+                "project_id": self.project_id,
+                "job_name": self.job_name,
+                "location": self.location,
+                "gcp_conn_id": self.gcp_conn_id,
+                "polling_period_seconds": self.polling_period_seconds,
+                "timeout": self.timeout,
+                "impersonation_chain": self.impersonation_chain,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """
+        Main loop of the class in where it is fetching the job status and yields certain Event.
+
+        If the job has status success then it yields TriggerEvent with success status, if job has
+        status failed - with error status and if the job is being deleted - with deleted status.
+        In any other case Trigger will wait for specified amount of time
+        stored in self.polling_period_seconds variable.
+        """
+        timeout = self.timeout
+        hook = self._get_async_hook()
+        while timeout is None or timeout > 0:
+
+            try:
+                job: Job = await hook.get_batch_job(job_name=self.job_name)
+
+                status: JobStatus.State = job.status.state
+                if status == JobStatus.State.SUCCEEDED:
+                    yield TriggerEvent(
+                        {
+                            "job_name": self.job_name,
+                            "status": "success",
+                            "message": "Job completed",
+                        }
+                    )
+                    return
+                elif status == JobStatus.State.FAILED:
+                    yield TriggerEvent(
+                        {
+                            "job_name": self.job_name,
+                            "status": "error",
+                            "message": f"Batch job with name {self.job_name} has failed its execution",
+                        }
+                    )
+                    return
+                elif status == JobStatus.State.DELETION_IN_PROGRESS:
+                    yield TriggerEvent(
+                        {
+                            "job_name": self.job_name,
+                            "status": "deleted",
+                            "message": f"Batch job with name {self.job_name} is being deleted",
+                        }
+                    )
+                    return
+                else:
+                    self.log.info("Job is still running...")
+                    self.log.info("Current job status is: %s", status)
+                    self.log.info("Sleeping for %s seconds.", self.polling_period_seconds)

Review Comment:
   Maybe better to join into only one log statement



##########
airflow/providers/google/cloud/operators/cloud_batch.py:
##########
@@ -0,0 +1,298 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Sequence
+
+from google.api_core import operation  # type: ignore
+from google.cloud.batch_v1 import Job, Task
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_batch import CloudBatchHook
+from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
+from airflow.providers.google.cloud.triggers.cloud_batch import CloudBatchJobFinishedTrigger
+from airflow.utils.context import Context
+
+
+class CloudBatchSubmitJobOperator(GoogleCloudBaseOperator):
+    """
+    Submit a job and wait for its completion.
+
+    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the service belongs to.
+    :param job_name: Required. The name of the job to create.
+    :param job: Required. The job descriptor containing the configuration of the job to submit.
+    :param polling_period_seconds: Optional: Control the rate of the poll for the result of deferrable run.
+        By default, the trigger will poll every 10 seconds.
+    :param timeout: The timeout for this request.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    :param deferrable: Run operator in the deferrable mode
+
+    """
+
+    template_fields = ("project_id", "region", "gcp_conn_id", "impersonation_chain", "job_name")
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        job_name: str,
+        job: dict | Job,
+        polling_period_seconds: float = 10,
+        timeout_seconds: float | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.job_name = job_name
+        self.job = job
+        self.polling_period_seconds = polling_period_seconds
+        self.timeout_seconds = timeout_seconds
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.deferrable = deferrable
+        self.polling_period_seconds = polling_period_seconds
+
+    def execute(self, context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain)
+        job = hook.submit_build_job(job_name=self.job_name, job=self.job, region=self.region)
+
+        if not self.deferrable:
+            completed_job = hook.wait_for_job(
+                job_name=job.name,
+                polling_period_seconds=self.polling_period_seconds,
+                timeout=self.timeout_seconds,
+            )
+
+            return Job.to_dict(completed_job)
+
+        else:
+            self.defer(
+                trigger=CloudBatchJobFinishedTrigger(
+                    job_name=job.name,
+                    project_id=self.project_id,
+                    gcp_conn_id=self.gcp_conn_id,
+                    impersonation_chain=self.impersonation_chain,
+                    location=self.region,
+                    polling_period_seconds=self.polling_period_seconds,
+                    timeout=self.timeout_seconds,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context: Context, event: dict):
+        job_status = event["status"]
+        if job_status == "success":
+            hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain)
+            job = hook.get_job(job_name=event["job_name"])
+            return Job.to_dict(job)
+        else:
+            raise AirflowException(f"Unexpected error in the operation: {event['message']}")
+
+
+class CloudBatchDeleteJobOperator(GoogleCloudBaseOperator):
+    """
+    Deletes a job and wait for the operation to be completed.
+
+    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the service belongs to.
+    :param job_name: Required. The name of the job to be deleted.
+    :param timeout: The timeout for this request.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+
+    """
+
+    template_fields = ("project_id", "region", "gcp_conn_id", "impersonation_chain", "job_name")
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        job_name: str,
+        timeout: float | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.job_name = job_name
+        self.timeout = timeout
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain)
+
+        operation = hook.delete_job(job_name=self.job_name, region=self.region, project_id=self.project_id)
+
+        self._wait_for_operation(operation)
+
+    def _wait_for_operation(self, operation: operation.Operation):
+        try:
+            return operation.result(timeout=self.timeout)
+        except Exception:
+            error = operation.exception(timeout=self.timeout)
+            raise AirflowException(error)
+
+
+class CloudBatchListJobsOperator(GoogleCloudBaseOperator):
+    """
+    List Cloud Batch jobs.
+
+    :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the service belongs to.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+    :param filter: The filter based on which to list the jobs. If left empty, all the jobs are listed.
+    :param limit: The number of jobs to list. If left empty,
+        all the jobs matching the filter will be returned.
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+
+    """
+
+    template_fields = (
+        "project_id",
+        "region",
+        "gcp_conn_id",
+        "impersonation_chain",
+    )
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        gcp_conn_id: str = "google_cloud_default",
+        filter: str | None = None,
+        limit: int | None = None,
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.filter = filter
+        self.limit = limit
+        if limit is not None and limit < 0:
+            raise AirflowException("The limit for the list jobs request should be greater or equal to zero")
+
+    def execute(self, context):

Review Comment:
   Missing type hints



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org