You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/03/03 09:34:50 UTC
[airflow] 13/38: Add Apache Beam operators (#12814)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit fef2a36e1f4e2798a6becdfbb47d3018ea79b612
Author: Tobiasz Kędzierski <to...@polidea.com>
AuthorDate: Wed Feb 3 21:34:01 2021 +0100
Add Apache Beam operators (#12814)
(cherry picked from commit 1872d8719d24f94aeb1dcba9694837070b9884ca)
---
CONTRIBUTING.rst | 17 +-
INSTALL | 14 +-
.../apache/beam/BACKPORT_PROVIDER_README.md | 99 +++
airflow/providers/apache/beam/CHANGELOG.rst | 25 +
airflow/providers/apache/beam/README.md | 97 +++
airflow/providers/apache/beam/__init__.py | 17 +
.../providers/apache/beam/example_dags/__init__.py | 17 +
.../apache/beam/example_dags/example_beam.py | 315 +++++++++
airflow/providers/apache/beam/hooks/__init__.py | 17 +
airflow/providers/apache/beam/hooks/beam.py | 289 ++++++++
.../providers/apache/beam/operators/__init__.py | 17 +
airflow/providers/apache/beam/operators/beam.py | 446 ++++++++++++
airflow/providers/apache/beam/provider.yaml | 45 ++
airflow/providers/dependencies.json | 4 +
airflow/providers/google/cloud/hooks/dataflow.py | 330 ++++-----
.../providers/google/cloud/operators/dataflow.py | 331 +++++++--
.../copy_provider_package_sources.py | 62 ++
dev/provider_packages/prepare_provider_packages.py | 4 +-
.../apache-airflow-providers-apache-beam/index.rst | 36 +
.../operators.rst | 116 ++++
docs/apache-airflow/extra-packages-ref.rst | 2 +
docs/spelling_wordlist.txt | 2 +
.../run_install_and_test_provider_packages.sh | 2 +-
setup.py | 1 +
tests/core/test_providers_manager.py | 1 +
tests/providers/apache/beam/__init__.py | 16 +
tests/providers/apache/beam/hooks/__init__.py | 16 +
tests/providers/apache/beam/hooks/test_beam.py | 271 ++++++++
tests/providers/apache/beam/operators/__init__.py | 16 +
tests/providers/apache/beam/operators/test_beam.py | 274 ++++++++
.../apache/beam/operators/test_beam_system.py | 47 ++
.../providers/google/cloud/hooks/test_dataflow.py | 760 ++++++++++++---------
.../google/cloud/operators/test_dataflow.py | 223 ++++--
.../google/cloud/operators/test_mlengine_utils.py | 30 +-
34 files changed, 3263 insertions(+), 696 deletions(-)
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 6d0e224..0a6f381 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -572,13 +572,13 @@ This is the full list of those extras:
.. START EXTRAS HERE
-all, all_dbs, amazon, apache.atlas, apache.cassandra, apache.druid, apache.hdfs, apache.hive,
-apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop, apache.webhdfs,
-async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes, crypto, dask,
-databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker,
-druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google, google_auth,
-grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
-microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
+all, all_dbs, amazon, apache.atlas, apache.beam, apache.cassandra, apache.druid, apache.hdfs,
+apache.hive, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop,
+apache.webhdfs, async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes,
+crypto, dask, databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc,
+docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google,
+google_auth, grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes,
+ldap, microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
opsgenie, oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole,
rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack,
snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica, virtualenv, webhdfs, winrm,
@@ -641,12 +641,13 @@ Here is the list of packages and their extras:
Package Extras
========================== ===========================
amazon apache.hive,google,imap,mongo,mysql,postgres,ssh
+apache.beam google
apache.druid apache.hive
apache.hive amazon,microsoft.mssql,mysql,presto,samba,vertica
apache.livy http
dingding http
discord http
-google amazon,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,postgres,presto,salesforce,sftp,ssh
+google amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,postgres,presto,salesforce,sftp,ssh
hashicorp google
microsoft.azure google,oracle
microsoft.mssql odbc
diff --git a/INSTALL b/INSTALL
index e1ef456..d175aa1 100644
--- a/INSTALL
+++ b/INSTALL
@@ -97,13 +97,13 @@ The list of available extras:
# START EXTRAS HERE
-all, all_dbs, amazon, apache.atlas, apache.cassandra, apache.druid, apache.hdfs, apache.hive,
-apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop, apache.webhdfs,
-async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes, crypto, dask,
-databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker,
-druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google, google_auth,
-grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
-microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
+all, all_dbs, amazon, apache.atlas, apache.beam, apache.cassandra, apache.druid, apache.hdfs,
+apache.hive, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop,
+apache.webhdfs, async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes,
+crypto, dask, databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc,
+docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google,
+google_auth, grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes,
+ldap, microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
opsgenie, oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole,
rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack,
snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica, virtualenv, webhdfs, winrm,
diff --git a/airflow/providers/apache/beam/BACKPORT_PROVIDER_README.md b/airflow/providers/apache/beam/BACKPORT_PROVIDER_README.md
new file mode 100644
index 0000000..d0908b6
--- /dev/null
+++ b/airflow/providers/apache/beam/BACKPORT_PROVIDER_README.md
@@ -0,0 +1,99 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
+
+
+# Package apache-airflow-backport-providers-apache-beam
+
+Release:
+
+**Table of contents**
+
+- [Backport package](#backport-package)
+- [Installation](#installation)
+- [PIP requirements](#pip-requirements)
+- [Cross provider package dependencies](#cross-provider-package-dependencies)
+- [Provider class summary](#provider-classes-summary)
+ - [Operators](#operators)
+ - [Moved operators](#moved-operators)
+ - [Transfer operators](#transfer-operators)
+ - [Moved transfer operators](#moved-transfer-operators)
+ - [Hooks](#hooks)
+ - [Moved hooks](#moved-hooks)
+- [Releases](#releases)
+ - [Release](#release)
+
+## Backport package
+
+This is a backport providers package for `apache.beam` provider. All classes for this provider package
+are in `airflow.providers.apache.beam` python package.
+
+**Only Python 3.6+ is supported for this backport package.**
+
+While Airflow 1.10.* continues to support Python 2.7+ - you need to upgrade python to 3.6+ if you
+want to use this backport package.
+
+
+## Installation
+
+You can install this package on top of an existing airflow 1.10.* installation via
+`pip install apache-airflow-backport-providers-apache-beam`
+
+## Cross provider package dependencies
+
+Those are dependencies that might be needed in order to use all the features of the package.
+You need to install the specified backport providers package in order to use them.
+
+You can install such cross-provider dependencies when installing from PyPI. For example:
+
+```bash
+pip install apache-airflow-beckport-providers-apache-beam[google]
+```
+
+| Dependent package | Extra |
+|:----------------------------------------------------------------------------------------------------------|:------------|
+| [apache-airflow-providers-apache-google](https://pypi.org/project/apache-airflow-providers-apache-google) | google |
+
+
+# Provider classes summary
+
+In Airflow 2.0, all operators, transfers, hooks, sensors, secrets for the `apache.beam` provider
+are in the `airflow.providers.apache.beam` package. You can read more about the naming conventions used
+in [Naming conventions for provider packages](https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#naming-conventions-for-provider-packages)
+
+
+## Operators
+
+### New operators
+
+| New Airflow 2.0 operators: `airflow.providers.apache.beam` package |
+|:-----------------------------------------------------------------------------------------------------------------------------------------------|
+| [operators.beam.BeamRunJavaPipelineOperator](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py) |
+| [operators.beam.BeamRunPythonPipelineOperator](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py) |
+
+
+## Hooks
+
+### New hooks
+
+| New Airflow 2.0 hooks: `airflow.providers.apache.beam` package |
+|:-----------------------------------------------------------------------------------------------------------------|
+| [hooks.beam.BeamHook](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/hooks/beam.py) |
+
+
+## Releases
diff --git a/airflow/providers/apache/beam/CHANGELOG.rst b/airflow/providers/apache/beam/CHANGELOG.rst
new file mode 100644
index 0000000..cef7dda
--- /dev/null
+++ b/airflow/providers/apache/beam/CHANGELOG.rst
@@ -0,0 +1,25 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Changelog
+---------
+
+1.0.0
+.....
+
+Initial version of the provider.
diff --git a/airflow/providers/apache/beam/README.md b/airflow/providers/apache/beam/README.md
new file mode 100644
index 0000000..3aa0ead
--- /dev/null
+++ b/airflow/providers/apache/beam/README.md
@@ -0,0 +1,97 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
+
+
+# Package apache-airflow-providers-apache-beam
+
+Release: 0.0.1
+
+**Table of contents**
+
+- [Provider package](#provider-package)
+- [Installation](#installation)
+- [PIP requirements](#pip-requirements)
+- [Cross provider package dependencies](#cross-provider-package-dependencies)
+- [Provider class summary](#provider-classes-summary)
+ - [Operators](#operators)
+ - [Transfer operators](#transfer-operators)
+ - [Hooks](#hooks)
+- [Releases](#releases)
+
+## Provider package
+
+This is a provider package for `apache.beam` provider. All classes for this provider package
+are in `airflow.providers.apache.beam` python package.
+
+## Installation
+
+NOTE!
+
+On November 2020, new version of PIP (20.3) has been released with a new, 2020 resolver. This resolver
+does not yet work with Apache Airflow and might lead to errors in installation - depends on your choice
+of extras. In order to install Airflow you need to either downgrade pip to version 20.2.4
+`pip install --upgrade pip==20.2.4` or, in case you use Pip 20.3, you need to add option
+`--use-deprecated legacy-resolver` to your pip install command.
+
+You can install this package on top of an existing airflow 2.* installation via
+`pip install apache-airflow-providers-apache-beam`
+
+## Cross provider package dependencies
+
+Those are dependencies that might be needed in order to use all the features of the package.
+You need to install the specified backport providers package in order to use them.
+
+You can install such cross-provider dependencies when installing from PyPI. For example:
+
+```bash
+pip install apache-airflow-providers-apache-beam[google]
+```
+
+| Dependent package | Extra |
+|:--------------------------------------------------------------------------------------------|:------------|
+| [apache-airflow-providers-google](https://pypi.org/project/apache-airflow-providers-google) | google |
+
+
+# Provider classes summary
+
+In Airflow 2.0, all operators, transfers, hooks, sensors, secrets for the `apache.beam` provider
+are in the `airflow.providers.apache.beam` package. You can read more about the naming conventions used
+in [Naming conventions for provider packages](https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#naming-conventions-for-provider-packages)
+
+
+## Operators
+
+### New operators
+
+| New Airflow 2.0 operators: `airflow.providers.apache.beam` package |
+|:-----------------------------------------------------------------------------------------------------------------------------------------------|
+| [operators.beam.BeamRunJavaPipelineOperator](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py) |
+| [operators.beam.BeamRunPythonPipelineOperator](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py) |
+
+
+## Hooks
+
+### New hooks
+
+| New Airflow 2.0 hooks: `airflow.providers.apache.beam` package |
+|:-----------------------------------------------------------------------------------------------------------------|
+| [hooks.beam.BeamHook](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/hooks/beam.py) |
+
+
+## Releases
diff --git a/airflow/providers/apache/beam/__init__.py b/airflow/providers/apache/beam/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/apache/beam/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/beam/example_dags/__init__.py b/airflow/providers/apache/beam/example_dags/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/apache/beam/example_dags/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/beam/example_dags/example_beam.py b/airflow/providers/apache/beam/example_dags/example_beam.py
new file mode 100644
index 0000000..d20c4ce
--- /dev/null
+++ b/airflow/providers/apache/beam/example_dags/example_beam.py
@@ -0,0 +1,315 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Apache Beam operators
+"""
+import os
+from urllib.parse import urlparse
+
+from airflow import models
+from airflow.providers.apache.beam.operators.beam import (
+ BeamRunJavaPipelineOperator,
+ BeamRunPythonPipelineOperator,
+)
+from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
+from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration
+from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor
+from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
+from airflow.utils.dates import days_ago
+
+GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project')
+GCS_INPUT = os.environ.get('APACHE_BEAM_PYTHON', 'gs://apache-beam-samples/shakespeare/kinglear.txt')
+GCS_TMP = os.environ.get('APACHE_BEAM_GCS_TMP', 'gs://test-dataflow-example/temp/')
+GCS_STAGING = os.environ.get('APACHE_BEAM_GCS_STAGING', 'gs://test-dataflow-example/staging/')
+GCS_OUTPUT = os.environ.get('APACHE_BEAM_GCS_OUTPUT', 'gs://test-dataflow-example/output')
+GCS_PYTHON = os.environ.get('APACHE_BEAM_PYTHON', 'gs://test-dataflow-example/wordcount_debugging.py')
+GCS_PYTHON_DATAFLOW_ASYNC = os.environ.get(
+ 'APACHE_BEAM_PYTHON_DATAFLOW_ASYNC', 'gs://test-dataflow-example/wordcount_debugging.py'
+)
+
+GCS_JAR_DIRECT_RUNNER = os.environ.get(
+ 'APACHE_BEAM_DIRECT_RUNNER_JAR',
+ 'gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-DirectRunner.jar',
+)
+GCS_JAR_DATAFLOW_RUNNER = os.environ.get(
+ 'APACHE_BEAM_DATAFLOW_RUNNER_JAR', 'gs://test-dataflow-example/word-count-beam-bundled-0.1.jar'
+)
+GCS_JAR_SPARK_RUNNER = os.environ.get(
+ 'APACHE_BEAM_SPARK_RUNNER_JAR',
+ 'gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-SparkRunner.jar',
+)
+GCS_JAR_FLINK_RUNNER = os.environ.get(
+ 'APACHE_BEAM_FLINK_RUNNER_JAR',
+ 'gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-FlinkRunner.jar',
+)
+
+GCS_JAR_DIRECT_RUNNER_PARTS = urlparse(GCS_JAR_DIRECT_RUNNER)
+GCS_JAR_DIRECT_RUNNER_BUCKET_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.netloc
+GCS_JAR_DIRECT_RUNNER_OBJECT_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.path[1:]
+GCS_JAR_DATAFLOW_RUNNER_PARTS = urlparse(GCS_JAR_DATAFLOW_RUNNER)
+GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.netloc
+GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.path[1:]
+GCS_JAR_SPARK_RUNNER_PARTS = urlparse(GCS_JAR_SPARK_RUNNER)
+GCS_JAR_SPARK_RUNNER_BUCKET_NAME = GCS_JAR_SPARK_RUNNER_PARTS.netloc
+GCS_JAR_SPARK_RUNNER_OBJECT_NAME = GCS_JAR_SPARK_RUNNER_PARTS.path[1:]
+GCS_JAR_FLINK_RUNNER_PARTS = urlparse(GCS_JAR_FLINK_RUNNER)
+GCS_JAR_FLINK_RUNNER_BUCKET_NAME = GCS_JAR_FLINK_RUNNER_PARTS.netloc
+GCS_JAR_FLINK_RUNNER_OBJECT_NAME = GCS_JAR_FLINK_RUNNER_PARTS.path[1:]
+
+
+default_args = {
+ 'default_pipeline_options': {
+ 'output': '/tmp/example_beam',
+ },
+ "trigger_rule": "all_done",
+}
+
+
+with models.DAG(
+ "example_beam_native_java_direct_runner",
+ schedule_interval=None, # Override to match your needs
+ start_date=days_ago(1),
+ tags=['example'],
+) as dag_native_java_direct_runner:
+
+ # [START howto_operator_start_java_direct_runner_pipeline]
+ jar_to_local_direct_runner = GCSToLocalFilesystemOperator(
+ task_id="jar_to_local_direct_runner",
+ bucket=GCS_JAR_DIRECT_RUNNER_BUCKET_NAME,
+ object_name=GCS_JAR_DIRECT_RUNNER_OBJECT_NAME,
+ filename="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar",
+ )
+
+ start_java_pipeline_direct_runner = BeamRunJavaPipelineOperator(
+ task_id="start_java_pipeline_direct_runner",
+ jar="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar",
+ pipeline_options={
+ 'output': '/tmp/start_java_pipeline_direct_runner',
+ 'inputFile': GCS_INPUT,
+ },
+ job_class='org.apache.beam.examples.WordCount',
+ )
+
+ jar_to_local_direct_runner >> start_java_pipeline_direct_runner
+ # [END howto_operator_start_java_direct_runner_pipeline]
+
+with models.DAG(
+ "example_beam_native_java_dataflow_runner",
+ schedule_interval=None, # Override to match your needs
+ start_date=days_ago(1),
+ tags=['example'],
+) as dag_native_java_dataflow_runner:
+ # [START howto_operator_start_java_dataflow_runner_pipeline]
+ jar_to_local_dataflow_runner = GCSToLocalFilesystemOperator(
+ task_id="jar_to_local_dataflow_runner",
+ bucket=GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME,
+ object_name=GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME,
+ filename="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar",
+ )
+
+ start_java_pipeline_dataflow = BeamRunJavaPipelineOperator(
+ task_id="start_java_pipeline_dataflow",
+ runner="DataflowRunner",
+ jar="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar",
+ pipeline_options={
+ 'tempLocation': GCS_TMP,
+ 'stagingLocation': GCS_STAGING,
+ 'output': GCS_OUTPUT,
+ },
+ job_class='org.apache.beam.examples.WordCount',
+ dataflow_config={"job_name": "{{task.task_id}}", "location": "us-central1"},
+ )
+
+ jar_to_local_dataflow_runner >> start_java_pipeline_dataflow
+ # [END howto_operator_start_java_dataflow_runner_pipeline]
+
+with models.DAG(
+ "example_beam_native_java_spark_runner",
+ schedule_interval=None, # Override to match your needs
+ start_date=days_ago(1),
+ tags=['example'],
+) as dag_native_java_spark_runner:
+
+ jar_to_local_spark_runner = GCSToLocalFilesystemOperator(
+ task_id="jar_to_local_spark_runner",
+ bucket=GCS_JAR_SPARK_RUNNER_BUCKET_NAME,
+ object_name=GCS_JAR_SPARK_RUNNER_OBJECT_NAME,
+ filename="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar",
+ )
+
+ start_java_pipeline_spark_runner = BeamRunJavaPipelineOperator(
+ task_id="start_java_pipeline_spark_runner",
+ runner="SparkRunner",
+ jar="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar",
+ pipeline_options={
+ 'output': '/tmp/start_java_pipeline_spark_runner',
+ 'inputFile': GCS_INPUT,
+ },
+ job_class='org.apache.beam.examples.WordCount',
+ )
+
+ jar_to_local_spark_runner >> start_java_pipeline_spark_runner
+
+with models.DAG(
+ "example_beam_native_java_flink_runner",
+ schedule_interval=None, # Override to match your needs
+ start_date=days_ago(1),
+ tags=['example'],
+) as dag_native_java_flink_runner:
+
+ jar_to_local_flink_runner = GCSToLocalFilesystemOperator(
+ task_id="jar_to_local_flink_runner",
+ bucket=GCS_JAR_FLINK_RUNNER_BUCKET_NAME,
+ object_name=GCS_JAR_FLINK_RUNNER_OBJECT_NAME,
+ filename="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar",
+ )
+
+ start_java_pipeline_flink_runner = BeamRunJavaPipelineOperator(
+ task_id="start_java_pipeline_flink_runner",
+ runner="FlinkRunner",
+ jar="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar",
+ pipeline_options={
+ 'output': '/tmp/start_java_pipeline_flink_runner',
+ 'inputFile': GCS_INPUT,
+ },
+ job_class='org.apache.beam.examples.WordCount',
+ )
+
+ jar_to_local_flink_runner >> start_java_pipeline_flink_runner
+
+
+with models.DAG(
+ "example_beam_native_python",
+ default_args=default_args,
+ start_date=days_ago(1),
+ schedule_interval=None, # Override to match your needs
+ tags=['example'],
+) as dag_native_python:
+
+ # [START howto_operator_start_python_direct_runner_pipeline_local_file]
+ start_python_pipeline_local_direct_runner = BeamRunPythonPipelineOperator(
+ task_id="start_python_pipeline_local_direct_runner",
+ py_file='apache_beam.examples.wordcount',
+ py_options=['-m'],
+ py_requirements=['apache-beam[gcp]==2.26.0'],
+ py_interpreter='python3',
+ py_system_site_packages=False,
+ )
+ # [END howto_operator_start_python_direct_runner_pipeline_local_file]
+
+ # [START howto_operator_start_python_direct_runner_pipeline_gcs_file]
+ start_python_pipeline_direct_runner = BeamRunPythonPipelineOperator(
+ task_id="start_python_pipeline_direct_runner",
+ py_file=GCS_PYTHON,
+ py_options=[],
+ pipeline_options={"output": GCS_OUTPUT},
+ py_requirements=['apache-beam[gcp]==2.26.0'],
+ py_interpreter='python3',
+ py_system_site_packages=False,
+ )
+ # [END howto_operator_start_python_direct_runner_pipeline_gcs_file]
+
+ # [START howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
+ start_python_pipeline_dataflow_runner = BeamRunPythonPipelineOperator(
+ task_id="start_python_pipeline_dataflow_runner",
+ runner="DataflowRunner",
+ py_file=GCS_PYTHON,
+ pipeline_options={
+ 'tempLocation': GCS_TMP,
+ 'stagingLocation': GCS_STAGING,
+ 'output': GCS_OUTPUT,
+ },
+ py_options=[],
+ py_requirements=['apache-beam[gcp]==2.26.0'],
+ py_interpreter='python3',
+ py_system_site_packages=False,
+ dataflow_config=DataflowConfiguration(
+ job_name='{{task.task_id}}', project_id=GCP_PROJECT_ID, location="us-central1"
+ ),
+ )
+ # [END howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
+
+ start_python_pipeline_local_spark_runner = BeamRunPythonPipelineOperator(
+ task_id="start_python_pipeline_local_spark_runner",
+ py_file='apache_beam.examples.wordcount',
+ runner="SparkRunner",
+ py_options=['-m'],
+ py_requirements=['apache-beam[gcp]==2.26.0'],
+ py_interpreter='python3',
+ py_system_site_packages=False,
+ )
+
+ start_python_pipeline_local_flink_runner = BeamRunPythonPipelineOperator(
+ task_id="start_python_pipeline_local_flink_runner",
+ py_file='apache_beam.examples.wordcount',
+ runner="FlinkRunner",
+ py_options=['-m'],
+ pipeline_options={
+ 'output': '/tmp/start_python_pipeline_local_flink_runner',
+ },
+ py_requirements=['apache-beam[gcp]==2.26.0'],
+ py_interpreter='python3',
+ py_system_site_packages=False,
+ )
+
+ [
+ start_python_pipeline_local_direct_runner,
+ start_python_pipeline_direct_runner,
+ ] >> start_python_pipeline_local_flink_runner >> start_python_pipeline_local_spark_runner
+
+
+with models.DAG(
+ "example_beam_native_python_dataflow_async",
+ default_args=default_args,
+ start_date=days_ago(1),
+ schedule_interval=None, # Override to match your needs
+ tags=['example'],
+) as dag_native_python_dataflow_async:
+ # [START howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
+ start_python_job_dataflow_runner_async = BeamRunPythonPipelineOperator(
+ task_id="start_python_job_dataflow_runner_async",
+ runner="DataflowRunner",
+ py_file=GCS_PYTHON_DATAFLOW_ASYNC,
+ pipeline_options={
+ 'tempLocation': GCS_TMP,
+ 'stagingLocation': GCS_STAGING,
+ 'output': GCS_OUTPUT,
+ },
+ py_options=[],
+ py_requirements=['apache-beam[gcp]==2.26.0'],
+ py_interpreter='python3',
+ py_system_site_packages=False,
+ dataflow_config=DataflowConfiguration(
+ job_name='{{task.task_id}}',
+ project_id=GCP_PROJECT_ID,
+ location="us-central1",
+ wait_until_finished=False,
+ ),
+ )
+
+ wait_for_python_job_dataflow_runner_async_done = DataflowJobStatusSensor(
+ task_id="wait-for-python-job-async-done",
+ job_id="{{task_instance.xcom_pull('start_python_job_dataflow_runner_async')['dataflow_job_id']}}",
+ expected_statuses={DataflowJobStatus.JOB_STATE_DONE},
+ project_id=GCP_PROJECT_ID,
+ location='us-central1',
+ )
+
+ start_python_job_dataflow_runner_async >> wait_for_python_job_dataflow_runner_async_done
+ # [END howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
diff --git a/airflow/providers/apache/beam/hooks/__init__.py b/airflow/providers/apache/beam/hooks/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/apache/beam/hooks/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py
new file mode 100644
index 0000000..8e188b0
--- /dev/null
+++ b/airflow/providers/apache/beam/hooks/beam.py
@@ -0,0 +1,289 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains a Apache Beam Hook."""
+import json
+import select
+import shlex
+import subprocess
+import textwrap
+from tempfile import TemporaryDirectory
+from typing import Callable, List, Optional
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.python_virtualenv import prepare_virtualenv
+
+
+class BeamRunnerType:
+ """
+ Helper class for listing runner types.
+ For more information about runners see:
+ https://beam.apache.org/documentation/
+ """
+
+ DataflowRunner = "DataflowRunner"
+ DirectRunner = "DirectRunner"
+ SparkRunner = "SparkRunner"
+ FlinkRunner = "FlinkRunner"
+ SamzaRunner = "SamzaRunner"
+ NemoRunner = "NemoRunner"
+ JetRunner = "JetRunner"
+ Twister2Runner = "Twister2Runner"
+
+
+def beam_options_to_args(options: dict) -> List[str]:
+ """
+ Returns a formatted pipeline options from a dictionary of arguments
+
+ The logic of this method should be compatible with Apache Beam:
+ https://github.com/apache/beam/blob/b56740f0e8cd80c2873412847d0b336837429fb9/sdks/python/
+ apache_beam/options/pipeline_options.py#L230-L251
+
+ :param options: Dictionary with options
+ :type options: dict
+ :return: List of arguments
+ :rtype: List[str]
+ """
+ if not options:
+ return []
+
+ args: List[str] = []
+ for attr, value in options.items():
+ if value is None or (isinstance(value, bool) and value):
+ args.append(f"--{attr}")
+ elif isinstance(value, list):
+ args.extend([f"--{attr}={v}" for v in value])
+ else:
+ args.append(f"--{attr}={value}")
+ return args
+
+
+class BeamCommandRunner(LoggingMixin):
+ """
+ Class responsible for running pipeline command in subprocess
+
+ :param cmd: Parts of the command to be run in subprocess
+ :type cmd: List[str]
+ :param process_line_callback: Optional callback which can be used to process
+ stdout and stderr to detect job id
+ :type process_line_callback: Optional[Callable[[str], None]]
+ """
+
+ def __init__(
+ self,
+ cmd: List[str],
+ process_line_callback: Optional[Callable[[str], None]] = None,
+ ) -> None:
+ super().__init__()
+ self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd))
+ self.process_line_callback = process_line_callback
+ self.job_id: Optional[str] = None
+ self._proc = subprocess.Popen(
+ cmd,
+ shell=False,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ close_fds=True,
+ )
+
+ def _process_fd(self, fd):
+ """
+ Prints output to logs.
+
+ :param fd: File descriptor.
+ """
+ if fd not in (self._proc.stdout, self._proc.stderr):
+ raise Exception("No data in stderr or in stdout.")
+
+ fd_to_log = {self._proc.stderr: self.log.warning, self._proc.stdout: self.log.info}
+ func_log = fd_to_log[fd]
+
+ while True:
+ line = fd.readline().decode()
+ if not line:
+ return
+ if self.process_line_callback:
+ self.process_line_callback(line)
+ func_log(line.rstrip("\n"))
+
+ def wait_for_done(self) -> None:
+ """Waits for Apache Beam pipeline to complete."""
+ self.log.info("Start waiting for Apache Beam process to complete.")
+ reads = [self._proc.stderr, self._proc.stdout]
+ while True:
+ # Wait for at least one available fd.
+ readable_fds, _, _ = select.select(reads, [], [], 5)
+ if readable_fds is None:
+ self.log.info("Waiting for Apache Beam process to complete.")
+ continue
+
+ for readable_fd in readable_fds:
+ self._process_fd(readable_fd)
+
+ if self._proc.poll() is not None:
+ break
+
+ # Corner case: check if more output was created between the last read and the process termination
+ for readable_fd in reads:
+ self._process_fd(readable_fd)
+
+ self.log.info("Process exited with return code: %s", self._proc.returncode)
+
+ if self._proc.returncode != 0:
+ raise AirflowException(f"Apache Beam process failed with return code {self._proc.returncode}")
+
+
+class BeamHook(BaseHook):
+ """
+ Hook for Apache Beam.
+
+ All the methods in the hook where project_id is used must be called with
+ keyword arguments rather than positional.
+
+ :param runner: Runner type
+ :type runner: str
+ """
+
+ def __init__(
+ self,
+ runner: str,
+ ) -> None:
+ self.runner = runner
+ super().__init__()
+
+ def _start_pipeline(
+ self,
+ variables: dict,
+ command_prefix: List[str],
+ process_line_callback: Optional[Callable[[str], None]] = None,
+ ) -> None:
+ cmd = command_prefix + [
+ f"--runner={self.runner}",
+ ]
+ if variables:
+ cmd.extend(beam_options_to_args(variables))
+ cmd_runner = BeamCommandRunner(
+ cmd=cmd,
+ process_line_callback=process_line_callback,
+ )
+ cmd_runner.wait_for_done()
+
+ def start_python_pipeline( # pylint: disable=too-many-arguments
+ self,
+ variables: dict,
+ py_file: str,
+ py_options: List[str],
+ py_interpreter: str = "python3",
+ py_requirements: Optional[List[str]] = None,
+ py_system_site_packages: bool = False,
+ process_line_callback: Optional[Callable[[str], None]] = None,
+ ):
+ """
+ Starts Apache Beam python pipeline.
+
+ :param variables: Variables passed to the pipeline.
+ :type variables: Dict
+ :param py_options: Additional options.
+ :type py_options: List[str]
+ :param py_interpreter: Python version of the Apache Beam pipeline.
+ If None, this defaults to the python3.
+ To track python versions supported by beam and related
+ issues check: https://issues.apache.org/jira/browse/BEAM-1251
+ :type py_interpreter: str
+ :param py_requirements: Additional python package(s) to install.
+ If a value is passed to this parameter, a new virtual environment has been created with
+ additional packages installed.
+
+ You could also install the apache-beam package if it is not installed on your system or you want
+ to use a different version.
+ :type py_requirements: List[str]
+ :param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
+ See virtualenv documentation for more information.
+
+ This option is only relevant if the ``py_requirements`` parameter is not None.
+ :type py_system_site_packages: bool
+ :param on_new_job_id_callback: Callback called when the job ID is known.
+ :type on_new_job_id_callback: callable
+ """
+ if "labels" in variables:
+ variables["labels"] = [f"{key}={value}" for key, value in variables["labels"].items()]
+
+ if py_requirements is not None:
+ if not py_requirements and not py_system_site_packages:
+ warning_invalid_environment = textwrap.dedent(
+ """\
+ Invalid method invocation. You have disabled inclusion of system packages and empty list
+ required for installation, so it is not possible to create a valid virtual environment.
+ In the virtual environment, apache-beam package must be installed for your job to be \
+ executed. To fix this problem:
+ * install apache-beam on the system, then set parameter py_system_site_packages to True,
+ * add apache-beam to the list of required packages in parameter py_requirements.
+ """
+ )
+ raise AirflowException(warning_invalid_environment)
+
+ with TemporaryDirectory(prefix="apache-beam-venv") as tmp_dir:
+ py_interpreter = prepare_virtualenv(
+ venv_directory=tmp_dir,
+ python_bin=py_interpreter,
+ system_site_packages=py_system_site_packages,
+ requirements=py_requirements,
+ )
+ command_prefix = [py_interpreter] + py_options + [py_file]
+
+ self._start_pipeline(
+ variables=variables,
+ command_prefix=command_prefix,
+ process_line_callback=process_line_callback,
+ )
+ else:
+ command_prefix = [py_interpreter] + py_options + [py_file]
+
+ self._start_pipeline(
+ variables=variables,
+ command_prefix=command_prefix,
+ process_line_callback=process_line_callback,
+ )
+
+ def start_java_pipeline(
+ self,
+ variables: dict,
+ jar: str,
+ job_class: Optional[str] = None,
+ process_line_callback: Optional[Callable[[str], None]] = None,
+ ) -> None:
+ """
+ Starts Apache Beam Java pipeline.
+
+ :param variables: Variables passed to the job.
+ :type variables: dict
+ :param jar: Name of the jar for the pipeline
+ :type job_class: str
+ :param job_class: Name of the java class for the pipeline.
+ :type job_class: str
+ """
+ if "labels" in variables:
+ variables["labels"] = json.dumps(variables["labels"], separators=(",", ":"))
+
+ command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar]
+ self._start_pipeline(
+ variables=variables,
+ command_prefix=command_prefix,
+ process_line_callback=process_line_callback,
+ )
diff --git a/airflow/providers/apache/beam/operators/__init__.py b/airflow/providers/apache/beam/operators/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/apache/beam/operators/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py
new file mode 100644
index 0000000..849298e
--- /dev/null
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -0,0 +1,446 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains Apache Beam operators."""
+from contextlib import ExitStack
+from typing import Callable, List, Optional, Union
+
+from airflow.models import BaseOperator
+from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
+from airflow.providers.google.cloud.hooks.dataflow import (
+ DataflowHook,
+ process_line_and_extract_dataflow_job_id_callback,
+)
+from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, DataflowConfiguration
+from airflow.utils.decorators import apply_defaults
+from airflow.utils.helpers import convert_camel_to_snake
+from airflow.version import version
+
+
+class BeamRunPythonPipelineOperator(BaseOperator):
+ """
+ Launching Apache Beam pipelines written in Python. Note that both
+ ``default_pipeline_options`` and ``pipeline_options`` will be merged to specify pipeline
+ execution parameter, and ``default_pipeline_options`` is expected to save
+ high-level options, for instances, project and zone information, which
+ apply to all beam operators in the DAG.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:BeamRunPythonPipelineOperator`
+
+ .. seealso::
+ For more detail on Apache Beam have a look at the reference:
+ https://beam.apache.org/documentation/
+
+ :param py_file: Reference to the python Apache Beam pipeline file.py, e.g.,
+ /some/local/file/path/to/your/python/pipeline/file. (templated)
+ :type py_file: str
+ :param runner: Runner on which pipeline will be run. By default "DirectRunner" is being used.
+ Other possible options: DataflowRunner, SparkRunner, FlinkRunner.
+ See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType`
+ See: https://beam.apache.org/documentation/runners/capability-matrix/
+
+ If you use Dataflow runner check dedicated operator:
+ :class:`~providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`
+ :type runner: str
+ :param py_options: Additional python options, e.g., ["-m", "-v"].
+ :type py_options: list[str]
+ :param default_pipeline_options: Map of default pipeline options.
+ :type default_pipeline_options: dict
+ :param pipeline_options: Map of pipeline options.The key must be a dictionary.
+ The value can contain different types:
+
+ * If the value is None, the single option - ``--key`` (without value) will be added.
+ * If the value is False, this option will be skipped
+ * If the value is True, the single option - ``--key`` (without value) will be added.
+ * If the value is list, the many options will be added for each key.
+ If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` options
+ will be left
+ * Other value types will be replaced with the Python textual representation.
+
+ When defining labels (``labels`` option), you can also provide a dictionary.
+ :type pipeline_options: dict
+ :param py_interpreter: Python version of the beam pipeline.
+ If None, this defaults to the python3.
+ To track python versions supported by beam and related
+ issues check: https://issues.apache.org/jira/browse/BEAM-1251
+ :type py_interpreter: str
+ :param py_requirements: Additional python package(s) to install.
+ If a value is passed to this parameter, a new virtual environment has been created with
+ additional packages installed.
+
+ You could also install the apache_beam package if it is not installed on your system or you want
+ to use a different version.
+ :type py_requirements: List[str]
+ :param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
+ See virtualenv documentation for more information.
+
+ This option is only relevant if the ``py_requirements`` parameter is not None.
+ :param gcp_conn_id: Optional.
+ The connection ID to use connecting to Google Cloud Storage if python file is on GCS.
+ :type gcp_conn_id: str
+ :param delegate_to: Optional.
+ The account to impersonate using domain-wide delegation of authority,
+ if any. For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :type delegate_to: str
+ :param dataflow_config: Dataflow configuration, used when runner type is set to DataflowRunner
+ :type dataflow_config: Union[dict, providers.google.cloud.operators.dataflow.DataflowConfiguration]
+ """
+
+ template_fields = ["py_file", "runner", "pipeline_options", "default_pipeline_options", "dataflow_config"]
+ template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'}
+
+ @apply_defaults
+ def __init__(
+ self,
+ *,
+ py_file: str,
+ runner: str = "DirectRunner",
+ default_pipeline_options: Optional[dict] = None,
+ pipeline_options: Optional[dict] = None,
+ py_interpreter: str = "python3",
+ py_options: Optional[List[str]] = None,
+ py_requirements: Optional[List[str]] = None,
+ py_system_site_packages: bool = False,
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: Optional[str] = None,
+ dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.py_file = py_file
+ self.runner = runner
+ self.py_options = py_options or []
+ self.default_pipeline_options = default_pipeline_options or {}
+ self.pipeline_options = pipeline_options or {}
+ self.pipeline_options.setdefault("labels", {}).update(
+ {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
+ )
+ self.py_interpreter = py_interpreter
+ self.py_requirements = py_requirements
+ self.py_system_site_packages = py_system_site_packages
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+ self.dataflow_config = dataflow_config or {}
+ self.beam_hook: Optional[BeamHook] = None
+ self.dataflow_hook: Optional[DataflowHook] = None
+ self.dataflow_job_id: Optional[str] = None
+
+ if self.dataflow_config and self.runner.lower() != BeamRunnerType.DataflowRunner.lower():
+ self.log.warning(
+ "dataflow_config is defined but runner is different than DataflowRunner (%s)", self.runner
+ )
+
+ def execute(self, context):
+ """Execute the Apache Beam Pipeline."""
+ self.beam_hook = BeamHook(runner=self.runner)
+ pipeline_options = self.default_pipeline_options.copy()
+ process_line_callback: Optional[Callable] = None
+ is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
+
+ if isinstance(self.dataflow_config, dict):
+ self.dataflow_config = DataflowConfiguration(**self.dataflow_config)
+
+ if is_dataflow:
+ self.dataflow_hook = DataflowHook(
+ gcp_conn_id=self.dataflow_config.gcp_conn_id or self.gcp_conn_id,
+ delegate_to=self.dataflow_config.delegate_to or self.delegate_to,
+ poll_sleep=self.dataflow_config.poll_sleep,
+ impersonation_chain=self.dataflow_config.impersonation_chain,
+ drain_pipeline=self.dataflow_config.drain_pipeline,
+ cancel_timeout=self.dataflow_config.cancel_timeout,
+ wait_until_finished=self.dataflow_config.wait_until_finished,
+ )
+ self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id
+
+ dataflow_job_name = DataflowHook.build_dataflow_job_name(
+ self.dataflow_config.job_name, self.dataflow_config.append_job_name
+ )
+ pipeline_options["job_name"] = dataflow_job_name
+ pipeline_options["project"] = self.dataflow_config.project_id
+ pipeline_options["region"] = self.dataflow_config.location
+ pipeline_options.setdefault("labels", {}).update(
+ {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
+ )
+
+ def set_current_dataflow_job_id(job_id):
+ self.dataflow_job_id = job_id
+
+ process_line_callback = process_line_and_extract_dataflow_job_id_callback(
+ on_new_job_id_callback=set_current_dataflow_job_id
+ )
+
+ pipeline_options.update(self.pipeline_options)
+
+ # Convert argument names from lowerCamelCase to snake case.
+ formatted_pipeline_options = {
+ convert_camel_to_snake(key): pipeline_options[key] for key in pipeline_options
+ }
+
+ with ExitStack() as exit_stack:
+ if self.py_file.lower().startswith("gs://"):
+ gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
+ tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member
+ gcs_hook.provide_file(object_url=self.py_file)
+ )
+ self.py_file = tmp_gcs_file.name
+
+ self.beam_hook.start_python_pipeline(
+ variables=formatted_pipeline_options,
+ py_file=self.py_file,
+ py_options=self.py_options,
+ py_interpreter=self.py_interpreter,
+ py_requirements=self.py_requirements,
+ py_system_site_packages=self.py_system_site_packages,
+ process_line_callback=process_line_callback,
+ )
+
+ if is_dataflow:
+ self.dataflow_hook.wait_for_done( # pylint: disable=no-value-for-parameter
+ job_name=dataflow_job_name,
+ location=self.dataflow_config.location,
+ job_id=self.dataflow_job_id,
+ multiple_jobs=False,
+ )
+
+ return {"dataflow_job_id": self.dataflow_job_id}
+
+ def on_kill(self) -> None:
+ if self.dataflow_hook and self.dataflow_job_id:
+ self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id)
+ self.dataflow_hook.cancel_job(
+ job_id=self.dataflow_job_id,
+ project_id=self.dataflow_config.project_id,
+ )
+
+
+# pylint: disable=too-many-instance-attributes
+class BeamRunJavaPipelineOperator(BaseOperator):
+ """
+ Launching Apache Beam pipelines written in Java.
+
+ Note that both
+ ``default_pipeline_options`` and ``pipeline_options`` will be merged to specify pipeline
+ execution parameter, and ``default_pipeline_options`` is expected to save
+ high-level pipeline_options, for instances, project and zone information, which
+ apply to all Apache Beam operators in the DAG.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:BeamRunJavaPipelineOperator`
+
+ .. seealso::
+ For more detail on Apache Beam have a look at the reference:
+ https://beam.apache.org/documentation/
+
+ You need to pass the path to your jar file as a file reference with the ``jar``
+ parameter, the jar needs to be a self executing jar (see documentation here:
+ https://beam.apache.org/documentation/runners/dataflow/#self-executing-jar).
+ Use ``pipeline_options`` to pass on pipeline_options to your job.
+
+ :param jar: The reference to a self executing Apache Beam jar (templated).
+ :type jar: str
+ :param runner: Runner on which pipeline will be run. By default "DirectRunner" is being used.
+ See:
+ https://beam.apache.org/documentation/runners/capability-matrix/
+ If you use Dataflow runner check dedicated operator:
+ :class:`~providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`
+ :type runner: str
+ :param job_class: The name of the Apache Beam pipeline class to be executed, it
+ is often not the main class configured in the pipeline jar file.
+ :type job_class: str
+ :param default_pipeline_options: Map of default job pipeline_options.
+ :type default_pipeline_options: dict
+ :param pipeline_options: Map of job specific pipeline_options.The key must be a dictionary.
+ The value can contain different types:
+
+ * If the value is None, the single option - ``--key`` (without value) will be added.
+ * If the value is False, this option will be skipped
+ * If the value is True, the single option - ``--key`` (without value) will be added.
+ * If the value is list, the many pipeline_options will be added for each key.
+ If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` pipeline_options
+ will be left
+ * Other value types will be replaced with the Python textual representation.
+
+ When defining labels (``labels`` option), you can also provide a dictionary.
+ :type pipeline_options: dict
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud Storage if jar is on GCS
+ :type gcp_conn_id: str
+ :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+ if any. For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :type delegate_to: str
+ :param dataflow_config: Dataflow configuration, used when runner type is set to DataflowRunner
+ :type dataflow_config: Union[dict, providers.google.cloud.operators.dataflow.DataflowConfiguration]
+ """
+
+ template_fields = [
+ "jar",
+ "runner",
+ "job_class",
+ "pipeline_options",
+ "default_pipeline_options",
+ "dataflow_config",
+ ]
+ template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'}
+ ui_color = "#0273d4"
+
+ @apply_defaults
+ def __init__(
+ self,
+ *,
+ jar: str,
+ runner: str = "DirectRunner",
+ job_class: Optional[str] = None,
+ default_pipeline_options: Optional[dict] = None,
+ pipeline_options: Optional[dict] = None,
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: Optional[str] = None,
+ dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.jar = jar
+ self.runner = runner
+ self.default_pipeline_options = default_pipeline_options or {}
+ self.pipeline_options = pipeline_options or {}
+ self.job_class = job_class
+ self.dataflow_config = dataflow_config or {}
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+ self.dataflow_job_id = None
+ self.dataflow_hook: Optional[DataflowHook] = None
+ self.beam_hook: Optional[BeamHook] = None
+ self._dataflow_job_name: Optional[str] = None
+
+ if self.dataflow_config and self.runner.lower() != BeamRunnerType.DataflowRunner.lower():
+ self.log.warning(
+ "dataflow_config is defined but runner is different than DataflowRunner (%s)", self.runner
+ )
+
+ def execute(self, context):
+ """Execute the Apache Beam Pipeline."""
+ self.beam_hook = BeamHook(runner=self.runner)
+ pipeline_options = self.default_pipeline_options.copy()
+ process_line_callback: Optional[Callable] = None
+ is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
+
+ if isinstance(self.dataflow_config, dict):
+ self.dataflow_config = DataflowConfiguration(**self.dataflow_config)
+
+ if is_dataflow:
+ self.dataflow_hook = DataflowHook(
+ gcp_conn_id=self.dataflow_config.gcp_conn_id or self.gcp_conn_id,
+ delegate_to=self.dataflow_config.delegate_to or self.delegate_to,
+ poll_sleep=self.dataflow_config.poll_sleep,
+ impersonation_chain=self.dataflow_config.impersonation_chain,
+ drain_pipeline=self.dataflow_config.drain_pipeline,
+ cancel_timeout=self.dataflow_config.cancel_timeout,
+ wait_until_finished=self.dataflow_config.wait_until_finished,
+ )
+ self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id
+
+ self._dataflow_job_name = DataflowHook.build_dataflow_job_name(
+ self.dataflow_config.job_name, self.dataflow_config.append_job_name
+ )
+ pipeline_options["jobName"] = self.dataflow_config.job_name
+ pipeline_options["project"] = self.dataflow_config.project_id
+ pipeline_options["region"] = self.dataflow_config.location
+ pipeline_options.setdefault("labels", {}).update(
+ {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
+ )
+
+ def set_current_dataflow_job_id(job_id):
+ self.dataflow_job_id = job_id
+
+ process_line_callback = process_line_and_extract_dataflow_job_id_callback(
+ on_new_job_id_callback=set_current_dataflow_job_id
+ )
+
+ pipeline_options.update(self.pipeline_options)
+
+ with ExitStack() as exit_stack:
+ if self.jar.lower().startswith("gs://"):
+ gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
+ tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member
+ gcs_hook.provide_file(object_url=self.jar)
+ )
+ self.jar = tmp_gcs_file.name
+
+ if is_dataflow:
+ is_running = False
+ if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob:
+ is_running = (
+ # The reason for disable=no-value-for-parameter is that project_id parameter is
+ # required but here is not passed, moreover it cannot be passed here.
+ # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
+ # fallback project_id value from variables and raise error if project_id is
+ # defined both in variables and as parameter (here is already defined in variables)
+ self.dataflow_hook.is_job_dataflow_running( # pylint: disable=no-value-for-parameter
+ name=self.dataflow_config.job_name,
+ variables=pipeline_options,
+ )
+ )
+ while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
+ # The reason for disable=no-value-for-parameter is that project_id parameter is
+ # required but here is not passed, moreover it cannot be passed here.
+ # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
+ # fallback project_id value from variables and raise error if project_id is
+ # defined both in variables and as parameter (here is already defined in variables)
+ # pylint: disable=no-value-for-parameter
+ is_running = self.dataflow_hook.is_job_dataflow_running(
+ name=self.dataflow_config.job_name,
+ variables=pipeline_options,
+ )
+ if not is_running:
+ pipeline_options["jobName"] = self._dataflow_job_name
+ self.beam_hook.start_java_pipeline(
+ variables=pipeline_options,
+ jar=self.jar,
+ job_class=self.job_class,
+ process_line_callback=process_line_callback,
+ )
+ self.dataflow_hook.wait_for_done(
+ job_name=self._dataflow_job_name,
+ location=self.dataflow_config.location,
+ job_id=self.dataflow_job_id,
+ multiple_jobs=self.dataflow_config.multiple_jobs,
+ project_id=self.dataflow_config.project_id,
+ )
+
+ else:
+ self.beam_hook.start_java_pipeline(
+ variables=pipeline_options,
+ jar=self.jar,
+ job_class=self.job_class,
+ process_line_callback=process_line_callback,
+ )
+
+ return {"dataflow_job_id": self.dataflow_job_id}
+
+ def on_kill(self) -> None:
+ if self.dataflow_hook and self.dataflow_job_id:
+ self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id)
+ self.dataflow_hook.cancel_job(
+ job_id=self.dataflow_job_id,
+ project_id=self.dataflow_config.project_id,
+ )
diff --git a/airflow/providers/apache/beam/provider.yaml b/airflow/providers/apache/beam/provider.yaml
new file mode 100644
index 0000000..4325265
--- /dev/null
+++ b/airflow/providers/apache/beam/provider.yaml
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-apache-beam
+name: Apache Beam
+description: |
+ `Apache Beam <https://beam.apache.org/>`__.
+
+versions:
+ - 0.0.1
+
+integrations:
+ - integration-name: Apache Beam
+ external-doc-url: https://beam.apache.org/
+ how-to-guide:
+ - /docs/apache-airflow-providers-apache-beam/operators.rst
+ tags: [apache]
+
+operators:
+ - integration-name: Apache Beam
+ python-modules:
+ - airflow.providers.apache.beam.operators.beam
+
+hooks:
+ - integration-name: Apache Beam
+ python-modules:
+ - airflow.providers.apache.beam.hooks.beam
+
+hook-class-names:
+ - airflow.providers.apache.beam.hooks.beam.BeamHook
diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json
index 748b1a5..836020c 100644
--- a/airflow/providers/dependencies.json
+++ b/airflow/providers/dependencies.json
@@ -8,6 +8,9 @@
"postgres",
"ssh"
],
+ "apache.beam": [
+ "google"
+ ],
"apache.druid": [
"apache.hive"
],
@@ -30,6 +33,7 @@
],
"google": [
"amazon",
+ "apache.beam",
"apache.cassandra",
"cncf.kubernetes",
"facebook",
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py
index 0a665d4..0ad0262 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -19,23 +19,20 @@
import functools
import json
import re
-import select
import shlex
import subprocess
-import textwrap
import time
import uuid
import warnings
from copy import deepcopy
-from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Set, TypeVar, Union, cast
from googleapiclient.discovery import build
from airflow.exceptions import AirflowException
+from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.python_virtualenv import prepare_virtualenv
from airflow.utils.timeout import timeout
# This is the default location
@@ -50,6 +47,35 @@ JOB_ID_PATTERN = re.compile(
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
+def process_line_and_extract_dataflow_job_id_callback(
+ on_new_job_id_callback: Optional[Callable[[str], None]]
+) -> Callable[[str], None]:
+ """
+ Returns callback which triggers function passed as `on_new_job_id_callback` when Dataflow job_id is found.
+ To be used for `process_line_callback` in
+ :py:class:`~airflow.providers.apache.beam.hooks.beam.BeamCommandRunner`
+
+ :param on_new_job_id_callback: Callback called when the job ID is known
+ :type on_new_job_id_callback: callback
+ """
+
+ def _process_line_and_extract_job_id(
+ line: str,
+ # on_new_job_id_callback: Optional[Callable[[str], None]]
+ ) -> None:
+ # Job id info: https://goo.gl/SE29y9.
+ matched_job = JOB_ID_PATTERN.search(line)
+ if matched_job:
+ job_id = matched_job.group("job_id_java") or matched_job.group("job_id_python")
+ if on_new_job_id_callback:
+ on_new_job_id_callback(job_id)
+
+ def wrap(line: str):
+ return _process_line_and_extract_job_id(line)
+
+ return wrap
+
+
def _fallback_variable_parameter(parameter_name: str, variable_key_name: str) -> Callable[[T], T]:
def _wrapper(func: T) -> T:
"""
@@ -484,98 +510,6 @@ class _DataflowJobsController(LoggingMixin):
self.log.info("No jobs to cancel")
-class _DataflowRunner(LoggingMixin):
- def __init__(
- self,
- cmd: List[str],
- on_new_job_id_callback: Optional[Callable[[str], None]] = None,
- ) -> None:
- super().__init__()
- self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd))
- self.on_new_job_id_callback = on_new_job_id_callback
- self.job_id: Optional[str] = None
- self._proc = subprocess.Popen(
- cmd,
- shell=False,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- close_fds=True,
- )
-
- def _process_fd(self, fd):
- """
- Prints output to logs and lookup for job ID in each line.
-
- :param fd: File descriptor.
- """
- if fd == self._proc.stderr:
- while True:
- line = self._proc.stderr.readline().decode()
- if not line:
- return
- self._process_line_and_extract_job_id(line)
- self.log.warning(line.rstrip("\n"))
-
- if fd == self._proc.stdout:
- while True:
- line = self._proc.stdout.readline().decode()
- if not line:
- return
- self._process_line_and_extract_job_id(line)
- self.log.info(line.rstrip("\n"))
-
- raise Exception("No data in stderr or in stdout.")
-
- def _process_line_and_extract_job_id(self, line: str) -> None:
- """
- Extracts job_id.
-
- :param line: URL from which job_id has to be extracted
- :type line: str
- """
- # Job id info: https://goo.gl/SE29y9.
- matched_job = JOB_ID_PATTERN.search(line)
- if matched_job:
- job_id = matched_job.group("job_id_java") or matched_job.group("job_id_python")
- self.log.info("Found Job ID: %s", job_id)
- self.job_id = job_id
- if self.on_new_job_id_callback:
- self.on_new_job_id_callback(job_id)
-
- def wait_for_done(self) -> Optional[str]:
- """
- Waits for Dataflow job to complete.
-
- :return: Job id
- :rtype: Optional[str]
- """
- self.log.info("Start waiting for DataFlow process to complete.")
- self.job_id = None
- reads = [self._proc.stderr, self._proc.stdout]
- while True:
- # Wait for at least one available fd.
- readable_fds, _, _ = select.select(reads, [], [], 5)
- if readable_fds is None:
- self.log.info("Waiting for DataFlow process to complete.")
- continue
-
- for readable_fd in readable_fds:
- self._process_fd(readable_fd)
-
- if self._proc.poll() is not None:
- break
-
- # Corner case: check if more output was created between the last read and the process termination
- for readable_fd in reads:
- self._process_fd(readable_fd)
-
- self.log.info("Process exited with return code: %s", self._proc.returncode)
-
- if self._proc.returncode != 0:
- raise Exception(f"DataFlow failed with return code {self._proc.returncode}")
- return self.job_id
-
-
class DataflowHook(GoogleBaseHook):
"""
Hook for Google Dataflow.
@@ -598,6 +532,8 @@ class DataflowHook(GoogleBaseHook):
self.drain_pipeline = drain_pipeline
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
+ self.job_id: Optional[str] = None
+ self.beam_hook = BeamHook(BeamRunnerType.DataflowRunner)
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
@@ -609,40 +545,6 @@ class DataflowHook(GoogleBaseHook):
http_authorized = self._authorize()
return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)
- @GoogleBaseHook.provide_gcp_credential_file
- def _start_dataflow(
- self,
- variables: dict,
- name: str,
- command_prefix: List[str],
- project_id: str,
- multiple_jobs: bool = False,
- on_new_job_id_callback: Optional[Callable[[str], None]] = None,
- location: str = DEFAULT_DATAFLOW_LOCATION,
- ) -> None:
- cmd = command_prefix + [
- "--runner=DataflowRunner",
- f"--project={project_id}",
- ]
- if variables:
- cmd.extend(self._options_to_args(variables))
- runner = _DataflowRunner(cmd=cmd, on_new_job_id_callback=on_new_job_id_callback)
- job_id = runner.wait_for_done()
- job_controller = _DataflowJobsController(
- dataflow=self.get_conn(),
- project_number=project_id,
- name=name,
- location=location,
- poll_sleep=self.poll_sleep,
- job_id=job_id,
- num_retries=self.num_retries,
- multiple_jobs=multiple_jobs,
- drain_pipeline=self.drain_pipeline,
- cancel_timeout=self.cancel_timeout,
- wait_until_finished=self.wait_until_finished,
- )
- job_controller.wait_for_done()
-
@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
@@ -680,22 +582,36 @@ class DataflowHook(GoogleBaseHook):
:param location: Job location.
:type location: str
"""
- name = self._build_dataflow_job_name(job_name, append_job_name)
+ warnings.warn(
+ """"This method is deprecated.
+ Please use `airflow.providers.apache.beam.hooks.beam.start.start_java_pipeline`
+ to start pipeline and `providers.google.cloud.hooks.dataflow.DataflowHook.wait_for_done`
+ to wait for the required pipeline state.
+ """,
+ DeprecationWarning,
+ stacklevel=3,
+ )
+
+ name = self.build_dataflow_job_name(job_name, append_job_name)
+
variables["jobName"] = name
variables["region"] = location
+ variables["project"] = project_id
if "labels" in variables:
variables["labels"] = json.dumps(variables["labels"], separators=(",", ":"))
- command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar]
- self._start_dataflow(
+ self.beam_hook.start_java_pipeline(
variables=variables,
- name=name,
- command_prefix=command_prefix,
- project_id=project_id,
- multiple_jobs=multiple_jobs,
- on_new_job_id_callback=on_new_job_id_callback,
+ jar=jar,
+ job_class=job_class,
+ process_line_callback=process_line_and_extract_dataflow_job_id_callback(on_new_job_id_callback),
+ )
+ self.wait_for_done( # pylint: disable=no-value-for-parameter
+ job_name=name,
location=location,
+ job_id=self.job_id,
+ multiple_jobs=multiple_jobs,
)
@_fallback_to_location_from_variables
@@ -748,7 +664,7 @@ class DataflowHook(GoogleBaseHook):
:type environment: Optional[dict]
"""
- name = self._build_dataflow_job_name(job_name, append_job_name)
+ name = self.build_dataflow_job_name(job_name, append_job_name)
environment = environment or {}
# available keys for runtime environment are listed here:
@@ -921,58 +837,40 @@ class DataflowHook(GoogleBaseHook):
:param location: Job location.
:type location: str
"""
- name = self._build_dataflow_job_name(job_name, append_job_name)
+ warnings.warn(
+ """This method is deprecated.
+ Please use `airflow.providers.apache.beam.hooks.beam.start.start_python_pipeline`
+ to start pipeline and `providers.google.cloud.hooks.dataflow.DataflowHook.wait_for_done`
+ to wait for the required pipeline state.
+ """,
+ DeprecationWarning,
+ stacklevel=3,
+ )
+
+ name = self.build_dataflow_job_name(job_name, append_job_name)
variables["job_name"] = name
variables["region"] = location
+ variables["project"] = project_id
- if "labels" in variables:
- variables["labels"] = [f"{key}={value}" for key, value in variables["labels"].items()]
-
- if py_requirements is not None:
- if not py_requirements and not py_system_site_packages:
- warning_invalid_environment = textwrap.dedent(
- """\
- Invalid method invocation. You have disabled inclusion of system packages and empty list
- required for installation, so it is not possible to create a valid virtual environment.
- In the virtual environment, apache-beam package must be installed for your job to be \
- executed. To fix this problem:
- * install apache-beam on the system, then set parameter py_system_site_packages to True,
- * add apache-beam to the list of required packages in parameter py_requirements.
- """
- )
- raise AirflowException(warning_invalid_environment)
-
- with TemporaryDirectory(prefix="dataflow-venv") as tmp_dir:
- py_interpreter = prepare_virtualenv(
- venv_directory=tmp_dir,
- python_bin=py_interpreter,
- system_site_packages=py_system_site_packages,
- requirements=py_requirements,
- )
- command_prefix = [py_interpreter] + py_options + [dataflow]
-
- self._start_dataflow(
- variables=variables,
- name=name,
- command_prefix=command_prefix,
- project_id=project_id,
- on_new_job_id_callback=on_new_job_id_callback,
- location=location,
- )
- else:
- command_prefix = [py_interpreter] + py_options + [dataflow]
-
- self._start_dataflow(
- variables=variables,
- name=name,
- command_prefix=command_prefix,
- project_id=project_id,
- on_new_job_id_callback=on_new_job_id_callback,
- location=location,
- )
+ self.beam_hook.start_python_pipeline(
+ variables=variables,
+ py_file=dataflow,
+ py_options=py_options,
+ py_interpreter=py_interpreter,
+ py_requirements=py_requirements,
+ py_system_site_packages=py_system_site_packages,
+ process_line_callback=process_line_and_extract_dataflow_job_id_callback(on_new_job_id_callback),
+ )
+
+ self.wait_for_done( # pylint: disable=no-value-for-parameter
+ job_name=name,
+ location=location,
+ job_id=self.job_id,
+ )
@staticmethod
- def _build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str:
+ def build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str:
+ """Builds Dataflow job name."""
base_job_name = str(job_name).replace("_", "-")
if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", base_job_name):
@@ -989,23 +887,6 @@ class DataflowHook(GoogleBaseHook):
return safe_job_name
- @staticmethod
- def _options_to_args(variables: dict) -> List[str]:
- if not variables:
- return []
- # The logic of this method should be compatible with Apache Beam:
- # https://github.com/apache/beam/blob/b56740f0e8cd80c2873412847d0b336837429fb9/sdks/python/
- # apache_beam/options/pipeline_options.py#L230-L251
- args: List[str] = []
- for attr, value in variables.items():
- if value is None or (isinstance(value, bool) and value):
- args.append(f"--{attr}")
- elif isinstance(value, list):
- args.extend([f"--{attr}={v}" for v in value])
- else:
- args.append(f"--{attr}={value}")
- return args
-
@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
@@ -1125,7 +1006,7 @@ class DataflowHook(GoogleBaseHook):
"--format=value(job.id)",
f"--job-name={job_name}",
f"--region={location}",
- *(self._options_to_args(options)),
+ *(beam_options_to_args(options)),
]
self.log.info("Executing command: %s", " ".join([shlex.quote(c) for c in cmd]))
with self.provide_authorized_gcloud():
@@ -1266,3 +1147,44 @@ class DataflowHook(GoogleBaseHook):
location=location,
)
return jobs_controller.fetch_job_autoscaling_events_by_id(job_id)
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def wait_for_done(
+ self,
+ job_name: str,
+ location: str,
+ project_id: str,
+ job_id: Optional[str] = None,
+ multiple_jobs: bool = False,
+ ) -> None:
+ """
+ Wait for Dataflow job.
+
+ :param job_name: The 'jobName' to use when executing the DataFlow job
+ (templated). This ends up being set in the pipeline options, so any entry
+ with key ``'jobName'`` in ``options`` will be overwritten.
+ :type job_name: str
+ :param location: location the job is running
+ :type location: str
+ :param project_id: Optional, the Google Cloud project ID in which to start a job.
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
+ :type project_id:
+ :param job_id: a Dataflow job ID
+ :type job_id: str
+ :param multiple_jobs: If pipeline creates multiple jobs then monitor all jobs
+ :type multiple_jobs: boolean
+ """
+ job_controller = _DataflowJobsController(
+ dataflow=self.get_conn(),
+ project_number=project_id,
+ name=job_name,
+ location=location,
+ poll_sleep=self.poll_sleep,
+ job_id=job_id or self.job_id,
+ num_retries=self.num_retries,
+ multiple_jobs=multiple_jobs,
+ drain_pipeline=self.drain_pipeline,
+ cancel_timeout=self.cancel_timeout,
+ wait_until_finished=self.wait_until_finished,
+ )
+ job_controller.wait_for_done()
diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py
index 49863dc..f977704 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -16,15 +16,20 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains Google Dataflow operators."""
-
import copy
import re
+import warnings
from contextlib import ExitStack
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Union
from airflow.models import BaseOperator
-from airflow.providers.google.cloud.hooks.dataflow import DEFAULT_DATAFLOW_LOCATION, DataflowHook
+from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
+from airflow.providers.google.cloud.hooks.dataflow import (
+ DEFAULT_DATAFLOW_LOCATION,
+ DataflowHook,
+ process_line_and_extract_dataflow_job_id_callback,
+)
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.utils.decorators import apply_defaults
from airflow.version import version
@@ -43,12 +48,137 @@ class CheckJobRunning(Enum):
WaitForRun = 3
+class DataflowConfiguration:
+ """Dataflow configuration that can be passed to
+ :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator` and
+ :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`.
+
+ :param job_name: The 'jobName' to use when executing the DataFlow job
+ (templated). This ends up being set in the pipeline options, so any entry
+ with key ``'jobName'`` or ``'job_name'``in ``options`` will be overwritten.
+ :type job_name: str
+ :param append_job_name: True if unique suffix has to be appended to job name.
+ :type append_job_name: bool
+ :param project_id: Optional, the Google Cloud project ID in which to start a job.
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
+ :type project_id: str
+ :param location: Job location.
+ :type location: str
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :type gcp_conn_id: str
+ :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+ if any. For this to work, the service account making the request must have
+ domain-wide delegation enabled.
+ :type delegate_to: str
+ :param poll_sleep: The time in seconds to sleep between polling Google
+ Cloud Platform for the dataflow job status while the job is in the
+ JOB_STATE_RUNNING state.
+ :type poll_sleep: int
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ :type impersonation_chain: Union[str, Sequence[str]]
+ :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
+ instead of canceling during during killing task instance. See:
+ https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+ :type drain_pipeline: bool
+ :param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be
+ successfully cancelled when task is being killed.
+ :type cancel_timeout: Optional[int]
+ :param wait_until_finished: (Optional)
+ If True, wait for the end of pipeline execution before exiting.
+ If False, only submits job.
+ If None, default behavior.
+
+ The default behavior depends on the type of pipeline:
+
+ * for the streaming pipeline, wait for jobs to start,
+ * for the batch pipeline, wait for the jobs to complete.
+
+ .. warning::
+
+ You cannot call ``PipelineResult.wait_until_finish`` method in your pipeline code for the operator
+ to work properly. i. e. you must use asynchronous execution. Otherwise, your pipeline will
+ always wait until finished. For more information, look at:
+ `Asynchronous execution
+ <https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#python_10>`__
+
+ The process of starting the Dataflow job in Airflow consists of two steps:
+
+ * running a subprocess and reading the stderr/stderr log for the job id.
+ * loop waiting for the end of the job ID from the previous step.
+ This loop checks the status of the job.
+
+ Step two is started just after step one has finished, so if you have wait_until_finished in your
+ pipeline code, step two will not start until the process stops. When this process stops,
+ steps two will run, but it will only execute one iteration as the job will be in a terminal state.
+
+ If you in your pipeline do not call the wait_for_pipeline method but pass wait_until_finish=True
+ to the operator, the second loop will wait for the job's terminal state.
+
+ If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False
+ to the operator, the second loop will check once is job not in terminal state and exit the loop.
+ :type wait_until_finished: Optional[bool]
+ :param multiple_jobs: If pipeline creates multiple jobs then monitor all jobs. Supported only by
+ :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+ :type multiple_jobs: boolean
+ :param check_if_running: Before running job, validate that a previous run is not in process.
+ IgnoreJob = do not check if running.
+ FinishIfRunning = if job is running finish with nothing.
+ WaitForRun = wait until job finished and the run job.
+ Supported only by:
+ :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+ :type check_if_running: CheckJobRunning
+ """
+
+ template_fields = ["job_name", "location"]
+
+ def __init__(
+ self,
+ *,
+ job_name: Optional[str] = "{{task.task_id}}",
+ append_job_name: bool = True,
+ project_id: Optional[str] = None,
+ location: Optional[str] = DEFAULT_DATAFLOW_LOCATION,
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: Optional[str] = None,
+ poll_sleep: int = 10,
+ impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ drain_pipeline: bool = False,
+ cancel_timeout: Optional[int] = 5 * 60,
+ wait_until_finished: Optional[bool] = None,
+ multiple_jobs: Optional[bool] = None,
+ check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun,
+ ) -> None:
+ self.job_name = job_name
+ self.append_job_name = append_job_name
+ self.project_id = project_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+ self.poll_sleep = poll_sleep
+ self.impersonation_chain = impersonation_chain
+ self.drain_pipeline = drain_pipeline
+ self.cancel_timeout = cancel_timeout
+ self.wait_until_finished = wait_until_finished
+ self.multiple_jobs = multiple_jobs
+ self.check_if_running = check_if_running
+
+
# pylint: disable=too-many-instance-attributes
class DataflowCreateJavaJobOperator(BaseOperator):
"""
Start a Java Cloud DataFlow batch job. The parameters of the operation
will be passed to the job.
+ This class is deprecated.
+ Please use `providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`.
+
**Example**: ::
default_args = {
@@ -235,6 +365,14 @@ class DataflowCreateJavaJobOperator(BaseOperator):
wait_until_finished: Optional[bool] = None,
**kwargs,
) -> None:
+ # TODO: Remove one day
+ warnings.warn(
+ "The `{cls}` operator is deprecated, please use "
+ "`providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator` instead."
+ "".format(cls=self.__class__.__name__),
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__(**kwargs)
dataflow_default_options = dataflow_default_options or {}
@@ -257,62 +395,83 @@ class DataflowCreateJavaJobOperator(BaseOperator):
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
self.job_id = None
- self.hook = None
+ self.beam_hook: Optional[BeamHook] = None
+ self.dataflow_hook: Optional[DataflowHook] = None
def execute(self, context):
- self.hook = DataflowHook(
+ """Execute the Apache Beam Pipeline."""
+ self.beam_hook = BeamHook(runner=BeamRunnerType.DataflowRunner)
+ self.dataflow_hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
)
- dataflow_options = copy.copy(self.dataflow_default_options)
- dataflow_options.update(self.options)
- is_running = False
- if self.check_if_running != CheckJobRunning.IgnoreJob:
- is_running = self.hook.is_job_dataflow_running( # type: ignore[attr-defined]
- name=self.job_name,
- variables=dataflow_options,
- project_id=self.project_id,
- location=self.location,
- )
- while is_running and self.check_if_running == CheckJobRunning.WaitForRun:
- is_running = self.hook.is_job_dataflow_running( # type: ignore[attr-defined]
- name=self.job_name,
- variables=dataflow_options,
- project_id=self.project_id,
- location=self.location,
- )
+ job_name = self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name)
+ pipeline_options = copy.deepcopy(self.dataflow_default_options)
+
+ pipeline_options["jobName"] = self.job_name
+ pipeline_options["project"] = self.project_id or self.dataflow_hook.project_id
+ pipeline_options["region"] = self.location
+ pipeline_options.update(self.options)
+ pipeline_options.setdefault("labels", {}).update(
+ {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
+ )
+ pipeline_options.update(self.options)
- if not is_running:
- with ExitStack() as exit_stack:
- if self.jar.lower().startswith("gs://"):
- gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
- tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member
- gcs_hook.provide_file(object_url=self.jar)
- )
- self.jar = tmp_gcs_file.name
-
- def set_current_job_id(job_id):
- self.job_id = job_id
-
- self.hook.start_java_dataflow( # type: ignore[attr-defined]
- job_name=self.job_name,
- variables=dataflow_options,
- jar=self.jar,
- job_class=self.job_class,
- append_job_name=True,
- multiple_jobs=self.multiple_jobs,
- on_new_job_id_callback=set_current_job_id,
- project_id=self.project_id,
- location=self.location,
+ def set_current_job_id(job_id):
+ self.job_id = job_id
+
+ process_line_callback = process_line_and_extract_dataflow_job_id_callback(
+ on_new_job_id_callback=set_current_job_id
+ )
+
+ with ExitStack() as exit_stack:
+ if self.jar.lower().startswith("gs://"):
+ gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
+ tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member
+ gcs_hook.provide_file(object_url=self.jar)
)
+ self.jar = tmp_gcs_file.name
+
+ is_running = False
+ if self.check_if_running != CheckJobRunning.IgnoreJob:
+ is_running = (
+ self.dataflow_hook.is_job_dataflow_running( # pylint: disable=no-value-for-parameter
+ name=self.job_name,
+ variables=pipeline_options,
+ )
+ )
+ while is_running and self.check_if_running == CheckJobRunning.WaitForRun:
+ # pylint: disable=no-value-for-parameter
+ is_running = self.dataflow_hook.is_job_dataflow_running(
+ name=self.job_name,
+ variables=pipeline_options,
+ )
+ if not is_running:
+ pipeline_options["jobName"] = job_name
+ self.beam_hook.start_java_pipeline(
+ variables=pipeline_options,
+ jar=self.jar,
+ job_class=self.job_class,
+ process_line_callback=process_line_callback,
+ )
+ self.dataflow_hook.wait_for_done( # pylint: disable=no-value-for-parameter
+ job_name=job_name,
+ location=self.location,
+ job_id=self.job_id,
+ multiple_jobs=self.multiple_jobs,
+ )
+
+ return {"job_id": self.job_id}
def on_kill(self) -> None:
self.log.info("On kill.")
if self.job_id:
- self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
+ self.dataflow_hook.cancel_job(
+ job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id
+ )
# pylint: disable=too-many-instance-attributes
@@ -760,6 +919,9 @@ class DataflowCreatePythonJobOperator(BaseOperator):
high-level options, for instances, project and zone information, which
apply to all dataflow operators in the DAG.
+ This class is deprecated.
+ Please use `providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`.
+
.. seealso::
For more detail on job submission have a look at the reference:
https://cloud.google.com/dataflow/pipelines/specifying-exec-params
@@ -886,7 +1048,14 @@ class DataflowCreatePythonJobOperator(BaseOperator):
wait_until_finished: Optional[bool] = None,
**kwargs,
) -> None:
-
+ # TODO: Remove one day
+ warnings.warn(
+ "The `{cls}` operator is deprecated, please use "
+ "`providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator` instead."
+ "".format(cls=self.__class__.__name__),
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__(**kwargs)
self.py_file = py_file
@@ -909,10 +1078,40 @@ class DataflowCreatePythonJobOperator(BaseOperator):
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
self.job_id = None
- self.hook: Optional[DataflowHook] = None
+ self.beam_hook: Optional[BeamHook] = None
+ self.dataflow_hook: Optional[DataflowHook] = None
def execute(self, context):
"""Execute the python dataflow job."""
+ self.beam_hook = BeamHook(runner=BeamRunnerType.DataflowRunner)
+ self.dataflow_hook = DataflowHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ poll_sleep=self.poll_sleep,
+ impersonation_chain=None,
+ drain_pipeline=self.drain_pipeline,
+ cancel_timeout=self.cancel_timeout,
+ wait_until_finished=self.wait_until_finished,
+ )
+
+ job_name = self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name)
+ pipeline_options = self.dataflow_default_options.copy()
+ pipeline_options["job_name"] = job_name
+ pipeline_options["project"] = self.project_id or self.dataflow_hook.project_id
+ pipeline_options["region"] = self.location
+ pipeline_options.update(self.options)
+
+ # Convert argument names from lowerCamelCase to snake case.
+ camel_to_snake = lambda name: re.sub(r"[A-Z]", lambda x: "_" + x.group(0).lower(), name)
+ formatted_pipeline_options = {camel_to_snake(key): pipeline_options[key] for key in pipeline_options}
+
+ def set_current_job_id(job_id):
+ self.job_id = job_id
+
+ process_line_callback = process_line_and_extract_dataflow_job_id_callback(
+ on_new_job_id_callback=set_current_job_id
+ )
+
with ExitStack() as exit_stack:
if self.py_file.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
@@ -921,38 +1120,28 @@ class DataflowCreatePythonJobOperator(BaseOperator):
)
self.py_file = tmp_gcs_file.name
- self.hook = DataflowHook(
- gcp_conn_id=self.gcp_conn_id,
- delegate_to=self.delegate_to,
- poll_sleep=self.poll_sleep,
- drain_pipeline=self.drain_pipeline,
- cancel_timeout=self.cancel_timeout,
- wait_until_finished=self.wait_until_finished,
- )
- dataflow_options = self.dataflow_default_options.copy()
- dataflow_options.update(self.options)
- # Convert argument names from lowerCamelCase to snake case.
- camel_to_snake = lambda name: re.sub(r"[A-Z]", lambda x: "_" + x.group(0).lower(), name)
- formatted_options = {camel_to_snake(key): dataflow_options[key] for key in dataflow_options}
-
- def set_current_job_id(job_id):
- self.job_id = job_id
-
- self.hook.start_python_dataflow( # type: ignore[attr-defined]
- job_name=self.job_name,
- variables=formatted_options,
- dataflow=self.py_file,
+ self.beam_hook.start_python_pipeline(
+ variables=formatted_pipeline_options,
+ py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
- on_new_job_id_callback=set_current_job_id,
- project_id=self.project_id,
+ process_line_callback=process_line_callback,
+ )
+
+ self.dataflow_hook.wait_for_done( # pylint: disable=no-value-for-parameter
+ job_name=job_name,
location=self.location,
+ job_id=self.job_id,
+ multiple_jobs=False,
)
- return {"job_id": self.job_id}
+
+ return {"job_id": self.job_id}
def on_kill(self) -> None:
self.log.info("On kill.")
if self.job_id:
- self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
+ self.dataflow_hook.cancel_job(
+ job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id
+ )
diff --git a/dev/provider_packages/copy_provider_package_sources.py b/dev/provider_packages/copy_provider_package_sources.py
index 1d10747..c7f75f5 100755
--- a/dev/provider_packages/copy_provider_package_sources.py
+++ b/dev/provider_packages/copy_provider_package_sources.py
@@ -703,6 +703,67 @@ class RefactorBackportPackages:
.rename("airflow.models.baseoperator")
)
+ def refactor_apache_beam_package(self):
+ r"""
+ Fixes to "apache_beam" providers package.
+
+ Copies some of the classes used from core Airflow to "common.utils" package of the
+ the provider and renames imports to use them from there. Note that in this case we also rename
+ the imports in the copied files.
+
+ For example we copy python_virtualenv.py, process_utils.py and change import as in example diff:
+
+ .. code-block:: diff
+
+ --- ./airflow/providers/apache/beam/common/utils/python_virtualenv.py
+ +++ ./airflow/providers/apache/beam/common/utils/python_virtualenv.py
+ @@ -21,7 +21,7 @@
+ \"\"\"
+ from typing import List, Optional
+
+ -from airflow.utils.process_utils import execute_in_subprocess
+ +from airflow.providers.apache.beam.common.utils.process_utils import execute_in_subprocess
+
+
+ def _generate_virtualenv_cmd(tmp_dir: str, python_bin: str, system_site_packages: bool)
+
+ """
+
+ def apache_beam_package_filter(node: LN, capture: Capture, filename: Filename) -> bool:
+ return filename.startswith("./airflow/providers/apache/beam")
+
+ os.makedirs(
+ os.path.join(get_target_providers_package_folder("apache.beam"), "common", "utils"), exist_ok=True
+ )
+ copyfile(
+ os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"),
+ os.path.join(
+ get_target_providers_package_folder("apache.beam"), "common", "utils", "__init__.py"
+ ),
+ )
+ copyfile(
+ os.path.join(get_source_airflow_folder(), "airflow", "utils", "python_virtualenv.py"),
+ os.path.join(
+ get_target_providers_package_folder("apache.beam"), "common", "utils", "python_virtualenv.py"
+ ),
+ )
+ copyfile(
+ os.path.join(get_source_airflow_folder(), "airflow", "utils", "process_utils.py"),
+ os.path.join(
+ get_target_providers_package_folder("apache.beam"), "common", "utils", "process_utils.py"
+ ),
+ )
+ (
+ self.qry.select_module("airflow.utils.python_virtualenv")
+ .filter(callback=apache_beam_package_filter)
+ .rename("airflow.providers.apache.beam.common.utils.python_virtualenv")
+ )
+ (
+ self.qry.select_module("airflow.utils.process_utils")
+ .filter(callback=apache_beam_package_filter)
+ .rename("airflow.providers.apache.beam.common.utils.process_utils")
+ )
+
def refactor_odbc_package(self):
"""
Fixes to "odbc" providers package.
@@ -760,6 +821,7 @@ class RefactorBackportPackages:
self.rename_deprecated_modules()
self.refactor_amazon_package()
self.refactor_google_package()
+ self.refactor_apache_beam_package()
self.refactor_elasticsearch_package()
self.refactor_odbc_package()
self.remove_tags()
diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py
index 322a57f..3cfc39f 100755
--- a/dev/provider_packages/prepare_provider_packages.py
+++ b/dev/provider_packages/prepare_provider_packages.py
@@ -790,8 +790,10 @@ def convert_git_changes_to_table(
f"`{message_without_backticks}`" if markdown else f"``{message_without_backticks}``",
)
)
- table = tabulate(table_data, headers=headers, tablefmt="pipe" if markdown else "rst")
header = ""
+ if not table_data:
+ return header
+ table = tabulate(table_data, headers=headers, tablefmt="pipe" if markdown else "rst")
if not markdown:
header += f"\n\n{print_version}\n" + "." * len(print_version) + "\n\n"
release_date = table_data[0][1]
diff --git a/docs/apache-airflow-providers-apache-beam/index.rst b/docs/apache-airflow-providers-apache-beam/index.rst
new file mode 100644
index 0000000..30718f9
--- /dev/null
+++ b/docs/apache-airflow-providers-apache-beam/index.rst
@@ -0,0 +1,36 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+``apache-airflow-providers-apache-beam``
+========================================
+
+Content
+-------
+
+.. toctree::
+ :maxdepth: 1
+ :caption: References
+
+ Python API <_api/airflow/providers/apache/beam/index>
+ PyPI Repository <https://pypi.org/project/apache-airflow-providers-apache-beam/>
+ Example DAGs <https://github.com/apache/airflow/tree/master/airflow/providers/apache/beam/example_dags>
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Guides
+
+ Operators <operators>
diff --git a/docs/apache-airflow-providers-apache-beam/operators.rst b/docs/apache-airflow-providers-apache-beam/operators.rst
new file mode 100644
index 0000000..3c1b2bd
--- /dev/null
+++ b/docs/apache-airflow-providers-apache-beam/operators.rst
@@ -0,0 +1,116 @@
+
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Apache Beam Operators
+=====================
+
+`Apache Beam <https://beam.apache.org/>`__ is an open source, unified model for defining both batch and
+streaming data-parallel processing pipelines. Using one of the open source Beam SDKs, you build a program
+that defines the pipeline. The pipeline is then executed by one of Beam’s supported distributed processing
+back-ends, which include Apache Flink, Apache Spark, and Google Cloud Dataflow.
+
+
+.. _howto/operator:BeamRunPythonPipelineOperator:
+
+Run Python Pipelines in Apache Beam
+===================================
+
+The ``py_file`` argument must be specified for
+:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`
+as it contains the pipeline to be executed by Beam. The Python file can be available on GCS that Airflow
+has the ability to download or available on the local filesystem (provide the absolute path to it).
+
+The ``py_interpreter`` argument specifies the Python version to be used when executing the pipeline, the default
+is ``python3`. If your Airflow instance is running on Python 2 - specify ``python2`` and ensure your ``py_file`` is
+in Python 2. For best results, use Python 3.
+
+If ``py_requirements`` argument is specified a temporary Python virtual environment with specified requirements will be created
+and within it pipeline will run.
+
+The ``py_system_site_packages`` argument specifies whether or not all the Python packages from your Airflow instance,
+will be accessible within virtual environment (if ``py_requirements`` argument is specified),
+recommend avoiding unless the Dataflow job requires it.
+
+Python Pipelines with DirectRunner
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_python_direct_runner_pipeline_local_file]
+ :end-before: [END howto_operator_start_python_direct_runner_pipeline_local_file]
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_python_direct_runner_pipeline_gcs_file]
+ :end-before: [END howto_operator_start_python_direct_runner_pipeline_gcs_file]
+
+Python Pipelines with DataflowRunner
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
+ :end-before: [END howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
+ :end-before: [END howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
+
+.. _howto/operator:BeamRunJavaPipelineOperator:
+
+Run Java Pipelines in Apache Beam
+=================================
+
+For Java pipeline the ``jar`` argument must be specified for
+:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+as it contains the pipeline to be executed by Apache Beam. The JAR can be available on GCS that Airflow
+has the ability to download or available on the local filesystem (provide the absolute path to it).
+
+Java Pipelines with DirectRunner
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_java_direct_runner_pipeline]
+ :end-before: [END howto_operator_start_java_direct_runner_pipeline
+
+Java Pipelines with DataflowRunner
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_java_dataflow_runner_pipeline]
+ :end-before: [END howto_operator_start_java_dataflow_runner_pipeline
+
+Reference
+^^^^^^^^^
+
+For further information, look at:
+
+* `Apache Beam Documentation <https://beam.apache.org/documentation/>`__
+* `Google Cloud API Documentation <https://cloud.google.com/dataflow/docs/apis>`__
+* `Product Documentation <https://cloud.google.com/dataflow/docs/>`__
+* `Dataflow Monitoring Interface <https://cloud.google.com/dataflow/docs/guides/using-monitoring-intf/>`__
+* `Dataflow Command-line Interface <https://cloud.google.com/dataflow/docs/guides/using-command-line-intf/>`__
diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst
index b2549ae..5221beb 100644
--- a/docs/apache-airflow/extra-packages-ref.rst
+++ b/docs/apache-airflow/extra-packages-ref.rst
@@ -107,6 +107,8 @@ custom bash/python providers).
+=====================+=====================================================+================================================+
| apache.atlas | ``pip install 'apache-airflow[apache.atlas]'`` | Apache Atlas |
+---------------------+-----------------------------------------------------+------------------------------------------------+
+| apache.beam | ``pip install 'apache-airflow[apache.beam]'`` | Apache Beam operators & hooks |
++---------------------+-----------------------------------------------------+------------------------------------------------+
| apache.cassandra | ``pip install 'apache-airflow[apache.cassandra]'`` | Cassandra related operators & hooks |
+---------------------+-----------------------------------------------------+------------------------------------------------+
| apache.druid | ``pip install 'apache-airflow[apache.druid]'`` | Druid related operators & hooks |
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index db4342a..f8f8f83 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -141,6 +141,7 @@ Fileshares
Filesystem
Firehose
Firestore
+Flink
FluentD
Fokko
Formaturas
@@ -325,6 +326,7 @@ Seki
Sendgrid
Siddharth
SlackHook
+Spark
SparkPi
SparkR
SparkSQL
diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh
index 969fa29..9b951c7 100755
--- a/scripts/in_container/run_install_and_test_provider_packages.sh
+++ b/scripts/in_container/run_install_and_test_provider_packages.sh
@@ -95,7 +95,7 @@ function discover_all_provider_packages() {
# Columns is to force it wider, so it doesn't wrap at 80 characters
COLUMNS=180 airflow providers list
- local expected_number_of_providers=62
+ local expected_number_of_providers=63
local actual_number_of_providers
actual_providers=$(airflow providers list --output yaml | grep package_name)
actual_number_of_providers=$(wc -l <<<"$actual_providers")
diff --git a/setup.py b/setup.py
index 210b12f..50f6a2f 100644
--- a/setup.py
+++ b/setup.py
@@ -523,6 +523,7 @@ devel_hadoop = devel_minreq + hdfs + hive + kerberos + presto + webhdfs
# Dict of all providers which are part of the Apache Airflow repository together with their requirements
PROVIDERS_REQUIREMENTS: Dict[str, List[str]] = {
'amazon': amazon,
+ 'apache.beam': apache_beam,
'apache.cassandra': cassandra,
'apache.druid': druid,
'apache.hdfs': hdfs,
diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py
index 7d80c58..39ee588 100644
--- a/tests/core/test_providers_manager.py
+++ b/tests/core/test_providers_manager.py
@@ -22,6 +22,7 @@ from airflow.providers_manager import ProvidersManager
ALL_PROVIDERS = [
'apache-airflow-providers-amazon',
+ 'apache-airflow-providers-apache-beam',
'apache-airflow-providers-apache-cassandra',
'apache-airflow-providers-apache-druid',
'apache-airflow-providers-apache-hdfs',
diff --git a/tests/providers/apache/beam/__init__.py b/tests/providers/apache/beam/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/providers/apache/beam/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/apache/beam/hooks/__init__.py b/tests/providers/apache/beam/hooks/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/providers/apache/beam/hooks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/apache/beam/hooks/test_beam.py b/tests/providers/apache/beam/hooks/test_beam.py
new file mode 100644
index 0000000..d0d713e
--- /dev/null
+++ b/tests/providers/apache/beam/hooks/test_beam.py
@@ -0,0 +1,271 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import copy
+import subprocess
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
+
+from parameterized import parameterized
+
+from airflow.exceptions import AirflowException
+from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner, BeamHook, beam_options_to_args
+
+PY_FILE = 'apache_beam.examples.wordcount'
+JAR_FILE = 'unitest.jar'
+JOB_CLASS = 'com.example.UnitTest'
+PY_OPTIONS = ['-m']
+TEST_JOB_ID = 'test-job-id'
+
+DEFAULT_RUNNER = "DirectRunner"
+BEAM_STRING = 'airflow.providers.apache.beam.hooks.beam.{}'
+BEAM_VARIABLES_PY = {'output': 'gs://test/output', 'labels': {'foo': 'bar'}}
+BEAM_VARIABLES_JAVA = {
+ 'output': 'gs://test/output',
+ 'labels': {'foo': 'bar'},
+}
+
+APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG = f""""\
+Dataflow SDK version: 2.14.0
+Jun 15, 2020 2:57:28 PM org.apache.beam.runners.dataflow.DataflowRunner run
+INFO: To access the Dataflow monitoring console, please navigate to https://console.cloud.google.com/dataflow\
+/jobsDetail/locations/europe-west3/jobs/{TEST_JOB_ID}?project=XXX
+Submitted job: {TEST_JOB_ID}
+Jun 15, 2020 2:57:28 PM org.apache.beam.runners.dataflow.DataflowRunner run
+INFO: To cancel the job using the 'gcloud' tool, run:
+> gcloud dataflow jobs --project=XXX cancel --region=europe-west3 {TEST_JOB_ID}
+"""
+
+
+class TestBeamHook(unittest.TestCase):
+ @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+ def test_start_python_pipeline(self, mock_runner):
+ hook = BeamHook(runner=DEFAULT_RUNNER)
+ wait_for_done = mock_runner.return_value.wait_for_done
+ process_line_callback = MagicMock()
+
+ hook.start_python_pipeline( # pylint: disable=no-value-for-parameter
+ variables=copy.deepcopy(BEAM_VARIABLES_PY),
+ py_file=PY_FILE,
+ py_options=PY_OPTIONS,
+ process_line_callback=process_line_callback,
+ )
+
+ expected_cmd = [
+ "python3",
+ '-m',
+ PY_FILE,
+ f'--runner={DEFAULT_RUNNER}',
+ '--output=gs://test/output',
+ '--labels=foo=bar',
+ ]
+ mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+ wait_for_done.assert_called_once_with()
+
+ @parameterized.expand(
+ [
+ ('default_to_python3', 'python3'),
+ ('major_version_2', 'python2'),
+ ('major_version_3', 'python3'),
+ ('minor_version', 'python3.6'),
+ ]
+ )
+ @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+ def test_start_python_pipeline_with_custom_interpreter(self, _, py_interpreter, mock_runner):
+ hook = BeamHook(runner=DEFAULT_RUNNER)
+ wait_for_done = mock_runner.return_value.wait_for_done
+ process_line_callback = MagicMock()
+
+ hook.start_python_pipeline( # pylint: disable=no-value-for-parameter
+ variables=copy.deepcopy(BEAM_VARIABLES_PY),
+ py_file=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_interpreter=py_interpreter,
+ process_line_callback=process_line_callback,
+ )
+
+ expected_cmd = [
+ py_interpreter,
+ '-m',
+ PY_FILE,
+ f'--runner={DEFAULT_RUNNER}',
+ '--output=gs://test/output',
+ '--labels=foo=bar',
+ ]
+ mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+ wait_for_done.assert_called_once_with()
+
+ @parameterized.expand(
+ [
+ (['foo-bar'], False),
+ (['foo-bar'], True),
+ ([], True),
+ ]
+ )
+ @mock.patch(BEAM_STRING.format('prepare_virtualenv'))
+ @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+ def test_start_python_pipeline_with_non_empty_py_requirements_and_without_system_packages(
+ self, current_py_requirements, current_py_system_site_packages, mock_runner, mock_virtualenv
+ ):
+ hook = BeamHook(runner=DEFAULT_RUNNER)
+ wait_for_done = mock_runner.return_value.wait_for_done
+ mock_virtualenv.return_value = '/dummy_dir/bin/python'
+ process_line_callback = MagicMock()
+
+ hook.start_python_pipeline( # pylint: disable=no-value-for-parameter
+ variables=copy.deepcopy(BEAM_VARIABLES_PY),
+ py_file=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_requirements=current_py_requirements,
+ py_system_site_packages=current_py_system_site_packages,
+ process_line_callback=process_line_callback,
+ )
+
+ expected_cmd = [
+ '/dummy_dir/bin/python',
+ '-m',
+ PY_FILE,
+ f'--runner={DEFAULT_RUNNER}',
+ '--output=gs://test/output',
+ '--labels=foo=bar',
+ ]
+ mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+ wait_for_done.assert_called_once_with()
+ mock_virtualenv.assert_called_once_with(
+ venv_directory=mock.ANY,
+ python_bin="python3",
+ system_site_packages=current_py_system_site_packages,
+ requirements=current_py_requirements,
+ )
+
+ @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+ def test_start_python_pipeline_with_empty_py_requirements_and_without_system_packages(self, mock_runner):
+ hook = BeamHook(runner=DEFAULT_RUNNER)
+ wait_for_done = mock_runner.return_value.wait_for_done
+ process_line_callback = MagicMock()
+
+ with self.assertRaisesRegex(AirflowException, "Invalid method invocation."):
+ hook.start_python_pipeline( # pylint: disable=no-value-for-parameter
+ variables=copy.deepcopy(BEAM_VARIABLES_PY),
+ py_file=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_requirements=[],
+ process_line_callback=process_line_callback,
+ )
+
+ mock_runner.assert_not_called()
+ wait_for_done.assert_not_called()
+
+ @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+ def test_start_java_pipeline(self, mock_runner):
+ hook = BeamHook(runner=DEFAULT_RUNNER)
+ wait_for_done = mock_runner.return_value.wait_for_done
+ process_line_callback = MagicMock()
+
+ hook.start_java_pipeline( # pylint: disable=no-value-for-parameter
+ jar=JAR_FILE,
+ variables=copy.deepcopy(BEAM_VARIABLES_JAVA),
+ process_line_callback=process_line_callback,
+ )
+
+ expected_cmd = [
+ 'java',
+ '-jar',
+ JAR_FILE,
+ f'--runner={DEFAULT_RUNNER}',
+ '--output=gs://test/output',
+ '--labels={"foo":"bar"}',
+ ]
+ mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+ wait_for_done.assert_called_once_with()
+
+ @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+ def test_start_java_pipeline_with_job_class(self, mock_runner):
+ hook = BeamHook(runner=DEFAULT_RUNNER)
+ wait_for_done = mock_runner.return_value.wait_for_done
+ process_line_callback = MagicMock()
+
+ hook.start_java_pipeline( # pylint: disable=no-value-for-parameter
+ jar=JAR_FILE,
+ variables=copy.deepcopy(BEAM_VARIABLES_JAVA),
+ job_class=JOB_CLASS,
+ process_line_callback=process_line_callback,
+ )
+
+ expected_cmd = [
+ 'java',
+ '-cp',
+ JAR_FILE,
+ JOB_CLASS,
+ f'--runner={DEFAULT_RUNNER}',
+ '--output=gs://test/output',
+ '--labels={"foo":"bar"}',
+ ]
+ mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+ wait_for_done.assert_called_once_with()
+
+
+class TestBeamRunner(unittest.TestCase):
+ @mock.patch('airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log')
+ @mock.patch('subprocess.Popen')
+ @mock.patch('select.select')
+ def test_beam_wait_for_done_logging(self, mock_select, mock_popen, mock_logging):
+ cmd = ['test', 'cmd']
+ mock_logging.info = MagicMock()
+ mock_logging.warning = MagicMock()
+ mock_proc = MagicMock()
+ mock_proc.stderr = MagicMock()
+ mock_proc.stderr.readlines = MagicMock(return_value=['test\n', 'error\n'])
+ mock_stderr_fd = MagicMock()
+ mock_proc.stderr.fileno = MagicMock(return_value=mock_stderr_fd)
+ mock_proc_poll = MagicMock()
+ mock_select.return_value = [[mock_stderr_fd]]
+
+ def poll_resp_error():
+ mock_proc.return_code = 1
+ return True
+
+ mock_proc_poll.side_effect = [None, poll_resp_error]
+ mock_proc.poll = mock_proc_poll
+ mock_popen.return_value = mock_proc
+ beam = BeamCommandRunner(cmd)
+ mock_logging.info.assert_called_once_with('Running command: %s', " ".join(cmd))
+ mock_popen.assert_called_once_with(
+ cmd,
+ shell=False,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ close_fds=True,
+ )
+ self.assertRaises(Exception, beam.wait_for_done)
+
+
+class TestBeamOptionsToArgs(unittest.TestCase):
+ @parameterized.expand(
+ [
+ ({"key": "val"}, ["--key=val"]),
+ ({"key": None}, ["--key"]),
+ ({"key": True}, ["--key"]),
+ ({"key": False}, ["--key=False"]),
+ ({"key": ["a", "b", "c"]}, ["--key=a", "--key=b", "--key=c"]),
+ ]
+ )
+ def test_beam_options_to_args(self, options, expected_args):
+ args = beam_options_to_args(options)
+ assert args == expected_args
diff --git a/tests/providers/apache/beam/operators/__init__.py b/tests/providers/apache/beam/operators/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/providers/apache/beam/operators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py
new file mode 100644
index 0000000..c31ff33
--- /dev/null
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import unittest
+from unittest import mock
+
+from airflow.providers.apache.beam.operators.beam import (
+ BeamRunJavaPipelineOperator,
+ BeamRunPythonPipelineOperator,
+)
+from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration
+from airflow.version import version
+
+TASK_ID = 'test-beam-operator'
+DEFAULT_RUNNER = "DirectRunner"
+JOB_NAME = 'test-dataflow-pipeline-name'
+JOB_ID = 'test-dataflow-pipeline-id'
+JAR_FILE = 'gs://my-bucket/example/test.jar'
+JOB_CLASS = 'com.test.NotMain'
+PY_FILE = 'gs://my-bucket/my-object.py'
+PY_INTERPRETER = 'python3'
+PY_OPTIONS = ['-m']
+DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
+ 'project': 'test',
+ 'stagingLocation': 'gs://test/staging',
+}
+ADDITIONAL_OPTIONS = {'output': 'gs://test/output', 'labels': {'foo': 'bar'}}
+TEST_VERSION = f"v{version.replace('.', '-').replace('+', '-')}"
+EXPECTED_ADDITIONAL_OPTIONS = {
+ 'output': 'gs://test/output',
+ 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+}
+
+
+class TestBeamRunPythonPipelineOperator(unittest.TestCase):
+ def setUp(self):
+ self.operator = BeamRunPythonPipelineOperator(
+ task_id=TASK_ID,
+ py_file=PY_FILE,
+ py_options=PY_OPTIONS,
+ default_pipeline_options=DEFAULT_OPTIONS_PYTHON,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ )
+
+ def test_init(self):
+ """Test BeamRunPythonPipelineOperator instance is properly initialized."""
+ self.assertEqual(self.operator.task_id, TASK_ID)
+ self.assertEqual(self.operator.py_file, PY_FILE)
+ self.assertEqual(self.operator.runner, DEFAULT_RUNNER)
+ self.assertEqual(self.operator.py_options, PY_OPTIONS)
+ self.assertEqual(self.operator.py_interpreter, PY_INTERPRETER)
+ self.assertEqual(self.operator.default_pipeline_options, DEFAULT_OPTIONS_PYTHON)
+ self.assertEqual(self.operator.pipeline_options, EXPECTED_ADDITIONAL_OPTIONS)
+
+ @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+ def test_exec_direct_runner(self, gcs_hook, beam_hook_mock):
+ """Test BeamHook is created and the right args are passed to
+ start_python_workflow.
+ """
+ start_python_hook = beam_hook_mock.return_value.start_python_pipeline
+ gcs_provide_file = gcs_hook.return_value.provide_file
+ self.operator.execute(None)
+ beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
+ expected_options = {
+ 'project': 'test',
+ 'staging_location': 'gs://test/staging',
+ 'output': 'gs://test/output',
+ 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+ }
+ gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+ start_python_hook.assert_called_once_with(
+ variables=expected_options,
+ py_file=gcs_provide_file.return_value.__enter__.return_value.name,
+ py_options=PY_OPTIONS,
+ py_interpreter=PY_INTERPRETER,
+ py_requirements=None,
+ py_system_site_packages=False,
+ process_line_callback=None,
+ )
+
+ @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock):
+ """Test DataflowHook is created and the right args are passed to
+ start_python_dataflow.
+ """
+ dataflow_config = DataflowConfiguration()
+ self.operator.runner = "DataflowRunner"
+ self.operator.dataflow_config = dataflow_config
+ gcs_provide_file = gcs_hook.return_value.provide_file
+ self.operator.execute(None)
+ job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
+ dataflow_hook_mock.assert_called_once_with(
+ gcp_conn_id=dataflow_config.gcp_conn_id,
+ delegate_to=dataflow_config.delegate_to,
+ poll_sleep=dataflow_config.poll_sleep,
+ impersonation_chain=dataflow_config.impersonation_chain,
+ drain_pipeline=dataflow_config.drain_pipeline,
+ cancel_timeout=dataflow_config.cancel_timeout,
+ wait_until_finished=dataflow_config.wait_until_finished,
+ )
+ expected_options = {
+ 'project': dataflow_hook_mock.return_value.project_id,
+ 'job_name': job_name,
+ 'staging_location': 'gs://test/staging',
+ 'output': 'gs://test/output',
+ 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+ 'region': 'us-central1',
+ }
+ gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+ beam_hook_mock.return_value.start_python_pipeline.assert_called_once_with(
+ variables=expected_options,
+ py_file=gcs_provide_file.return_value.__enter__.return_value.name,
+ py_options=PY_OPTIONS,
+ py_interpreter=PY_INTERPRETER,
+ py_requirements=None,
+ py_system_site_packages=False,
+ process_line_callback=mock.ANY,
+ )
+ dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+ job_id=self.operator.dataflow_job_id,
+ job_name=job_name,
+ location='us-central1',
+ multiple_jobs=False,
+ )
+
+ @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __):
+ self.operator.runner = "DataflowRunner"
+ dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
+ self.operator.execute(None)
+ self.operator.dataflow_job_id = JOB_ID
+ self.operator.on_kill()
+ dataflow_cancel_job.assert_called_once_with(
+ job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
+ )
+
+ @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+ def test_on_kill_direct_runner(self, _, dataflow_mock, __):
+ dataflow_cancel_job = dataflow_mock.return_value.cancel_job
+ self.operator.execute(None)
+ self.operator.on_kill()
+ dataflow_cancel_job.assert_not_called()
+
+
+class TestBeamRunJavaPipelineOperator(unittest.TestCase):
+ def setUp(self):
+ self.operator = BeamRunJavaPipelineOperator(
+ task_id=TASK_ID,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ default_pipeline_options=DEFAULT_OPTIONS_JAVA,
+ pipeline_options=ADDITIONAL_OPTIONS,
+ )
+
+ def test_init(self):
+ """Test BeamRunJavaPipelineOperator instance is properly initialized."""
+ self.assertEqual(self.operator.task_id, TASK_ID)
+ self.assertEqual(self.operator.runner, DEFAULT_RUNNER)
+ self.assertEqual(self.operator.default_pipeline_options, DEFAULT_OPTIONS_JAVA)
+ self.assertEqual(self.operator.job_class, JOB_CLASS)
+ self.assertEqual(self.operator.jar, JAR_FILE)
+ self.assertEqual(self.operator.pipeline_options, ADDITIONAL_OPTIONS)
+
+ @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+ def test_exec_direct_runner(self, gcs_hook, beam_hook_mock):
+ """Test BeamHook is created and the right args are passed to
+ start_java_workflow.
+ """
+ start_java_hook = beam_hook_mock.return_value.start_java_pipeline
+ gcs_provide_file = gcs_hook.return_value.provide_file
+ self.operator.execute(None)
+
+ beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
+ gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
+ start_java_hook.assert_called_once_with(
+ variables={**DEFAULT_OPTIONS_JAVA, **ADDITIONAL_OPTIONS},
+ jar=gcs_provide_file.return_value.__enter__.return_value.name,
+ job_class=JOB_CLASS,
+ process_line_callback=None,
+ )
+
+ @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock):
+ """Test DataflowHook is created and the right args are passed to
+ start_java_dataflow.
+ """
+ dataflow_config = DataflowConfiguration()
+ self.operator.runner = "DataflowRunner"
+ self.operator.dataflow_config = dataflow_config
+ gcs_provide_file = gcs_hook.return_value.provide_file
+ dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
+ self.operator.execute(None)
+ job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
+ self.assertEqual(job_name, self.operator._dataflow_job_name)
+ dataflow_hook_mock.assert_called_once_with(
+ gcp_conn_id=dataflow_config.gcp_conn_id,
+ delegate_to=dataflow_config.delegate_to,
+ poll_sleep=dataflow_config.poll_sleep,
+ impersonation_chain=dataflow_config.impersonation_chain,
+ drain_pipeline=dataflow_config.drain_pipeline,
+ cancel_timeout=dataflow_config.cancel_timeout,
+ wait_until_finished=dataflow_config.wait_until_finished,
+ )
+ gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
+
+ expected_options = {
+ 'project': dataflow_hook_mock.return_value.project_id,
+ 'jobName': job_name,
+ 'stagingLocation': 'gs://test/staging',
+ 'region': 'us-central1',
+ 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+ 'output': 'gs://test/output',
+ }
+
+ beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with(
+ variables=expected_options,
+ jar=gcs_provide_file.return_value.__enter__.return_value.name,
+ job_class=JOB_CLASS,
+ process_line_callback=mock.ANY,
+ )
+ dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+ job_id=self.operator.dataflow_job_id,
+ job_name=job_name,
+ location='us-central1',
+ multiple_jobs=dataflow_config.multiple_jobs,
+ project_id=dataflow_hook_mock.return_value.project_id,
+ )
+
+ @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __):
+ self.operator.runner = "DataflowRunner"
+ dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
+ dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
+ self.operator.execute(None)
+ self.operator.dataflow_job_id = JOB_ID
+ self.operator.on_kill()
+ dataflow_cancel_job.assert_called_once_with(
+ job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
+ )
+
+ @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+ @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+ def test_on_kill_direct_runner(self, _, dataflow_mock, __):
+ dataflow_cancel_job = dataflow_mock.return_value.cancel_job
+ self.operator.execute(None)
+ self.operator.on_kill()
+ dataflow_cancel_job.assert_not_called()
diff --git a/tests/providers/apache/beam/operators/test_beam_system.py b/tests/providers/apache/beam/operators/test_beam_system.py
new file mode 100644
index 0000000..0798f35
--- /dev/null
+++ b/tests/providers/apache/beam/operators/test_beam_system.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+
+import pytest
+
+from tests.test_utils import AIRFLOW_MAIN_FOLDER
+from tests.test_utils.system_tests_class import SystemTest
+
+BEAM_DAG_FOLDER = os.path.join(AIRFLOW_MAIN_FOLDER, "airflow", "providers", "apache", "beam", "example_dags")
+
+
+@pytest.mark.system("apache.beam")
+class BeamExampleDagsSystemTest(SystemTest):
+ def test_run_example_dag_beam_python(self):
+ self.run_dag('example_beam_native_python', BEAM_DAG_FOLDER)
+
+ def test_run_example_dag_beam_python_dataflow_async(self):
+ self.run_dag('example_beam_native_python_dataflow_async', BEAM_DAG_FOLDER)
+
+ def test_run_example_dag_beam_java_direct_runner(self):
+ self.run_dag('example_beam_native_java_direct_runner', BEAM_DAG_FOLDER)
+
+ def test_run_example_dag_beam_java_dataflow_runner(self):
+ self.run_dag('example_beam_native_java_dataflow_runner', BEAM_DAG_FOLDER)
+
+ def test_run_example_dag_beam_java_spark_runner(self):
+ self.run_dag('example_beam_native_java_spark_runner', BEAM_DAG_FOLDER)
+
+ def test_run_example_dag_beam_java_flink_runner(self):
+ self.run_dag('example_beam_native_java_flink_runner', BEAM_DAG_FOLDER)
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py
index 5297b30..c0da030 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -30,16 +30,20 @@ import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
+from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner, BeamHook
from airflow.providers.google.cloud.hooks.dataflow import (
DEFAULT_DATAFLOW_LOCATION,
DataflowHook,
DataflowJobStatus,
DataflowJobType,
_DataflowJobsController,
- _DataflowRunner,
_fallback_to_project_id_from_variables,
+ process_line_and_extract_dataflow_job_id_callback,
)
+DEFAULT_RUNNER = "DirectRunner"
+BEAM_STRING = 'airflow.providers.apache.beam.hooks.beam.{}'
+
TASK_ID = 'test-dataflow-operator'
JOB_NAME = 'test-dataflow-pipeline'
MOCK_UUID = UUID('cf4a56d2-8101-4217-b027-2af6216feb48')
@@ -183,6 +187,7 @@ class TestDataflowHook(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init):
self.dataflow_hook = DataflowHook(gcp_conn_id='test')
+ self.dataflow_hook.beam_hook = MagicMock()
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize")
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.build")
@@ -194,186 +199,229 @@ class TestDataflowHook(unittest.TestCase):
assert mock_build.return_value == result
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
- def test_start_python_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
+ def test_start_python_dataflow(self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid):
+ mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME,
- variables=DATAFLOW_VARIABLES_PY,
- dataflow=PY_FILE,
+ on_new_job_id_callback = MagicMock()
+ py_requirements = ["pands", "numpy"]
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=DATAFLOW_VARIABLES_PY,
+ dataflow=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
+ py_requirements=py_requirements,
+ on_new_job_id_callback=on_new_job_id_callback,
+ )
+
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+ expected_variables["job_name"] = job_name
+ expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_python_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ py_file=PY_FILE,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
py_options=PY_OPTIONS,
+ py_requirements=py_requirements,
+ py_system_site_packages=False,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
)
- expected_cmd = [
- "python3",
- '-m',
- PY_FILE,
- '--region=us-central1',
- '--runner=DataflowRunner',
- '--project=test',
- '--labels=foo=bar',
- '--staging_location=gs://test/staging',
- f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
def test_start_python_dataflow_with_custom_region_as_variable(
- self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+ self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
):
+ mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
- variables['region'] = TEST_LOCATION
- self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME,
- variables=variables,
- dataflow=PY_FILE,
+ on_new_job_id_callback = MagicMock()
+ py_requirements = ["pands", "numpy"]
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+ passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+ passed_variables["region"] = TEST_LOCATION
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=passed_variables,
+ dataflow=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
+ py_requirements=py_requirements,
+ on_new_job_id_callback=on_new_job_id_callback,
+ )
+
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+ expected_variables["job_name"] = job_name
+ expected_variables["region"] = TEST_LOCATION
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_python_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ py_file=PY_FILE,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
py_options=PY_OPTIONS,
+ py_requirements=py_requirements,
+ py_system_site_packages=False,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION
)
- expected_cmd = [
- "python3",
- '-m',
- PY_FILE,
- f'--region={TEST_LOCATION}',
- '--runner=DataflowRunner',
- '--project=test',
- '--labels=foo=bar',
- '--staging_location=gs://test/staging',
- f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
def test_start_python_dataflow_with_custom_region_as_parameter(
- self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+ self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
):
+ mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME,
- variables=DATAFLOW_VARIABLES_PY,
- dataflow=PY_FILE,
+ on_new_job_id_callback = MagicMock()
+ py_requirements = ["pands", "numpy"]
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+ passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=passed_variables,
+ dataflow=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
+ py_requirements=py_requirements,
+ on_new_job_id_callback=on_new_job_id_callback,
+ location=TEST_LOCATION,
+ )
+
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+ expected_variables["job_name"] = job_name
+ expected_variables["region"] = TEST_LOCATION
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_python_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ py_file=PY_FILE,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
py_options=PY_OPTIONS,
- location=TEST_LOCATION,
+ py_requirements=py_requirements,
+ py_system_site_packages=False,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION
)
- expected_cmd = [
- "python3",
- '-m',
- PY_FILE,
- f'--region={TEST_LOCATION}',
- '--runner=DataflowRunner',
- '--project=test',
- '--labels=foo=bar',
- '--staging_location=gs://test/staging',
- f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
def test_start_python_dataflow_with_multiple_extra_packages(
- self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+ self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
):
+ mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_PY)
- variables['extra-package'] = ['a.whl', 'b.whl']
+ on_new_job_id_callback = MagicMock()
+ py_requirements = ["pands", "numpy"]
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
- self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME,
- variables=variables,
- dataflow=PY_FILE,
+ passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+ passed_variables['extra-package'] = ['a.whl', 'b.whl']
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=passed_variables,
+ dataflow=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
+ py_requirements=py_requirements,
+ on_new_job_id_callback=on_new_job_id_callback,
+ )
+
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+ expected_variables["job_name"] = job_name
+ expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+ expected_variables['extra-package'] = ['a.whl', 'b.whl']
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_python_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ py_file=PY_FILE,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
py_options=PY_OPTIONS,
+ py_requirements=py_requirements,
+ py_system_site_packages=False,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
)
- expected_cmd = [
- "python3",
- '-m',
- PY_FILE,
- '--extra-package=a.whl',
- '--extra-package=b.whl',
- '--region=us-central1',
- '--runner=DataflowRunner',
- '--project=test',
- '--labels=foo=bar',
- '--staging_location=gs://test/staging',
- f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@parameterized.expand(
[
- ('default_to_python3', 'python3'),
- ('major_version_2', 'python2'),
- ('major_version_3', 'python3'),
- ('minor_version', 'python3.6'),
+ ('python3',),
+ ('python2',),
+ ('python3',),
+ ('python3.6',),
]
)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
def test_start_python_dataflow_with_custom_interpreter(
- self,
- name,
- py_interpreter,
- mock_conn,
- mock_dataflow,
- mock_dataflowjob,
- mock_uuid,
+ self, py_interpreter, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
):
- del name # unused variable
+ mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME,
- variables=DATAFLOW_VARIABLES_PY,
- dataflow=PY_FILE,
- py_options=PY_OPTIONS,
+ on_new_job_id_callback = MagicMock()
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=DATAFLOW_VARIABLES_PY,
+ dataflow=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_interpreter=py_interpreter,
+ py_requirements=None,
+ on_new_job_id_callback=on_new_job_id_callback,
+ )
+
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+ expected_variables["job_name"] = job_name
+ expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_python_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ py_file=PY_FILE,
py_interpreter=py_interpreter,
+ py_options=PY_OPTIONS,
+ py_requirements=None,
+ py_system_site_packages=False,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
)
- expected_cmd = [
- py_interpreter,
- '-m',
- PY_FILE,
- '--region=us-central1',
- '--runner=DataflowRunner',
- '--project=test',
- '--labels=foo=bar',
- '--staging_location=gs://test/staging',
- f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@parameterized.expand(
[
@@ -382,225 +430,229 @@ class TestDataflowHook(unittest.TestCase):
([], True),
]
)
- @mock.patch(DATAFLOW_STRING.format('prepare_virtualenv'))
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
def test_start_python_dataflow_with_non_empty_py_requirements_and_without_system_packages(
self,
current_py_requirements,
current_py_system_site_packages,
- mock_conn,
- mock_dataflow,
- mock_dataflowjob,
+ mock_callback_on_job_id,
+ mock_dataflow_wait_for_done,
mock_uuid,
- mock_virtualenv,
):
+ mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- mock_virtualenv.return_value = '/dummy_dir/bin/python'
- self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME,
- variables=DATAFLOW_VARIABLES_PY,
- dataflow=PY_FILE,
+ on_new_job_id_callback = MagicMock()
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=DATAFLOW_VARIABLES_PY,
+ dataflow=PY_FILE,
+ py_options=PY_OPTIONS,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
+ py_requirements=current_py_requirements,
+ py_system_site_packages=current_py_system_site_packages,
+ on_new_job_id_callback=on_new_job_id_callback,
+ )
+
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+ expected_variables["job_name"] = job_name
+ expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_python_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ py_file=PY_FILE,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
py_options=PY_OPTIONS,
py_requirements=current_py_requirements,
py_system_site_packages=current_py_system_site_packages,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
)
- expected_cmd = [
- '/dummy_dir/bin/python',
- '-m',
- PY_FILE,
- '--region=us-central1',
- '--runner=DataflowRunner',
- '--project=test',
- '--labels=foo=bar',
- '--staging_location=gs://test/staging',
- f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
def test_start_python_dataflow_with_empty_py_requirements_and_without_system_packages(
- self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+ self, mock_dataflow_wait_for_done, mock_uuid
):
+ self.dataflow_hook.beam_hook = BeamHook(runner="DataflowRunner")
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- with pytest.raises(AirflowException, match="Invalid method invocation."):
+ on_new_job_id_callback = MagicMock()
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"), self.assertRaisesRegex(
+ AirflowException, "Invalid method invocation."
+ ):
self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
job_name=JOB_NAME,
variables=DATAFLOW_VARIABLES_PY,
dataflow=PY_FILE,
py_options=PY_OPTIONS,
+ py_interpreter=DEFAULT_PY_INTERPRETER,
py_requirements=[],
+ on_new_job_id_callback=on_new_job_id_callback,
)
+ mock_dataflow_wait_for_done.assert_not_called()
+
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
- def test_start_java_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
+ def test_start_java_dataflow(self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid):
+ mock_beam_start_java_pipeline = self.dataflow_hook.beam_hook.start_java_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, jar=JAR_FILE
- )
- expected_cmd = [
- 'java',
- '-jar',
- JAR_FILE,
- '--region=us-central1',
- '--runner=DataflowRunner',
- '--project=test',
- '--stagingLocation=gs://test/staging',
- '--labels={"foo":"bar"}',
- f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
+ on_new_job_id_callback = MagicMock()
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=DATAFLOW_VARIABLES_JAVA,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ on_new_job_id_callback=on_new_job_id_callback,
+ )
+
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+ expected_variables["jobName"] = job_name
+ expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+ expected_variables["labels"] = '{"foo":"bar"}'
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_java_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION, multiple_jobs=False
+ )
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
def test_start_java_dataflow_with_multiple_values_in_variables(
- self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+ self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
):
+ mock_beam_start_java_pipeline = self.dataflow_hook.beam_hook.start_java_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
- variables['mock-option'] = ['a.whl', 'b.whl']
-
- self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME, variables=variables, jar=JAR_FILE
- )
- expected_cmd = [
- 'java',
- '-jar',
- JAR_FILE,
- '--mock-option=a.whl',
- '--mock-option=b.whl',
- '--region=us-central1',
- '--runner=DataflowRunner',
- '--project=test',
- '--stagingLocation=gs://test/staging',
- '--labels={"foo":"bar"}',
- f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
+ on_new_job_id_callback = MagicMock()
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+ passed_variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+ passed_variables['mock-option'] = ['a.whl', 'b.whl']
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=passed_variables,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ on_new_job_id_callback=on_new_job_id_callback,
+ )
+
+ expected_variables = copy.deepcopy(passed_variables)
+ expected_variables["jobName"] = job_name
+ expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+ expected_variables["labels"] = '{"foo":"bar"}'
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_java_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION, multiple_jobs=False
+ )
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
def test_start_java_dataflow_with_custom_region_as_variable(
- self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+ self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
):
+ mock_beam_start_java_pipeline = self.dataflow_hook.beam_hook.start_java_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
+ on_new_job_id_callback = MagicMock()
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
- variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
- variables['region'] = TEST_LOCATION
-
- self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME, variables=variables, jar=JAR_FILE
- )
- expected_cmd = [
- 'java',
- '-jar',
- JAR_FILE,
- f'--region={TEST_LOCATION}',
- '--runner=DataflowRunner',
- '--project=test',
- '--stagingLocation=gs://test/staging',
- '--labels={"foo":"bar"}',
- f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
+ passed_variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+ passed_variables['region'] = TEST_LOCATION
+
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=passed_variables,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ on_new_job_id_callback=on_new_job_id_callback,
+ )
+
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+ expected_variables["jobName"] = job_name
+ expected_variables["region"] = TEST_LOCATION
+ expected_variables["labels"] = '{"foo":"bar"}'
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_java_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION, multiple_jobs=False
+ )
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+ @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
def test_start_java_dataflow_with_custom_region_as_parameter(
- self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+ self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
):
+ mock_beam_start_java_pipeline = self.dataflow_hook.beam_hook.start_java_pipeline
mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
+ on_new_job_id_callback = MagicMock()
+ job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
- variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
- variables['region'] = TEST_LOCATION
-
- self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME, variables=variables, jar=JAR_FILE
- )
- expected_cmd = [
- 'java',
- '-jar',
- JAR_FILE,
- f'--region={TEST_LOCATION}',
- '--runner=DataflowRunner',
- '--project=test',
- '--stagingLocation=gs://test/staging',
- '--labels={"foo":"bar"}',
- f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
+ with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+ self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
+ job_name=JOB_NAME,
+ variables=DATAFLOW_VARIABLES_JAVA,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ on_new_job_id_callback=on_new_job_id_callback,
+ location=TEST_LOCATION,
+ )
- @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
- @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
- @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
- def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
- mock_uuid.return_value = MOCK_UUID
- mock_conn.return_value = None
- dataflow_instance = mock_dataflow.return_value
- dataflow_instance.wait_for_done.return_value = None
- dataflowjob_instance = mock_dataflowjob.return_value
- dataflowjob_instance.wait_for_done.return_value = None
- self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
- job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, jar=JAR_FILE, job_class=JOB_CLASS
- )
- expected_cmd = [
- 'java',
- '-cp',
- JAR_FILE,
- JOB_CLASS,
- '--region=us-central1',
- '--runner=DataflowRunner',
- '--project=test',
- '--stagingLocation=gs://test/staging',
- '--labels={"foo":"bar"}',
- f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
- ]
- assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
+ expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+ expected_variables["jobName"] = job_name
+ expected_variables["region"] = TEST_LOCATION
+ expected_variables["labels"] = '{"foo":"bar"}'
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+ mock_beam_start_java_pipeline.assert_called_once_with(
+ variables=expected_variables,
+ jar=JAR_FILE,
+ job_class=JOB_CLASS,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+
+ mock_dataflow_wait_for_done.assert_called_once_with(
+ job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION, multiple_jobs=False
+ )
@parameterized.expand(
[
@@ -616,17 +668,20 @@ class TestDataflowHook(unittest.TestCase):
)
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
def test_valid_dataflow_job_name(self, expected_result, job_name, append_job_name, mock_uuid4):
- job_name = self.dataflow_hook._build_dataflow_job_name(
+ job_name = self.dataflow_hook.build_dataflow_job_name(
job_name=job_name, append_job_name=append_job_name
)
- assert expected_result == job_name
+ self.assertEqual(expected_result, job_name)
+ #
@parameterized.expand([("1dfjob@",), ("dfjob@",), ("df^jo",)])
def test_build_dataflow_job_name_with_invalid_value(self, job_name):
- with pytest.raises(ValueError):
- self.dataflow_hook._build_dataflow_job_name(job_name=job_name, append_job_name=False)
+ self.assertRaises(
+ ValueError, self.dataflow_hook.build_dataflow_job_name, job_name=job_name, append_job_name=False
+ )
+ #
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
def test_get_job(self, mock_conn, mock_dataflowjob):
@@ -641,6 +696,7 @@ class TestDataflowHook(unittest.TestCase):
)
method_fetch_job_by_id.assert_called_once_with(TEST_JOB_ID)
+ #
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
def test_fetch_job_metrics_by_id(self, mock_conn, mock_dataflowjob):
@@ -706,6 +762,34 @@ class TestDataflowHook(unittest.TestCase):
)
method_fetch_job_autoscaling_events_by_id.assert_called_once_with(TEST_JOB_ID)
+ @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
+ @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+ def test_wait_for_done(self, mock_conn, mock_dataflowjob):
+ method_wait_for_done = mock_dataflowjob.return_value.wait_for_done
+
+ self.dataflow_hook.wait_for_done(
+ job_name="JOB_NAME",
+ project_id=TEST_PROJECT_ID,
+ job_id=TEST_JOB_ID,
+ location=TEST_LOCATION,
+ multiple_jobs=False,
+ )
+ mock_conn.assert_called_once()
+ mock_dataflowjob.assert_called_once_with(
+ dataflow=mock_conn.return_value,
+ project_number=TEST_PROJECT_ID,
+ name="JOB_NAME",
+ location=TEST_LOCATION,
+ poll_sleep=self.dataflow_hook.poll_sleep,
+ job_id=TEST_JOB_ID,
+ num_retries=self.dataflow_hook.num_retries,
+ multiple_jobs=False,
+ drain_pipeline=self.dataflow_hook.drain_pipeline,
+ cancel_timeout=self.dataflow_hook.cancel_timeout,
+ wait_until_finished=self.dataflow_hook.wait_until_finished,
+ )
+ method_wait_for_done.assert_called_once_with()
+
class TestDataflowTemplateHook(unittest.TestCase):
def setUp(self):
@@ -1691,13 +1775,32 @@ class TestDataflow(unittest.TestCase):
def test_data_flow_valid_job_id(self, log):
echos = ";".join([f"echo {shlex.quote(line)}" for line in log.split("\n")])
cmd = ["bash", "-c", echos]
- assert _DataflowRunner(cmd).wait_for_done() == TEST_JOB_ID
+ found_job_id = None
+
+ def callback(job_id):
+ nonlocal found_job_id
+ found_job_id = job_id
+
+ BeamCommandRunner(
+ cmd, process_line_callback=process_line_and_extract_dataflow_job_id_callback(callback)
+ ).wait_for_done()
+ self.assertEqual(found_job_id, TEST_JOB_ID)
def test_data_flow_missing_job_id(self):
cmd = ['echo', 'unit testing']
- assert _DataflowRunner(cmd).wait_for_done() is None
+ found_job_id = None
+
+ def callback(job_id):
+ nonlocal found_job_id
+ found_job_id = job_id
+
+ BeamCommandRunner(
+ cmd, process_line_callback=process_line_and_extract_dataflow_job_id_callback(callback)
+ ).wait_for_done()
+
+ self.assertEqual(found_job_id, None)
- @mock.patch('airflow.providers.google.cloud.hooks.dataflow._DataflowRunner.log')
+ @mock.patch('airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log')
@mock.patch('subprocess.Popen')
@mock.patch('select.select')
def test_dataflow_wait_for_done_logging(self, mock_select, mock_popen, mock_logging):
@@ -1718,7 +1821,6 @@ class TestDataflow(unittest.TestCase):
mock_proc_poll.side_effect = [None, poll_resp_error]
mock_proc.poll = mock_proc_poll
mock_popen.return_value = mock_proc
- dataflow = _DataflowRunner(['test', 'cmd'])
+ dataflow = BeamCommandRunner(['test', 'cmd'])
mock_logging.info.assert_called_once_with('Running command: %s', 'test cmd')
- with pytest.raises(Exception):
- dataflow.wait_for_done()
+ self.assertRaises(Exception, dataflow.wait_for_done)
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py
index 7e290d7..3018052 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-
+import copy
import unittest
from copy import deepcopy
from unittest import mock
@@ -115,35 +115,56 @@ class TestDataflowPythonOperator(unittest.TestCase):
assert self.dataflow.dataflow_default_options == DEFAULT_OPTIONS_PYTHON
assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS
+ @mock.patch(
+ 'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
+ )
+ @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
- def test_exec(self, gcs_hook, dataflow_mock):
+ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id):
"""Test DataflowHook is created and the right args are passed to
start_python_workflow.
"""
- start_python_hook = dataflow_mock.return_value.start_python_dataflow
+ start_python_mock = beam_hook_mock.return_value.start_python_pipeline
gcs_provide_file = gcs_hook.return_value.provide_file
+ job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
self.dataflow.execute(None)
- assert dataflow_mock.called
+ beam_hook_mock.assert_called_once_with(runner="DataflowRunner")
+ self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
+ gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback=mock.ANY)
+ dataflow_hook_mock.assert_called_once_with(
+ gcp_conn_id="google_cloud_default",
+ delegate_to=mock.ANY,
+ poll_sleep=POLL_SLEEP,
+ impersonation_chain=None,
+ drain_pipeline=False,
+ cancel_timeout=mock.ANY,
+ wait_until_finished=None,
+ )
expected_options = {
- 'project': 'test',
- 'staging_location': 'gs://test/staging',
+ "project": dataflow_hook_mock.return_value.project_id,
+ "staging_location": 'gs://test/staging',
+ "job_name": job_name,
+ "region": TEST_LOCATION,
'output': 'gs://test/output',
- 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+ 'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
}
- gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
- start_python_hook.assert_called_once_with(
- job_name=JOB_NAME,
+ start_python_mock.assert_called_once_with(
variables=expected_options,
- dataflow=mock.ANY,
+ py_file=gcs_provide_file.return_value.__enter__.return_value.name,
py_options=PY_OPTIONS,
py_interpreter=PY_INTERPRETER,
py_requirements=None,
py_system_site_packages=False,
- on_new_job_id_callback=mock.ANY,
- project_id=None,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+ dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+ job_id=mock.ANY,
+ job_name=job_name,
location=TEST_LOCATION,
+ multiple_jobs=False,
)
assert self.dataflow.py_file.startswith('/tmp/dataflow')
@@ -172,110 +193,182 @@ class TestDataflowJavaOperator(unittest.TestCase):
assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS
assert self.dataflow.check_if_running == CheckJobRunning.WaitForRun
+ @mock.patch(
+ 'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
+ )
+ @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
- def test_exec(self, gcs_hook, dataflow_mock):
+ def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id):
"""Test DataflowHook is created and the right args are passed to
start_java_workflow.
"""
- start_java_hook = dataflow_mock.return_value.start_java_dataflow
+ start_java_mock = beam_hook_mock.return_value.start_java_pipeline
gcs_provide_file = gcs_hook.return_value.provide_file
+ job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
self.dataflow.check_if_running = CheckJobRunning.IgnoreJob
+
self.dataflow.execute(None)
- assert dataflow_mock.called
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback=mock.ANY)
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
- start_java_hook.assert_called_once_with(
- job_name=JOB_NAME,
- variables=mock.ANY,
- jar=mock.ANY,
+ expected_variables = {
+ 'project': dataflow_hook_mock.return_value.project_id,
+ 'stagingLocation': 'gs://test/staging',
+ 'jobName': job_name,
+ 'region': TEST_LOCATION,
+ 'output': 'gs://test/output',
+ 'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
+ }
+
+ start_java_mock.assert_called_once_with(
+ variables=expected_variables,
+ jar=gcs_provide_file.return_value.__enter__.return_value.name,
job_class=JOB_CLASS,
- append_job_name=True,
- multiple_jobs=None,
- on_new_job_id_callback=mock.ANY,
- project_id=None,
+ process_line_callback=mock_callback_on_job_id.return_value,
+ )
+ dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+ job_id=mock.ANY,
+ job_name=job_name,
location=TEST_LOCATION,
+ multiple_jobs=None,
)
+ @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
- def test_check_job_running_exec(self, gcs_hook, dataflow_mock):
+ def test_check_job_running_exec(self, gcs_hook, dataflow_mock, beam_hook_mock):
"""Test DataflowHook is created and the right args are passed to
start_java_workflow.
"""
dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
dataflow_running.return_value = True
- start_java_hook = dataflow_mock.return_value.start_java_dataflow
+ start_java_hook = beam_hook_mock.return_value.start_java_pipeline
gcs_provide_file = gcs_hook.return_value.provide_file
self.dataflow.check_if_running = True
+
self.dataflow.execute(None)
- assert dataflow_mock.called
- gcs_provide_file.assert_not_called()
+
+ self.assertTrue(dataflow_mock.called)
start_java_hook.assert_not_called()
- dataflow_running.assert_called_once_with(
- name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION
- )
+ gcs_provide_file.assert_called_once()
+ variables = {
+ 'project': dataflow_mock.return_value.project_id,
+ 'stagingLocation': 'gs://test/staging',
+ 'jobName': JOB_NAME,
+ 'region': TEST_LOCATION,
+ 'output': 'gs://test/output',
+ 'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
+ }
+ dataflow_running.assert_called_once_with(name=JOB_NAME, variables=variables)
+ @mock.patch(
+ 'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
+ )
+ @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
- def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock):
+ def test_check_job_not_running_exec(
+ self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id
+ ):
"""Test DataflowHook is created and the right args are passed to
start_java_workflow with option to check if job is running
-
"""
- dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
+ is_job_dataflow_running_variables = None
+
+ def set_is_job_dataflow_running_variables(*args, **kwargs):
+ nonlocal is_job_dataflow_running_variables
+ is_job_dataflow_running_variables = copy.deepcopy(kwargs.get("variables"))
+
+ dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running
+ dataflow_running.side_effect = set_is_job_dataflow_running_variables
dataflow_running.return_value = False
- start_java_hook = dataflow_mock.return_value.start_java_dataflow
+ start_java_mock = beam_hook_mock.return_value.start_java_pipeline
gcs_provide_file = gcs_hook.return_value.provide_file
self.dataflow.check_if_running = True
+
self.dataflow.execute(None)
- assert dataflow_mock.called
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback=mock.ANY)
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
- start_java_hook.assert_called_once_with(
- job_name=JOB_NAME,
- variables=mock.ANY,
- jar=mock.ANY,
+ expected_variables = {
+ 'project': dataflow_hook_mock.return_value.project_id,
+ 'stagingLocation': 'gs://test/staging',
+ 'jobName': JOB_NAME,
+ 'region': TEST_LOCATION,
+ 'output': 'gs://test/output',
+ 'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
+ }
+ self.assertEqual(expected_variables, is_job_dataflow_running_variables)
+ job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
+ expected_variables["jobName"] = job_name
+ start_java_mock.assert_called_once_with(
+ variables=expected_variables,
+ jar=gcs_provide_file.return_value.__enter__.return_value.name,
job_class=JOB_CLASS,
- append_job_name=True,
- multiple_jobs=None,
- on_new_job_id_callback=mock.ANY,
- project_id=None,
- location=TEST_LOCATION,
+ process_line_callback=mock_callback_on_job_id.return_value,
)
- dataflow_running.assert_called_once_with(
- name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION
+ dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+ job_id=mock.ANY,
+ job_name=job_name,
+ location=TEST_LOCATION,
+ multiple_jobs=None,
)
+ @mock.patch(
+ 'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
+ )
+ @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
- def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock):
+ def test_check_multiple_job_exec(
+ self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id
+ ):
"""Test DataflowHook is created and the right args are passed to
- start_java_workflow with option to check multiple jobs
-
+ start_java_workflow with option to check if job is running
"""
- dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
+ is_job_dataflow_running_variables = None
+
+ def set_is_job_dataflow_running_variables(*args, **kwargs):
+ nonlocal is_job_dataflow_running_variables
+ is_job_dataflow_running_variables = copy.deepcopy(kwargs.get("variables"))
+
+ dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running
+ dataflow_running.side_effect = set_is_job_dataflow_running_variables
dataflow_running.return_value = False
- start_java_hook = dataflow_mock.return_value.start_java_dataflow
+ start_java_mock = beam_hook_mock.return_value.start_java_pipeline
gcs_provide_file = gcs_hook.return_value.provide_file
- self.dataflow.multiple_jobs = True
self.dataflow.check_if_running = True
+ self.dataflow.multiple_jobs = True
+
self.dataflow.execute(None)
- assert dataflow_mock.called
+
+ mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback=mock.ANY)
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
- start_java_hook.assert_called_once_with(
- job_name=JOB_NAME,
- variables=mock.ANY,
- jar=mock.ANY,
+ expected_variables = {
+ 'project': dataflow_hook_mock.return_value.project_id,
+ 'stagingLocation': 'gs://test/staging',
+ 'jobName': JOB_NAME,
+ 'region': TEST_LOCATION,
+ 'output': 'gs://test/output',
+ 'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
+ }
+ self.assertEqual(expected_variables, is_job_dataflow_running_variables)
+ job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
+ expected_variables["jobName"] = job_name
+ start_java_mock.assert_called_once_with(
+ variables=expected_variables,
+ jar=gcs_provide_file.return_value.__enter__.return_value.name,
job_class=JOB_CLASS,
- append_job_name=True,
- multiple_jobs=True,
- on_new_job_id_callback=mock.ANY,
- project_id=None,
- location=TEST_LOCATION,
+ process_line_callback=mock_callback_on_job_id.return_value,
)
- dataflow_running.assert_called_once_with(
- name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION
+ dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+ job_id=mock.ANY,
+ job_name=job_name,
+ location=TEST_LOCATION,
+ multiple_jobs=True,
)
diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py b/tests/providers/google/cloud/operators/test_mlengine_utils.py
index 539ee60..c46fa62 100644
--- a/tests/providers/google/cloud/operators/test_mlengine_utils.py
+++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py
@@ -106,9 +106,14 @@ class TestCreateEvaluateOps(unittest.TestCase):
)
assert success_message['predictionOutput'] == result
- with patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') as mock_dataflow_hook:
- hook_instance = mock_dataflow_hook.return_value
- hook_instance.start_python_dataflow.return_value = None
+ with patch(
+ 'airflow.providers.google.cloud.operators.dataflow.DataflowHook'
+ ) as mock_dataflow_hook, patch(
+ 'airflow.providers.google.cloud.operators.dataflow.BeamHook'
+ ) as mock_beam_hook:
+ dataflow_hook_instance = mock_dataflow_hook.return_value
+ dataflow_hook_instance.start_python_dataflow.return_value = None
+ beam_hook_instance = mock_beam_hook.return_value
summary.execute(None)
mock_dataflow_hook.assert_called_once_with(
gcp_conn_id='google_cloud_default',
@@ -117,23 +122,28 @@ class TestCreateEvaluateOps(unittest.TestCase):
drain_pipeline=False,
cancel_timeout=600,
wait_until_finished=None,
+ impersonation_chain=None,
)
- hook_instance.start_python_dataflow.assert_called_once_with(
- job_name='{{task.task_id}}',
+ mock_beam_hook.assert_called_once_with(runner="DataflowRunner")
+ beam_hook_instance.start_python_pipeline.assert_called_once_with(
variables={
'prediction_path': 'gs://legal-bucket/fake-output-path',
'labels': {'airflow-version': TEST_VERSION},
'metric_keys': 'err',
'metric_fn_encoded': self.metric_fn_encoded,
+ 'project': 'test-project',
+ 'region': 'us-central1',
+ 'job_name': mock.ANY,
},
- dataflow=mock.ANY,
+ py_file=mock.ANY,
py_options=[],
- py_requirements=['apache-beam[gcp]>=2.14.0'],
py_interpreter='python3',
+ py_requirements=['apache-beam[gcp]>=2.14.0'],
py_system_site_packages=False,
- on_new_job_id_callback=ANY,
- project_id='test-project',
- location='us-central1',
+ process_line_callback=mock.ANY,
+ )
+ dataflow_hook_instance.wait_for_done.assert_called_once_with(
+ job_name=mock.ANY, location='us-central1', job_id=mock.ANY, multiple_jobs=False
)
with patch('airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook') as mock_gcs_hook: