You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/12/19 14:17:05 UTC
(airflow) branch main updated: Implement deferrable mode for BeamRunJavaPipelineOperator (#36122)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 881d88b4da Implement deferrable mode for BeamRunJavaPipelineOperator (#36122)
881d88b4da is described below
commit 881d88b4da90fbc053f9d911b80d1aa015a12e02
Author: max <42...@users.noreply.github.com>
AuthorDate: Tue Dec 19 15:16:57 2023 +0100
Implement deferrable mode for BeamRunJavaPipelineOperator (#36122)
---
airflow/providers/apache/beam/hooks/beam.py | 19 ++
airflow/providers/apache/beam/operators/beam.py | 163 +++++++----
airflow/providers/apache/beam/triggers/beam.py | 177 +++++++++++-
airflow/providers/google/cloud/hooks/dataflow.py | 44 ++-
.../operators/cloud/dataflow.rst | 8 +
tests/providers/apache/beam/hooks/test_beam.py | 44 +++
tests/providers/apache/beam/operators/test_beam.py | 302 +++++++++++++++------
tests/providers/apache/beam/triggers/test_beam.py | 154 +++++++++--
.../providers/google/cloud/hooks/test_dataflow.py | 29 +-
.../cloud/dataflow/example_dataflow_native_java.py | 35 ++-
10 files changed, 793 insertions(+), 182 deletions(-)
diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py
index efea53560b..29ecfa4651 100644
--- a/airflow/providers/apache/beam/hooks/beam.py
+++ b/airflow/providers/apache/beam/hooks/beam.py
@@ -508,6 +508,25 @@ class BeamAsyncHook(BeamHook):
)
return return_code
+ async def start_java_pipeline_async(self, variables: dict, jar: str, job_class: str | None = None):
+ """
+ Start Apache Beam Java pipeline.
+
+ :param variables: Variables passed to the job.
+ :param jar: Name of the jar for the pipeline.
+ :param job_class: Name of the java class for the pipeline.
+ :return: Beam command execution return code.
+ """
+ if "labels" in variables:
+ variables["labels"] = json.dumps(variables["labels"], separators=(",", ":"))
+
+ command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar]
+ return_code = await self.start_pipeline_async(
+ variables=variables,
+ command_prefix=command_prefix,
+ )
+ return return_code
+
async def start_pipeline_async(
self,
variables: dict,
diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py
index 876f47fa2d..bf75f3caa7 100644
--- a/airflow/providers/apache/beam/operators/beam.py
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -34,7 +34,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
-from airflow.providers.apache.beam.triggers.beam import BeamPipelineTrigger
+from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger
from airflow.providers.google.cloud.hooks.dataflow import (
DataflowHook,
process_line_and_extract_dataflow_job_id_callback,
@@ -239,6 +239,22 @@ class BeamBasePipelineOperator(BaseOperator, BeamDataflowMixin, ABC):
check_job_status_callback,
)
+ def execute_complete(self, context: Context, event: dict[str, Any]):
+ """
+ Execute when the trigger fires - returns immediately.
+
+ Relies on trigger to throw an exception, otherwise it assumes execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info(
+ "%s completed with response %s ",
+ self.task_id,
+ event["message"],
+ )
+ return {"dataflow_job_id": self.dataflow_job_id}
+
class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
"""
@@ -323,7 +339,7 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
self.deferrable = deferrable
def execute(self, context: Context):
- """Execute the Apache Beam Pipeline."""
+ """Execute the Apache Beam Python Pipeline."""
(
self.is_dataflow,
self.dataflow_job_name,
@@ -408,7 +424,7 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
)
with self.dataflow_hook.provide_authorized_gcloud():
self.defer(
- trigger=BeamPipelineTrigger(
+ trigger=BeamPythonPipelineTrigger(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
@@ -421,7 +437,7 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
)
else:
self.defer(
- trigger=BeamPipelineTrigger(
+ trigger=BeamPythonPipelineTrigger(
variables=self.snake_case_pipeline_options,
py_file=self.py_file,
py_options=self.py_options,
@@ -433,22 +449,6 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
method_name="execute_complete",
)
- def execute_complete(self, context: Context, event: dict[str, Any]):
- """
- Execute when the trigger fires - returns immediately.
-
- Relies on trigger to throw an exception, otherwise it assumes execution was
- successful.
- """
- if event["status"] == "error":
- raise AirflowException(event["message"])
- self.log.info(
- "%s completed with response %s ",
- self.task_id,
- event["message"],
- )
- return {"dataflow_job_id": self.dataflow_job_id}
-
def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id)
@@ -509,6 +509,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
pipeline_options: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
dataflow_config: DataflowConfiguration | dict | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(
@@ -521,61 +522,55 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
)
self.jar = jar
self.job_class = job_class
+ self.deferrable = deferrable
def execute(self, context: Context):
- """Execute the Apache Beam Pipeline."""
+ """Execute the Apache Beam Python Pipeline."""
(
- is_dataflow,
- dataflow_job_name,
- pipeline_options,
- process_line_callback,
+ self.is_dataflow,
+ self.dataflow_job_name,
+ self.pipeline_options,
+ self.process_line_callback,
_,
) = self._init_pipeline_options()
-
if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")
+ if self.deferrable:
+ asyncio.run(self.execute_async(context))
+ else:
+ return self.execute_sync(context)
+ def execute_sync(self, context: Context):
+ """Execute the Apache Beam Pipeline."""
with ExitStack() as exit_stack:
if self.jar.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id)
tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.jar))
self.jar = tmp_gcs_file.name
- if is_dataflow and self.dataflow_hook:
- is_running = False
- if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob:
- is_running = (
- # The reason for disable=no-value-for-parameter is that project_id parameter is
- # required but here is not passed, moreover it cannot be passed here.
- # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
- # fallback project_id value from variables and raise error if project_id is
- # defined both in variables and as parameter (here is already defined in variables)
- self.dataflow_hook.is_job_dataflow_running(
- name=self.dataflow_config.job_name,
- variables=pipeline_options,
- )
+ if self.is_dataflow and self.dataflow_hook:
+ is_running = self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun
+ while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
+ # The reason for disable=no-value-for-parameter is that project_id parameter is
+ # required but here is not passed, moreover it cannot be passed here.
+ # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
+ # fallback project_id value from variables and raise error if project_id is
+ # defined both in variables and as parameter (here is already defined in variables)
+ is_running = self.dataflow_hook.is_job_dataflow_running(
+ name=self.dataflow_config.job_name,
+ variables=self.pipeline_options,
)
- while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
- # The reason for disable=no-value-for-parameter is that project_id parameter is
- # required but here is not passed, moreover it cannot be passed here.
- # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
- # fallback project_id value from variables and raise error if project_id is
- # defined both in variables and as parameter (here is already defined in variables)
-
- is_running = self.dataflow_hook.is_job_dataflow_running(
- name=self.dataflow_config.job_name,
- variables=pipeline_options,
- )
+
if not is_running:
- pipeline_options["jobName"] = dataflow_job_name
+ self.pipeline_options["jobName"] = self.dataflow_job_name
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_java_pipeline(
- variables=pipeline_options,
+ variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
- process_line_callback=process_line_callback,
+ process_line_callback=self.process_line_callback,
)
- if dataflow_job_name and self.dataflow_config.location:
+ if self.dataflow_job_name and self.dataflow_config.location:
multiple_jobs = self.dataflow_config.multiple_jobs or False
DataflowJobLink.persist(
self,
@@ -585,7 +580,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
self.dataflow_job_id,
)
self.dataflow_hook.wait_for_done(
- job_name=dataflow_job_name,
+ job_name=self.dataflow_job_name,
location=self.dataflow_config.location,
job_id=self.dataflow_job_id,
multiple_jobs=multiple_jobs,
@@ -594,11 +589,65 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
return {"dataflow_job_id": self.dataflow_job_id}
else:
self.beam_hook.start_java_pipeline(
- variables=pipeline_options,
+ variables=self.pipeline_options,
jar=self.jar,
job_class=self.job_class,
- process_line_callback=process_line_callback,
+ process_line_callback=self.process_line_callback,
+ )
+
+ async def execute_async(self, context: Context):
+ # Creating a new event loop to manage I/O operations asynchronously
+ loop = asyncio.get_event_loop()
+ if self.jar.lower().startswith("gs://"):
+ gcs_hook = GCSHook(self.gcp_conn_id)
+ # Running synchronous `enter_context()` method in a separate
+ # thread using the default executor `None`. The `run_in_executor()` function returns the
+ # file object, which is created using gcs function `provide_file()`, asynchronously.
+ # This means we can perform asynchronous operations with this file.
+ create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar)
+ tmp_gcs_file: IO[str] = await loop.run_in_executor(
+ None, contextlib.ExitStack().enter_context, create_tmp_file_call
+ )
+ self.jar = tmp_gcs_file.name
+
+ if self.is_dataflow and self.dataflow_hook:
+ DataflowJobLink.persist(
+ self,
+ context,
+ self.dataflow_config.project_id,
+ self.dataflow_config.location,
+ self.dataflow_job_id,
+ )
+ with self.dataflow_hook.provide_authorized_gcloud():
+ self.pipeline_options["jobName"] = self.dataflow_job_name
+ self.defer(
+ trigger=BeamJavaPipelineTrigger(
+ variables=self.pipeline_options,
+ jar=self.jar,
+ job_class=self.job_class,
+ runner=self.runner,
+ check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun,
+ project_id=self.dataflow_config.project_id,
+ location=self.dataflow_config.location,
+ job_name=self.dataflow_job_name,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.dataflow_config.impersonation_chain,
+ poll_sleep=self.dataflow_config.poll_sleep,
+ cancel_timeout=self.dataflow_config.cancel_timeout,
+ ),
+ method_name="execute_complete",
)
+ else:
+ self.defer(
+ trigger=BeamJavaPipelineTrigger(
+ variables=self.pipeline_options,
+ jar=self.jar,
+ job_class=self.job_class,
+ runner=self.runner,
+ check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun,
+ ),
+ method_name="execute_complete",
+ )
def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
diff --git a/airflow/providers/apache/beam/triggers/beam.py b/airflow/providers/apache/beam/triggers/beam.py
index 0d201cd8c9..4caa46d1e5 100644
--- a/airflow/providers/apache/beam/triggers/beam.py
+++ b/airflow/providers/apache/beam/triggers/beam.py
@@ -16,15 +16,33 @@
# under the License.
from __future__ import annotations
-from typing import Any, AsyncIterator
+import asyncio
+import warnings
+from typing import Any, AsyncIterator, Sequence
+from google.cloud.dataflow_v1beta3 import ListJobsRequest
+
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook
+from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
-class BeamPipelineTrigger(BaseTrigger):
+class BeamPipelineBaseTrigger(BaseTrigger):
+ """Base class for Beam Pipeline Triggers."""
+
+ @staticmethod
+ def _get_async_hook(*args, **kwargs) -> BeamAsyncHook:
+ return BeamAsyncHook(*args, **kwargs)
+
+ @staticmethod
+ def _get_sync_dataflow_hook(**kwargs) -> AsyncDataflowHook:
+ return AsyncDataflowHook(**kwargs)
+
+
+class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger):
"""
- Trigger to perform checking the pipeline status until it reaches terminate state.
+ Trigger to perform checking the Python pipeline status until it reaches terminate state.
:param variables: Variables passed to the pipeline.
:param py_file: Path to the python file to execute.
@@ -35,12 +53,10 @@ class BeamPipelineTrigger(BaseTrigger):
:param py_requirements: Additional python package(s) to install.
If a value is passed to this parameter, a new virtual environment has been created with
additional packages installed.
-
You could also install the apache-beam package if it is not installed on your system, or you want
to use a different version.
:param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
See virtualenv documentation for more information.
-
This option is only relevant if the ``py_requirements`` parameter is not None.
:param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used.
Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner.
@@ -68,9 +84,9 @@ class BeamPipelineTrigger(BaseTrigger):
self.runner = runner
def serialize(self) -> tuple[str, dict[str, Any]]:
- """Serialize BeamPipelineTrigger arguments and classpath."""
+ """Serialize BeamPythonPipelineTrigger arguments and classpath."""
return (
- "airflow.providers.apache.beam.triggers.beam.BeamPipelineTrigger",
+ "airflow.providers.apache.beam.triggers.beam.BeamPythonPipelineTrigger",
{
"variables": self.variables,
"py_file": self.py_file,
@@ -84,7 +100,7 @@ class BeamPipelineTrigger(BaseTrigger):
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pipeline status and yields a TriggerEvent."""
- hook = self._get_async_hook()
+ hook = self._get_async_hook(runner=self.runner)
try:
return_code = await hook.start_python_pipeline_async(
variables=self.variables,
@@ -109,5 +125,146 @@ class BeamPipelineTrigger(BaseTrigger):
yield TriggerEvent({"status": "error", "message": "Operation failed"})
return
- def _get_async_hook(self) -> BeamAsyncHook:
- return BeamAsyncHook(runner=self.runner)
+
+class BeamJavaPipelineTrigger(BeamPipelineBaseTrigger):
+ """
+ Trigger to perform checking the Java pipeline status until it reaches terminate state.
+
+ :param variables: Variables passed to the job.
+ :param jar: Name of the jar for the pipeline.
+ :param job_class: Optional. Name of the java class for the pipeline.
+ :param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used.
+ Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner.
+ See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType`
+ See: https://beam.apache.org/documentation/runners/capability-matrix/
+ :param check_if_running: Optional. Before running job, validate that a previous run is not in process.
+ :param project_id: Optional. The Google Cloud project ID in which to start a job.
+ :param location: Optional. Job location.
+ :param job_name: Optional. The 'jobName' to use when executing the Dataflow job.
+ :param gcp_conn_id: Optional. The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional. GCP service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ :param poll_sleep: Optional. The time in seconds to sleep between polling GCP for the dataflow job status.
+ Default value is 10s.
+ :param cancel_timeout: Optional. How long (in seconds) operator should wait for the pipeline to be
+ successfully cancelled when task is being killed. Default value is 300s.
+ """
+
+ def __init__(
+ self,
+ variables: dict,
+ jar: str,
+ job_class: str | None = None,
+ runner: str = "DirectRunner",
+ check_if_running: bool = False,
+ project_id: str | None = None,
+ location: str | None = None,
+ job_name: str | None = None,
+ gcp_conn_id: str | None = None,
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_sleep: int = 10,
+ cancel_timeout: int | None = None,
+ ):
+ super().__init__()
+ self.variables = variables
+ self.jar = jar
+ self.job_class = job_class
+ self.runner = runner
+ self.check_if_running = check_if_running
+ self.project_id = project_id
+ self.location = location
+ self.job_name = job_name
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_sleep = poll_sleep
+ self.cancel_timeout = cancel_timeout
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize BeamJavaPipelineTrigger arguments and classpath."""
+ return (
+ "airflow.providers.apache.beam.triggers.beam.BeamJavaPipelineTrigger",
+ {
+ "variables": self.variables,
+ "jar": self.jar,
+ "job_class": self.job_class,
+ "runner": self.runner,
+ "check_if_running": self.check_if_running,
+ "project_id": self.project_id,
+ "location": self.location,
+ "job_name": self.job_name,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_sleep": self.poll_sleep,
+ "cancel_timeout": self.cancel_timeout,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
+ """Get current Java pipeline status and yields a TriggerEvent."""
+ hook = self._get_async_hook(runner=self.runner)
+
+ return_code = 0
+ if self.check_if_running:
+ dataflow_hook = self._get_sync_dataflow_hook(
+ gcp_conn_id=self.gcp_conn_id,
+ poll_sleep=self.poll_sleep,
+ impersonation_chain=self.impersonation_chain,
+ cancel_timeout=self.cancel_timeout,
+ )
+ is_running = True
+ while is_running:
+ try:
+ jobs = await dataflow_hook.list_jobs(
+ project_id=self.project_id,
+ location=self.location,
+ jobs_filter=ListJobsRequest.Filter.ACTIVE,
+ )
+ is_running = bool([job async for job in jobs if job.name == self.job_name])
+ except Exception as e:
+ self.log.exception(f"Exception occurred while requesting jobs with name {self.job_name}")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
+ if is_running:
+ await asyncio.sleep(self.poll_sleep)
+ try:
+ return_code = await hook.start_java_pipeline_async(
+ variables=self.variables, jar=self.jar, job_class=self.job_class
+ )
+ except Exception as e:
+ self.log.exception("Exception occurred while starting the Java pipeline")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+
+ if return_code == 0:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": "Pipeline has finished SUCCESSFULLY",
+ }
+ )
+ else:
+ yield TriggerEvent({"status": "error", "message": "Operation failed"})
+ return
+
+
+class BeamPipelineTrigger(BeamPythonPipelineTrigger):
+ """
+ Trigger to perform checking the Python pipeline status until it reaches terminate state.
+
+ This class is deprecated. Please use
+ :class:`airflow.providers.apache.beam.triggers.beam.BeamPythonPipelineTrigger`
+ instead.
+ """
+
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ "`BeamPipelineTrigger` is deprecated. Please use `BeamPythonPipelineTrigger`.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ super().__init__(*args, **kwargs)
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py
index 6c051ce619..486f9225a2 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -27,9 +27,10 @@ import time
import uuid
import warnings
from copy import deepcopy
-from typing import Any, Callable, Generator, Sequence, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, TypeVar, cast
from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView
+from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest
from googleapiclient.discovery import build
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
@@ -42,6 +43,10 @@ from airflow.providers.google.common.hooks.base_google import (
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timeout import timeout
+if TYPE_CHECKING:
+ from google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.pagers import ListJobsAsyncPager
+
+
# This is the default location
# https://cloud.google.com/dataflow/pipelines/specifying-exec-params
DEFAULT_DATAFLOW_LOCATION = "us-central1"
@@ -219,7 +224,7 @@ class _DataflowJobsController(LoggingMixin):
def is_job_running(self) -> bool:
"""
- Helper method to check if jos is still running in dataflow.
+ Helper method to check if job is still running in dataflow.
:return: True if job is running.
"""
@@ -1313,3 +1318,38 @@ class AsyncDataflowHook(GoogleBaseAsyncHook):
)
state = job.current_state
return state
+
+ async def list_jobs(
+ self,
+ jobs_filter: int | None = None,
+ project_id: str | None = PROVIDE_PROJECT_ID,
+ location: str | None = DEFAULT_DATAFLOW_LOCATION,
+ page_size: int | None = None,
+ page_token: str | None = None,
+ ) -> ListJobsAsyncPager:
+ """List jobs.
+
+ For detail see:
+ https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.ListJobsRequest
+
+ :param jobs_filter: Optional. This field filters out and returns jobs in the specified job state.
+ :param project_id: Optional. The Google Cloud project ID in which to start a job.
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
+ :param location: Optional. The location of the Dataflow job (for example europe-west1).
+ :param page_size: Optional. If there are many jobs, limit response to at most this many.
+ :param page_token: Optional. Set this to the 'next_page_token' field of a previous response to request
+ additional results in a long list.
+ """
+ project_id = project_id or (await self.get_project_id())
+ client = await self.initialize_client(JobsV1Beta3AsyncClient)
+ request: ListJobsRequest = ListJobsRequest(
+ {
+ "project_id": project_id,
+ "location": location,
+ "filter": jobs_filter,
+ "page_size": page_size,
+ "page_token": page_token,
+ }
+ )
+ page_result: ListJobsAsyncPager = await client.list_jobs(request=request)
+ return page_result
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
index 3851cb8fa8..905bb5790b 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
@@ -87,6 +87,14 @@ Here is an example of creating and running a pipeline in Java with jar stored on
:start-after: [START howto_operator_start_java_job_jar_on_gcs]
:end-before: [END howto_operator_start_java_job_jar_on_gcs]
+Here is an example of creating and running a pipeline in Java with jar stored on GCS in deferrable mode:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_java_job_jar_on_gcs_deferrable]
+ :end-before: [END howto_operator_start_java_job_jar_on_gcs_deferrable]
+
Here is an example of creating and running a pipeline in Java with jar stored on local file system:
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py
diff --git a/tests/providers/apache/beam/hooks/test_beam.py b/tests/providers/apache/beam/hooks/test_beam.py
index ce33b683e2..10e2808a09 100644
--- a/tests/providers/apache/beam/hooks/test_beam.py
+++ b/tests/providers/apache/beam/hooks/test_beam.py
@@ -42,12 +42,19 @@ TEST_JOB_ID = "test-job-id"
GO_FILE = "/path/to/file.go"
DEFAULT_RUNNER = "DirectRunner"
BEAM_STRING = "airflow.providers.apache.beam.hooks.beam.{}"
+BEAM_VARIABLES = {"output": "gs://test/output", "labels": {"foo": "bar"}}
BEAM_VARIABLES_PY = {"output": "gs://test/output", "labels": {"foo": "bar"}}
BEAM_VARIABLES_JAVA = {
"output": "gs://test/output",
"labels": {"foo": "bar"},
}
+BEAM_VARIABLES_JAVA_STRING_LABELS = {
+ "output": "gs://test/output",
+ "labels": '{"foo":"bar"}',
+}
BEAM_VARIABLES_GO = {"output": "gs://test/output", "labels": {"foo": "bar"}}
+PIPELINE_COMMAND_PREFIX = ["a", "b", "c"]
+WORKING_DIRECTORY = "test_wd"
APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG = f""""\
Dataflow SDK version: 2.14.0
@@ -418,6 +425,25 @@ class TestBeamOptionsToArgs:
class TestBeamAsyncHook:
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.run_beam_command_async")
+ async def test_start_pipline_async(self, mock_runner):
+ expected_cmd = [
+ *PIPELINE_COMMAND_PREFIX,
+ f"--runner={DEFAULT_RUNNER}",
+ *beam_options_to_args(BEAM_VARIABLES),
+ ]
+ hook = BeamAsyncHook(runner=DEFAULT_RUNNER)
+ await hook.start_pipeline_async(
+ variables=BEAM_VARIABLES,
+ command_prefix=PIPELINE_COMMAND_PREFIX,
+ working_directory=WORKING_DIRECTORY,
+ )
+
+ mock_runner.assert_called_once_with(
+ cmd=expected_cmd, working_directory=WORKING_DIRECTORY, log=hook.log
+ )
+
@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.run_beam_command_async")
@mock.patch("airflow.providers.apache.beam.hooks.beam.BeamAsyncHook._create_tmp_dir")
@@ -583,3 +609,21 @@ class TestBeamAsyncHook:
)
mock_runner.assert_not_called()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "job_class, command_prefix",
+ [
+ (JOB_CLASS, ["java", "-cp", JAR_FILE, JOB_CLASS]),
+ (None, ["java", "-jar", JAR_FILE]),
+ ],
+ )
+ @mock.patch("airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.start_pipeline_async")
+ async def test_start_java_pipeline_async(self, mock_start_pipeline, job_class, command_prefix):
+ variables = copy.deepcopy(BEAM_VARIABLES_JAVA)
+ hook = BeamAsyncHook(runner=DEFAULT_RUNNER)
+ await hook.start_java_pipeline_async(variables=variables, jar=JAR_FILE, job_class=job_class)
+
+ mock_start_pipeline.assert_called_once_with(
+ variables=BEAM_VARIABLES_JAVA_STRING_LABELS, command_prefix=command_prefix
+ )
diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py
index 8b6f57cccc..538c2417a0 100644
--- a/tests/providers/apache/beam/operators/test_beam.py
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -24,11 +24,12 @@ import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.apache.beam.operators.beam import (
+ BeamBasePipelineOperator,
BeamRunGoPipelineOperator,
BeamRunJavaPipelineOperator,
BeamRunPythonPipelineOperator,
)
-from airflow.providers.apache.beam.triggers.beam import BeamPipelineTrigger
+from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger
from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration
from airflow.version import version
@@ -60,6 +61,35 @@ EXPECTED_ADDITIONAL_OPTIONS = {
"labels": {"foo": "bar", "airflow-version": TEST_VERSION},
}
TEST_IMPERSONATION_ACCOUNT = "test@impersonation.com"
+BEAM_OPERATOR_PATH = "airflow.providers.apache.beam.operators.beam.{}"
+
+
+class TestBeamBasePipelineOperator:
+ def setup_method(self):
+ self.operator = BeamBasePipelineOperator(
+ task_id=TASK_ID,
+ runner=DEFAULT_RUNNER,
+ )
+
+ def test_async_execute_should_throw_exception(self):
+ """Tests that an AirflowException is raised in case of error event"""
+
+ with pytest.raises(AirflowException):
+ self.operator.execute_complete(
+ context=mock.MagicMock(), event={"status": "error", "message": "test failure message"}
+ )
+
+ def test_async_execute_logging_should_execute_successfully(self):
+ """Asserts that logging occurs as expected"""
+
+ with mock.patch.object(self.operator.log, "info") as mock_log_info:
+ self.operator.execute_complete(
+ context=mock.MagicMock(),
+ event={"status": "success", "message": "Pipeline has finished SUCCESSFULLY"},
+ )
+ mock_log_info.assert_called_with(
+ "%s completed with response %s ", TASK_ID, "Pipeline has finished SUCCESSFULLY"
+ )
class TestBeamRunPythonPipelineOperator:
@@ -82,8 +112,8 @@ class TestBeamRunPythonPipelineOperator:
assert self.operator.default_pipeline_options == PY_DEFAULT_OPTIONS
assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_direct_runner(self, gcs_hook, beam_hook_mock):
"""Test BeamHook is created and the right args are passed to
start_python_workflow.
@@ -111,10 +141,10 @@ class TestBeamRunPythonPipelineOperator:
process_line_callback=None,
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock):
"""Test DataflowHook is created and the right args are passed to
start_python_dataflow.
@@ -164,10 +194,10 @@ class TestBeamRunPythonPipelineOperator:
)
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
self.operator.runner = "DataflowRunner"
dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
@@ -178,9 +208,9 @@ class TestBeamRunPythonPipelineOperator:
job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_on_kill_direct_runner(self, _, dataflow_mock, __):
dataflow_cancel_job = dataflow_mock.return_value.cancel_job
self.operator.execute(None)
@@ -207,8 +237,8 @@ class TestBeamRunJavaPipelineOperator:
assert self.operator.jar == JAR_FILE
assert self.operator.pipeline_options == ADDITIONAL_OPTIONS
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_direct_runner(self, gcs_hook, beam_hook_mock):
"""Test BeamHook is created and the right args are passed to
start_java_workflow.
@@ -226,10 +256,10 @@ class TestBeamRunJavaPipelineOperator:
process_line_callback=None,
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock):
"""Test DataflowHook is created and the right args are passed to
start_java_dataflow.
@@ -274,10 +304,10 @@ class TestBeamRunJavaPipelineOperator:
process_line_callback=mock.ANY,
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
self.operator.runner = "DataflowRunner"
dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
@@ -289,9 +319,9 @@ class TestBeamRunJavaPipelineOperator:
job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_on_kill_direct_runner(self, _, dataflow_mock, __):
dataflow_cancel_job = dataflow_mock.return_value.cancel_job
self.operator.execute(None)
@@ -386,8 +416,8 @@ class TestBeamRunGoPipelineOperator:
"tempfile.TemporaryDirectory",
return_value=MagicMock(__enter__=MagicMock(return_value="/tmp/apache-beam-go")),
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_direct_runner_with_gcs_go_file(self, gcs_hook, beam_hook_mock, _):
"""Test BeamHook is created and the right args are passed to
start_go_workflow.
@@ -413,8 +443,8 @@ class TestBeamRunGoPipelineOperator:
should_init_module=True,
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch("tempfile.TemporaryDirectory")
def test_exec_direct_runner_with_gcs_launcher_binary(
self, mock_tmp_dir, mock_beam_hook, mock_gcs_hook, tmp_path
@@ -468,7 +498,7 @@ class TestBeamRunGoPipelineOperator:
process_line_callback=None,
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch("airflow.providers.google.go_module_utils.init_module")
def test_exec_direct_runner_with_local_go_file(self, init_module, beam_hook_mock):
"""
@@ -490,7 +520,7 @@ class TestBeamRunGoPipelineOperator:
should_init_module=False,
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
def test_exec_direct_runner_with_local_launcher_binary(self, mock_beam_hook):
"""
Test start_go_pipeline_with_binary is called with a local launcher binary.
@@ -513,14 +543,14 @@ class TestBeamRunGoPipelineOperator:
process_line_callback=None,
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(
"tempfile.TemporaryDirectory",
return_value=MagicMock(__enter__=MagicMock(return_value="/tmp/apache-beam-go")),
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner_with_go_file(
self, gcs_hook, dataflow_hook_mock, beam_hook_mock, _, persist_link_mock
):
@@ -575,11 +605,11 @@ class TestBeamRunGoPipelineOperator:
)
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("tempfile.TemporaryDirectory")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("tempfile.TemporaryDirectory"))
def test_exec_dataflow_runner_with_launcher_binary_and_worker_binary(
self, mock_tmp_dir, mock_beam_hook, mock_gcs_hook, mock_dataflow_hook, mock_persist_link, tmp_path
):
@@ -672,10 +702,10 @@ class TestBeamRunGoPipelineOperator:
project_id=dataflow_config.project_id,
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
self.operator.runner = "DataflowRunner"
dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
@@ -686,9 +716,9 @@ class TestBeamRunGoPipelineOperator:
job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_on_kill_direct_runner(self, _, dataflow_mock, __):
dataflow_cancel_job = dataflow_mock.return_value.cancel_job
self.operator.execute(None)
@@ -717,59 +747,161 @@ class TestBeamRunPythonPipelineOperatorAsync:
assert self.operator.default_pipeline_options == DEFAULT_OPTIONS
assert self.operator.pipeline_options == EXPECTED_ADDITIONAL_OPTIONS
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_async_execute_should_execute_successfully(self, gcs_hook, beam_hook_mock):
"""
- Asserts that a task is deferred and the BeamPipelineTrigger will be fired
+ Asserts that a task is deferred and the BeamPythonPipelineTrigger will be fired
when the BeamRunPythonPipelineOperator is executed in deferrable mode when deferrable=True.
"""
with pytest.raises(TaskDeferred) as exc:
self.operator.execute(context=mock.MagicMock())
- assert isinstance(exc.value.trigger, BeamPipelineTrigger), "Trigger is not a BeamPipelineTrigger"
+ assert isinstance(
+ exc.value.trigger, BeamPythonPipelineTrigger
+ ), "Trigger is not a BeamPythonPipelineTrigger"
- def test_async_execute_should_throw_exception(self):
- """Tests that an AirflowException is raised in case of error event"""
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock):
+ """
+ Test BeamHook is created and the right args are passed to
+ start_python_workflow when executing direct runner.
+ """
+ gcs_provide_file = gcs_hook.return_value.provide_file
+ with pytest.raises(TaskDeferred):
+ self.operator.execute(context=mock.MagicMock())
+ beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
+ gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
- with pytest.raises(AirflowException):
- self.operator.execute_complete(
- context=mock.MagicMock(), event={"status": "error", "message": "test failure message"}
- )
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock):
+ """
+ Test DataflowHook is created and the right args are passed to
+ start_python_dataflow when executing Dataflow runner.
+ """
- def test_async_execute_logging_should_execute_successfully(self):
- """Asserts that logging occurs as expected"""
+ dataflow_config = DataflowConfiguration(impersonation_chain=TEST_IMPERSONATION_ACCOUNT)
+ self.operator.runner = "DataflowRunner"
+ self.operator.dataflow_config = dataflow_config
+ gcs_provide_file = gcs_hook.return_value.provide_file
+ magic_mock = mock.MagicMock()
+ with pytest.raises(TaskDeferred):
+ self.operator.execute(context=magic_mock)
- with mock.patch.object(self.operator.log, "info") as mock_log_info:
- self.operator.execute_complete(
- context=mock.MagicMock(),
- event={"status": "success", "message": "Pipeline has finished SUCCESSFULLY"},
- )
- mock_log_info.assert_called_with(
- "%s completed with response %s ", TASK_ID, "Pipeline has finished SUCCESSFULLY"
+ job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
+ dataflow_hook_mock.assert_called_once_with(
+ gcp_conn_id=dataflow_config.gcp_conn_id,
+ poll_sleep=dataflow_config.poll_sleep,
+ impersonation_chain=dataflow_config.impersonation_chain,
+ drain_pipeline=dataflow_config.drain_pipeline,
+ cancel_timeout=dataflow_config.cancel_timeout,
+ wait_until_finished=dataflow_config.wait_until_finished,
+ )
+ expected_options = {
+ "project": dataflow_hook_mock.return_value.project_id,
+ "job_name": job_name,
+ "staging_location": "gs://test/staging",
+ "output": "gs://test/output",
+ "labels": {"foo": "bar", "airflow-version": TEST_VERSION},
+ "region": "us-central1",
+ "impersonate_service_account": TEST_IMPERSONATION_ACCOUNT,
+ }
+ gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+ persist_link_mock.assert_called_once_with(
+ self.operator,
+ magic_mock,
+ expected_options["project"],
+ expected_options["region"],
+ self.operator.dataflow_job_id,
+ )
+ beam_hook_mock.return_value.start_python_pipeline.assert_not_called()
+ dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()
+
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
+ self.operator.runner = "DataflowRunner"
+ dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
+ with pytest.raises(TaskDeferred):
+ self.operator.execute(context=mock.MagicMock())
+ self.operator.dataflow_job_id = JOB_ID
+ self.operator.on_kill()
+ dataflow_cancel_job.assert_called_once_with(
+ job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ def test_on_kill_direct_runner(self, _, dataflow_mock, __):
+ dataflow_cancel_job = dataflow_mock.return_value.cancel_job
+ with pytest.raises(TaskDeferred):
+ self.operator.execute(mock.MagicMock())
+ self.operator.on_kill()
+ dataflow_cancel_job.assert_not_called()
+
+
+class TestBeamRunJavaPipelineOperatorAsync:
+ def setup_method(self):
+ self.operator = BeamRunJavaPipelineOperator(
+ task_id=TASK_ID,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ default_pipeline_options=DEFAULT_OPTIONS,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ deferrable=True,
+ )
+
+ def test_init(self):
+ """Test BeamRunJavaPipelineOperator instance is properly initialized."""
+ assert self.operator.task_id == TASK_ID
+ assert self.operator.jar == JAR_FILE
+ assert self.operator.runner == DEFAULT_RUNNER
+ assert self.operator.job_class == JOB_CLASS
+ assert self.operator.default_pipeline_options == DEFAULT_OPTIONS
+ assert self.operator.pipeline_options == ADDITIONAL_OPTIONS
+
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ def test_async_execute_should_execute_successfully(self, gcs_hook, beam_hook_mock):
+ """
+ Asserts that a task is deferred and the BeamJavaPipelineTrigger will be fired
+ when the BeamRunPythonPipelineOperator is executed in deferrable mode when deferrable=True.
+ """
+ with pytest.raises(TaskDeferred) as exc:
+ self.operator.execute(context=mock.MagicMock())
+
+ assert isinstance(
+ exc.value.trigger, BeamJavaPipelineTrigger
+ ), "Trigger is not a BeamPJavaPipelineTrigger"
+
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock):
"""
Test BeamHook is created and the right args are passed to
- start_python_workflow when executing direct runner.
+ start_java_pipeline when executing direct runner.
"""
gcs_provide_file = gcs_hook.return_value.provide_file
with pytest.raises(TaskDeferred):
self.operator.execute(context=mock.MagicMock())
beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
- gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+ gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock):
"""
Test DataflowHook is created and the right args are passed to
- start_python_dataflow when executing Dataflow runner.
+ start_java_pipeline when executing Dataflow runner.
"""
dataflow_config = DataflowConfiguration(impersonation_chain=TEST_IMPERSONATION_ACCOUNT)
@@ -798,7 +930,7 @@ class TestBeamRunPythonPipelineOperatorAsync:
"region": "us-central1",
"impersonate_service_account": TEST_IMPERSONATION_ACCOUNT,
}
- gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+ gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
persist_link_mock.assert_called_once_with(
self.operator,
magic_mock,
@@ -809,10 +941,10 @@ class TestBeamRunPythonPipelineOperatorAsync:
beam_hook_mock.return_value.start_python_pipeline.assert_not_called()
dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with()
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowJobLink.persist")
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___):
self.operator.runner = "DataflowRunner"
dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
@@ -824,9 +956,9 @@ class TestBeamRunPythonPipelineOperatorAsync:
job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
)
- @mock.patch("airflow.providers.apache.beam.operators.beam.BeamHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.DataflowHook")
- @mock.patch("airflow.providers.apache.beam.operators.beam.GCSHook")
+ @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
+ @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_on_kill_direct_runner(self, _, dataflow_mock, __):
dataflow_cancel_job = dataflow_mock.return_value.cancel_job
with pytest.raises(TaskDeferred):
diff --git a/tests/providers/apache/beam/triggers/test_beam.py b/tests/providers/apache/beam/triggers/test_beam.py
index 82e56ff3ec..6bd1b4bc66 100644
--- a/tests/providers/apache/beam/triggers/test_beam.py
+++ b/tests/providers/apache/beam/triggers/test_beam.py
@@ -20,17 +20,21 @@ from unittest import mock
import pytest
-from airflow.providers.apache.beam.triggers.beam import BeamPipelineTrigger
+from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger
from airflow.triggers.base import TriggerEvent
-HOOK_STATUS_STR = "airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.start_python_pipeline_async"
-CLASSPATH = "airflow.providers.apache.beam.triggers.beam.BeamPipelineTrigger"
+HOOK_STATUS_STR_PYTHON = "airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.start_python_pipeline_async"
+HOOK_STATUS_STR_JAVA = "airflow.providers.apache.beam.hooks.beam.BeamAsyncHook.start_java_pipeline_async"
+CLASSPATH_PYTHON = "airflow.providers.apache.beam.triggers.beam.BeamPythonPipelineTrigger"
+CLASSPATH_JAVA = "airflow.providers.apache.beam.triggers.beam.BeamJavaPipelineTrigger"
TASK_ID = "test_task"
LOCATION = "test-location"
INSTANCE_NAME = "airflow-test-instance"
INSTANCE = {"type": "BASIC", "displayName": INSTANCE_NAME}
PROJECT_ID = "test_project_id"
+TEST_GCP_CONN_ID = "test_gcp_conn_id"
+TEST_IMPERSONATION_CHAIN = "A,B,C"
TEST_VARIABLES = {"output": "gs://bucket_test/output", "labels": {"airflow-version": "v2-7-0-dev0"}}
TEST_PY_FILE = "apache_beam.examples.wordcount"
TEST_PY_OPTIONS: list[str] = []
@@ -38,11 +42,17 @@ TEST_PY_INTERPRETER = "python3"
TEST_PY_REQUIREMENTS = ["apache-beam[gcp]==2.46.0"]
TEST_PY_PACKAGES = False
TEST_RUNNER = "DirectRunner"
+TEST_JAR_FILE = "example.jar"
+TEST_JOB_CLASS = "TestClass"
+TEST_CHECK_IF_RUNNING = False
+TEST_JOB_NAME = "test_job_name"
+TEST_POLL_SLEEP = 10
+TEST_CANCEL_TIMEOUT = 300
@pytest.fixture
-def trigger():
- return BeamPipelineTrigger(
+def python_trigger():
+ return BeamPythonPipelineTrigger(
variables=TEST_VARIABLES,
py_file=TEST_PY_FILE,
py_options=TEST_PY_OPTIONS,
@@ -53,14 +63,32 @@ def trigger():
)
-class TestBeamPipelineTrigger:
- def test_beam_trigger_serialization_should_execute_successfully(self, trigger):
+@pytest.fixture
+def java_trigger():
+ return BeamJavaPipelineTrigger(
+ variables=TEST_VARIABLES,
+ jar=TEST_JAR_FILE,
+ job_class=TEST_JOB_CLASS,
+ runner=TEST_RUNNER,
+ check_if_running=TEST_CHECK_IF_RUNNING,
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ job_name=TEST_JOB_NAME,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ poll_sleep=TEST_POLL_SLEEP,
+ cancel_timeout=TEST_CANCEL_TIMEOUT,
+ )
+
+
+class TestBeamPythonPipelineTrigger:
+ def test_beam_trigger_serialization_should_execute_successfully(self, python_trigger):
"""
- Asserts that the BeamPipelineTrigger correctly serializes its arguments
+ Asserts that the BeamPythonPipelineTrigger correctly serializes its arguments
and classpath.
"""
- classpath, kwargs = trigger.serialize()
- assert classpath == CLASSPATH
+ classpath, kwargs = python_trigger.serialize()
+ assert classpath == CLASSPATH_PYTHON
assert kwargs == {
"variables": TEST_VARIABLES,
"py_file": TEST_PY_FILE,
@@ -72,36 +100,118 @@ class TestBeamPipelineTrigger:
}
@pytest.mark.asyncio
- @mock.patch(HOOK_STATUS_STR)
- async def test_beam_trigger_on_success_should_execute_successfully(self, mock_pipeline_status, trigger):
+ @mock.patch(HOOK_STATUS_STR_PYTHON)
+ async def test_beam_trigger_on_success_should_execute_successfully(
+ self, mock_pipeline_status, python_trigger
+ ):
"""
- Tests the BeamPipelineTrigger only fires once the job execution reaches a successful state.
+ Tests the BeamPythonPipelineTrigger only fires once the job execution reaches a successful state.
"""
mock_pipeline_status.return_value = 0
- generator = trigger.run()
+ generator = python_trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}) == actual
@pytest.mark.asyncio
- @mock.patch(HOOK_STATUS_STR)
- async def test_beam_trigger_error_should_execute_successfully(self, mock_pipeline_status, trigger):
+ @mock.patch(HOOK_STATUS_STR_PYTHON)
+ async def test_beam_trigger_error_should_execute_successfully(self, mock_pipeline_status, python_trigger):
"""
- Test that BeamPipelineTrigger fires the correct event in case of an error.
+ Test that BeamPythonPipelineTrigger fires the correct event in case of an error.
"""
mock_pipeline_status.return_value = 1
- generator = trigger.run()
+ generator = python_trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "Operation failed"}) == actual
@pytest.mark.asyncio
- @mock.patch(HOOK_STATUS_STR)
- async def test_beam_trigger_exception_should_execute_successfully(self, mock_pipeline_status, trigger):
+ @mock.patch(HOOK_STATUS_STR_PYTHON)
+ async def test_beam_trigger_exception_should_execute_successfully(
+ self, mock_pipeline_status, python_trigger
+ ):
"""
- Test that BeamPipelineTrigger fires the correct event in case of an error.
+ Test that BeamPythonPipelineTrigger fires the correct event in case of an error.
"""
mock_pipeline_status.side_effect = Exception("Test exception")
- generator = trigger.run()
+ generator = python_trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
+
+
+class TestBeamJavaPipelineTrigger:
+ def test_beam_trigger_serialization_should_execute_successfully(self, java_trigger):
+ """
+ Asserts that the BeamJavaPipelineTrigger correctly serializes its arguments
+ and classpath.
+ """
+ classpath, kwargs = java_trigger.serialize()
+ assert classpath == CLASSPATH_JAVA
+ assert kwargs == {
+ "variables": TEST_VARIABLES,
+ "jar": TEST_JAR_FILE,
+ "job_class": TEST_JOB_CLASS,
+ "runner": TEST_RUNNER,
+ "check_if_running": TEST_CHECK_IF_RUNNING,
+ "project_id": PROJECT_ID,
+ "location": LOCATION,
+ "job_name": TEST_JOB_NAME,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
+ "poll_sleep": TEST_POLL_SLEEP,
+ "cancel_timeout": TEST_CANCEL_TIMEOUT,
+ }
+
+ @pytest.mark.asyncio
+ @mock.patch(HOOK_STATUS_STR_JAVA)
+ async def test_beam_trigger_on_success_should_execute_successfully(
+ self, mock_pipeline_status, java_trigger
+ ):
+ """
+ Tests the BeamJavaPipelineTrigger only fires once the job execution reaches a successful state.
+ """
+ mock_pipeline_status.return_value = 0
+ generator = java_trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}) == actual
+
+ @pytest.mark.asyncio
+ @mock.patch(HOOK_STATUS_STR_JAVA)
+ async def test_beam_trigger_error_should_execute_successfully(self, mock_pipeline_status, java_trigger):
+ """
+ Test that BeamJavaPipelineTrigger fires the correct event in case of an error.
+ """
+ mock_pipeline_status.return_value = 1
+
+ generator = java_trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "Operation failed"}) == actual
+
+ @pytest.mark.asyncio
+ @mock.patch(HOOK_STATUS_STR_JAVA)
+ async def test_beam_trigger_exception_should_execute_successfully(
+ self, mock_pipeline_status, java_trigger
+ ):
+ """
+ Test that BeamJavaPipelineTrigger fires the correct event in case of an error.
+ """
+ mock_pipeline_status.side_effect = Exception("Test exception")
+
+ generator = java_trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.list_jobs")
+ async def test_beam_trigger_exception_list_jobs_should_execute_successfully(
+ self, mock_list_jobs, java_trigger
+ ):
+ """
+ Test that BeamJavaPipelineTrigger fires the correct event in case of an error.
+ """
+ mock_list_jobs.side_effect = Exception("Test exception")
+
+ java_trigger.check_if_running = True
+ generator = java_trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py
index 22d4051c43..e26713502a 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -27,7 +27,7 @@ from unittest.mock import MagicMock
from uuid import UUID
import pytest
-from google.cloud.dataflow_v1beta3 import GetJobRequest, JobView
+from google.cloud.dataflow_v1beta3 import GetJobRequest, JobView, ListJobsRequest
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.apache.beam.hooks.beam import BeamHook, run_beam_command
@@ -89,6 +89,7 @@ BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
DATAFLOW_STRING = "airflow.providers.google.cloud.hooks.dataflow.{}"
TEST_PROJECT = "test-project"
TEST_JOB_ID = "test-job-id"
+TEST_JOBS_FILTER = ListJobsRequest.Filter.ACTIVE
TEST_LOCATION = "custom-location"
DEFAULT_PY_INTERPRETER = "python3"
TEST_FLEX_PARAMETERS = {
@@ -1949,7 +1950,7 @@ class TestAsyncHook:
)
@pytest.mark.asyncio
- @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.initialize_client")
+ @mock.patch(DATAFLOW_STRING.format("AsyncDataflowHook.initialize_client"))
async def test_get_job(self, initialize_client_mock, hook, make_mock_awaitable):
client = initialize_client_mock.return_value
make_mock_awaitable(client.get_job, None)
@@ -1972,3 +1973,27 @@ class TestAsyncHook:
client.get_job.assert_called_once_with(
request=request,
)
+
+ @pytest.mark.asyncio
+ @mock.patch(DATAFLOW_STRING.format("AsyncDataflowHook.initialize_client"))
+ async def test_list_jobs(self, initialize_client_mock, hook, make_mock_awaitable):
+ client = initialize_client_mock.return_value
+ make_mock_awaitable(client.get_job, None)
+
+ await hook.list_jobs(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ jobs_filter=TEST_JOBS_FILTER,
+ )
+
+ request = ListJobsRequest(
+ {
+ "project_id": TEST_PROJECT_ID,
+ "location": TEST_LOCATION,
+ "filter": TEST_JOBS_FILTER,
+ "page_size": None,
+ "page_token": None,
+ }
+ )
+ initialize_client_mock.assert_called_once()
+ client.list_jobs.assert_called_once_with(request=request)
diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py
index 53b33b89e9..12047bae5b 100644
--- a/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py
+++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_native_java.py
@@ -48,12 +48,12 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
DAG_ID = "dataflow_native_java"
BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
-PUBLIC_BUCKET = "system-tests-resources"
+PUBLIC_BUCKET = "airflow-system-tests-resources"
JAR_FILE_NAME = "word-count-beam-bundled-0.1.jar"
-REMOTE_JAR_FILE_PATH = f"{DAG_ID}/{JAR_FILE_NAME}"
+REMOTE_JAR_FILE_PATH = f"dataflow/java/{JAR_FILE_NAME}"
GCS_OUTPUT = f"gs://{BUCKET_NAME}"
-GCS_JAR = f"gs://{PUBLIC_BUCKET}/{REMOTE_JAR_FILE_PATH}"
+GCS_JAR = f"gs://{PUBLIC_BUCKET}/dataflow/java/{JAR_FILE_NAME}"
LOCATION = "europe-west3"
with DAG(
@@ -105,11 +105,38 @@ with DAG(
)
# [END howto_operator_start_java_job_jar_on_gcs]
+ # [START howto_operator_start_java_job_jar_on_gcs_deferrable]
+ start_java_deferrable = BeamRunJavaPipelineOperator(
+ runner=BeamRunnerType.DataflowRunner,
+ task_id="start-java-job-deferrable",
+ jar=GCS_JAR,
+ pipeline_options={
+ "output": GCS_OUTPUT,
+ },
+ job_class="org.apache.beam.examples.WordCount",
+ dataflow_config={
+ "check_if_running": CheckJobRunning.WaitForRun,
+ "location": LOCATION,
+ "poll_sleep": 10,
+ "append_job_name": False,
+ },
+ deferrable=True,
+ )
+ # [END howto_operator_start_java_job_jar_on_gcs_deferrable]
+
delete_bucket = GCSDeleteBucketOperator(
task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE
)
- create_bucket >> download_file >> [start_java_job_local, start_java_job] >> delete_bucket
+ (
+ # TEST SETUP
+ create_bucket
+ >> download_file
+ # TEST BODY
+ >> [start_java_job_local, start_java_job, start_java_deferrable]
+ # TEST TEARDOWN
+ >> delete_bucket
+ )
from tests.system.utils.watcher import watcher