You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/03/03 09:37:26 UTC

[airflow] branch v2-0-test updated (bec2e7e -> 38b3548)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a change to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git.


 discard bec2e7e  fixup! Add Neo4j hook and operator (#13324)
 discard aa94e09  Fix failing docs build on Master (#14465)
 discard a75ab68  Add Azure Data Factory hook (#11015)
 discard 641a7c5  Add Tableau provider separate from Salesforce Provider (#14030)
 discard 34a8732  Pin moto to <2 (#14433)
 discard a55ffd2  Remove testfixtures module that is only used once (#14318)
 discard fb8f454  Limits Sphinx to <3.5.0 (#14238)
 discard bfa04e5  Remove reinstalling azure-storage steps from CI / Breeze (#14102)
 discard a8256ab  Update to Pytest 6.0 (#14065)
 discard 191d69a  Support google-cloud-logging` >=2.0.0 (#13801)
 discard 2ee5a82  Support google-cloud-monitoring>=2.0.0 (#13769)
 discard a833671  Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256)
 discard f0ae25e  Support google-cloud-tasks>=2.0.0 (#13347)
 discard 743729f  Support google-cloud-automl >=2.1.0 (#13505)
 discard ee4789e  Support google-cloud-datacatalog>=3.0.0 (#13534)
 discard 3e35316  Salesforce provider requires tableau (#13593)
 discard db5019b  Support google-cloud-bigquery-datatransfer>=3.0.0 (#13337)
 discard 8e66138  Add timeout option to gcs hook methods. (#13156)
 discard ed420f3  Support google-cloud-redis>=2.0.0 (#13117)
 discard d954916  Support google-cloud-pubsub>=2.0.0 (#13127)
 discard 6233cd0  Update compatibility with google-cloud-kms>=2.0 (#13124)
 discard c8d9c3a  Support google-cloud-datacatalog>=1.0.0 (#13097)
 discard 0df90cf  Update compatibility with google-cloud-os-login>=2.0.0 (#13126)
 discard 7696716  Add Google Cloud Workflows Operators (#13366)
 discard 42935a4  Upgrade slack_sdk to v3 (#13745)
 discard fef2a36  Add Apache Beam operators (#12814)
 discard 4307263  Fix grammar in production-deployment.rst (#14386)
 discard 475b7ce  Minor doc fixes (#14547)
 discard 7dcd19b  Add Neo4j hook and operator (#13324)
     new 8543471  Add Neo4j hook and operator (#13324)
     new a671f37  Minor doc fixes (#14547)
     new 6e74149  Fix grammar in production-deployment.rst (#14386)
     new c6ccaa5  Add Apache Beam operators (#12814)
     new 3c5173a  Upgrade slack_sdk to v3 (#13745)
     new 35679d2  Add Google Cloud Workflows Operators (#13366)
     new cfd5a48  Update compatibility with google-cloud-os-login>=2.0.0 (#13126)
     new 51d70e3  Support google-cloud-datacatalog>=1.0.0 (#13097)
     new fae32c3  Update compatibility with google-cloud-kms>=2.0 (#13124)
     new 7d49baa  Support google-cloud-pubsub>=2.0.0 (#13127)
     new 3ef6d6f  Support google-cloud-redis>=2.0.0 (#13117)
     new 1a24288  Add timeout option to gcs hook methods. (#13156)
     new 62d985b  Support google-cloud-bigquery-datatransfer>=3.0.0 (#13337)
     new 56cc293  Salesforce provider requires tableau (#13593)
     new 9ed976e  Support google-cloud-datacatalog>=3.0.0 (#13534)
     new 92c356e  Support google-cloud-automl >=2.1.0 (#13505)
     new 7b31def  Support google-cloud-tasks>=2.0.0 (#13347)
     new 21831aa  Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256)
     new 988a2a5  Support google-cloud-monitoring>=2.0.0 (#13769)
     new 25f2db1  Support google-cloud-logging` >=2.0.0 (#13801)
     new c34898b  Update to Pytest 6.0 (#14065)
     new cac851c  Remove reinstalling azure-storage steps from CI / Breeze (#14102)
     new 6118de4  Limits Sphinx to <3.5.0 (#14238)
     new 12b5dca  Remove testfixtures module that is only used once (#14318)
     new 982a3a2  Pin moto to <2 (#14433)
     new 3f6f6bf  Add Tableau provider separate from Salesforce Provider (#14030)
     new 0d199d3  Add Azure Data Factory hook (#11015)
     new 38b3548  Fix failing docs build on Master (#14465)

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (bec2e7e)
            \
             N -- N -- N   refs/heads/v2-0-test (38b3548)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 28 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:


[airflow] 26/28: Add Tableau provider separate from Salesforce Provider (#14030)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3f6f6bf0dfff1b035ea1538c244e17b6b519c474
Author: Jyoti Dhiman <36...@users.noreply.github.com>
AuthorDate: Thu Feb 25 17:52:54 2021 +0530

    Add Tableau provider separate from Salesforce Provider (#14030)
    
    Closes #13614
    
    (cherry picked from commit 45e72ca83049a7db526b1f0fbd94c75f5f92cc75)
---
 CONTRIBUTING.rst                                   |   1 +
 airflow/providers/dependencies.json                |   3 +
 airflow/providers/salesforce/CHANGELOG.rst         |  16 ++++
 airflow/providers/salesforce/hooks/tableau.py      | 104 ++-------------------
 .../operators/tableau_refresh_workbook.py          |  88 ++---------------
 airflow/providers/salesforce/provider.yaml         |   6 +-
 .../salesforce/sensors/tableau_job_status.py       |  68 +++-----------
 .../{salesforce => tableau}/CHANGELOG.rst          |   0
 .../provider.yaml => tableau/__init__.py}          |  34 +------
 .../example_dags/__init__.py}                      |  33 -------
 .../example_tableau_refresh_workbook.py            |   4 +-
 .../provider.yaml => tableau/hooks/__init__.py}    |  34 +------
 .../{salesforce => tableau}/hooks/tableau.py       |   0
 .../operators/__init__.py}                         |  33 -------
 .../operators/tableau_refresh_workbook.py          |   4 +-
 .../{salesforce => tableau}/provider.yaml          |  26 +++---
 .../provider.yaml => tableau/sensors/__init__.py}  |  33 -------
 .../sensors/tableau_job_status.py                  |   2 +-
 .../apache-airflow-providers-tableau/index.rst     |  29 +++++-
 docs/integration-logos/tableau/tableau.png         | Bin 0 -> 4142 bytes
 docs/spelling_wordlist.txt                         |   1 +
 .../run_install_and_test_provider_packages.sh      |   2 +-
 setup.py                                           |   4 +-
 tests/core/test_providers_manager.py               |   1 +
 .../providers/tableau/hooks/__init__.py            |  34 +------
 .../{salesforce => tableau}/hooks/test_tableau.py  |  32 +++++--
 .../providers/tableau/operators/__init__.py        |  33 -------
 .../operators/test_tableau_refresh_workbook.py     |  26 +++++-
 .../providers/tableau/sensors/__init__.py          |  33 -------
 .../sensors/test_tableau_job_status.py             |  16 +++-
 30 files changed, 162 insertions(+), 538 deletions(-)

diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 0a6f381..857d3bb 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -654,6 +654,7 @@ microsoft.mssql            odbc
 mysql                      amazon,presto,vertica
 opsgenie                   http
 postgres                   amazon
+salesforce                 tableau
 sftp                       ssh
 slack                      http
 snowflake                  slack
diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json
index 836020c..b01e96c 100644
--- a/airflow/providers/dependencies.json
+++ b/airflow/providers/dependencies.json
@@ -67,6 +67,9 @@
   "postgres": [
     "amazon"
   ],
+  "salesforce": [
+    "tableau"
+  ],
   "sftp": [
     "ssh"
   ],
diff --git a/airflow/providers/salesforce/CHANGELOG.rst b/airflow/providers/salesforce/CHANGELOG.rst
index cef7dda..b4eb0ed 100644
--- a/airflow/providers/salesforce/CHANGELOG.rst
+++ b/airflow/providers/salesforce/CHANGELOG.rst
@@ -19,6 +19,22 @@
 Changelog
 ---------
 
+1.0.2
+.....
+
+Tableau provider moved to separate 'tableau' provider
+
+Things done:
+
+    - Tableau classes imports classes from 'tableau' provider with deprecation warning
+
+
+1.0.1
+.....
+
+Updated documentation and readme files.
+
+
 1.0.0
 .....
 
diff --git a/airflow/providers/salesforce/hooks/tableau.py b/airflow/providers/salesforce/hooks/tableau.py
index 51c2f98..cf5f7f3 100644
--- a/airflow/providers/salesforce/hooks/tableau.py
+++ b/airflow/providers/salesforce/hooks/tableau.py
@@ -14,102 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from enum import Enum
-from typing import Any, Optional
 
-from tableauserverclient import Pager, PersonalAccessTokenAuth, Server, TableauAuth
-from tableauserverclient.server import Auth
+import warnings
 
-from airflow.hooks.base import BaseHook
+# pylint: disable=unused-import
+from airflow.providers.tableau.hooks.tableau import TableauHook, TableauJobFinishCode  # noqa
 
-
-class TableauJobFinishCode(Enum):
-    """
-    The finish code indicates the status of the job.
-
-    .. seealso:: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref.htm#query_job
-
-    """
-
-    PENDING = -1
-    SUCCESS = 0
-    ERROR = 1
-    CANCELED = 2
-
-
-class TableauHook(BaseHook):
-    """
-    Connects to the Tableau Server Instance and allows to communicate with it.
-
-    .. seealso:: https://tableau.github.io/server-client-python/docs/
-
-    :param site_id: The id of the site where the workbook belongs to.
-        It will connect to the default site if you don't provide an id.
-    :type site_id: Optional[str]
-    :param tableau_conn_id: The Tableau Connection id containing the credentials
-        to authenticate to the Tableau Server.
-    :type tableau_conn_id: str
-    """
-
-    conn_name_attr = 'tableau_conn_id'
-    default_conn_name = 'tableau_default'
-    conn_type = 'tableau'
-    hook_name = 'Tableau'
-
-    def __init__(self, site_id: Optional[str] = None, tableau_conn_id: str = default_conn_name) -> None:
-        super().__init__()
-        self.tableau_conn_id = tableau_conn_id
-        self.conn = self.get_connection(self.tableau_conn_id)
-        self.site_id = site_id or self.conn.extra_dejson.get('site_id', '')
-        self.server = Server(self.conn.host, use_server_version=True)
-        self.tableau_conn = None
-
-    def __enter__(self):
-        if not self.tableau_conn:
-            self.tableau_conn = self.get_conn()
-        return self
-
-    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
-        self.server.auth.sign_out()
-
-    def get_conn(self) -> Auth.contextmgr:
-        """
-        Signs in to the Tableau Server and automatically signs out if used as ContextManager.
-
-        :return: an authorized Tableau Server Context Manager object.
-        :rtype: tableauserverclient.server.Auth.contextmgr
-        """
-        if self.conn.login and self.conn.password:
-            return self._auth_via_password()
-        if 'token_name' in self.conn.extra_dejson and 'personal_access_token' in self.conn.extra_dejson:
-            return self._auth_via_token()
-        raise NotImplementedError('No Authentication method found for given Credentials!')
-
-    def _auth_via_password(self) -> Auth.contextmgr:
-        tableau_auth = TableauAuth(
-            username=self.conn.login, password=self.conn.password, site_id=self.site_id
-        )
-        return self.server.auth.sign_in(tableau_auth)
-
-    def _auth_via_token(self) -> Auth.contextmgr:
-        tableau_auth = PersonalAccessTokenAuth(
-            token_name=self.conn.extra_dejson['token_name'],
-            personal_access_token=self.conn.extra_dejson['personal_access_token'],
-            site_id=self.site_id,
-        )
-        return self.server.auth.sign_in_with_personal_access_token(tableau_auth)
-
-    def get_all(self, resource_name: str) -> Pager:
-        """
-        Get all items of the given resource.
-
-        .. seealso:: https://tableau.github.io/server-client-python/docs/page-through-results
-
-        :param resource_name: The name of the resource to paginate.
-            For example: jobs or workbooks
-        :type resource_name: str
-        :return: all items by returning a Pager.
-        :rtype: tableauserverclient.Pager
-        """
-        resource = getattr(self.server, resource_name)
-        return Pager(resource.get)
+warnings.warn(
+    "This module is deprecated. Please use `airflow.providers.tableau.hooks.tableau`.",
+    DeprecationWarning,
+    stacklevel=2,
+)
diff --git a/airflow/providers/salesforce/operators/tableau_refresh_workbook.py b/airflow/providers/salesforce/operators/tableau_refresh_workbook.py
index 7d4ffdc..309af33 100644
--- a/airflow/providers/salesforce/operators/tableau_refresh_workbook.py
+++ b/airflow/providers/salesforce/operators/tableau_refresh_workbook.py
@@ -14,84 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import Optional
 
-from tableauserverclient import WorkbookItem
+import warnings
 
-from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
-from airflow.providers.salesforce.hooks.tableau import TableauHook
-from airflow.utils.decorators import apply_defaults
+# pylint: disable=unused-import
+from airflow.providers.tableau.operators.tableau_refresh_workbook import (  # noqa
+    TableauRefreshWorkbookOperator,
+)
 
-
-class TableauRefreshWorkbookOperator(BaseOperator):
-    """
-    Refreshes a Tableau Workbook/Extract
-
-    .. seealso:: https://tableau.github.io/server-client-python/docs/api-ref#workbooks
-
-    :param workbook_name: The name of the workbook to refresh.
-    :type workbook_name: str
-    :param site_id: The id of the site where the workbook belongs to.
-    :type site_id: Optional[str]
-    :param blocking: By default the extract refresh will be blocking means it will wait until it has finished.
-    :type blocking: bool
-    :param tableau_conn_id: The Tableau Connection id containing the credentials
-        to authenticate to the Tableau Server.
-    :type tableau_conn_id: str
-    """
-
-    @apply_defaults
-    def __init__(
-        self,
-        *,
-        workbook_name: str,
-        site_id: Optional[str] = None,
-        blocking: bool = True,
-        tableau_conn_id: str = 'tableau_default',
-        **kwargs,
-    ) -> None:
-        super().__init__(**kwargs)
-        self.workbook_name = workbook_name
-        self.site_id = site_id
-        self.blocking = blocking
-        self.tableau_conn_id = tableau_conn_id
-
-    def execute(self, context: dict) -> str:
-        """
-        Executes the Tableau Extract Refresh and pushes the job id to xcom.
-
-        :param context: The task context during execution.
-        :type context: dict
-        :return: the id of the job that executes the extract refresh
-        :rtype: str
-        """
-        with TableauHook(self.site_id, self.tableau_conn_id) as tableau_hook:
-            workbook = self._get_workbook_by_name(tableau_hook)
-
-            job_id = self._refresh_workbook(tableau_hook, workbook.id)
-            if self.blocking:
-                from airflow.providers.salesforce.sensors.tableau_job_status import TableauJobStatusSensor
-
-                TableauJobStatusSensor(
-                    job_id=job_id,
-                    site_id=self.site_id,
-                    tableau_conn_id=self.tableau_conn_id,
-                    task_id='wait_until_succeeded',
-                    dag=None,
-                ).execute(context={})
-                self.log.info('Workbook %s has been successfully refreshed.', self.workbook_name)
-            return job_id
-
-    def _get_workbook_by_name(self, tableau_hook: TableauHook) -> WorkbookItem:
-        for workbook in tableau_hook.get_all(resource_name='workbooks'):
-            if workbook.name == self.workbook_name:
-                self.log.info('Found matching workbook with id %s', workbook.id)
-                return workbook
-
-        raise AirflowException(f'Workbook {self.workbook_name} not found!')
-
-    def _refresh_workbook(self, tableau_hook: TableauHook, workbook_id: str) -> str:
-        job = tableau_hook.server.workbooks.refresh(workbook_id)
-        self.log.info('Refreshing Workbook %s...', self.workbook_name)
-        return job.id
+warnings.warn(
+    "This module is deprecated. Please use `airflow.providers.tableau.operators.tableau_refresh_workbook`.",
+    DeprecationWarning,
+    stacklevel=2,
+)
diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/salesforce/provider.yaml
index fe739ff..c0992d8 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/airflow/providers/salesforce/provider.yaml
@@ -22,6 +22,8 @@ description: |
     `Salesforce <https://www.salesforce.com/>`__
 
 versions:
+  - 1.0.2
+  - 1.0.1
   - 1.0.0
 
 integrations:
@@ -40,10 +42,12 @@ sensors:
       - airflow.providers.salesforce.sensors.tableau_job_status
 
 hooks:
+  - integration-name: Tableau
+    python-modules:
+      - airflow.providers.salesforce.hooks.tableau
   - integration-name: Salesforce
     python-modules:
       - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
 
 hook-class-names:
   - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/airflow/providers/salesforce/sensors/tableau_job_status.py b/airflow/providers/salesforce/sensors/tableau_job_status.py
index 4939203..076159e 100644
--- a/airflow/providers/salesforce/sensors/tableau_job_status.py
+++ b/airflow/providers/salesforce/sensors/tableau_job_status.py
@@ -14,63 +14,17 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import Optional
 
-from airflow.exceptions import AirflowException
-from airflow.providers.salesforce.hooks.tableau import TableauHook, TableauJobFinishCode
-from airflow.sensors.base import BaseSensorOperator
-from airflow.utils.decorators import apply_defaults
+import warnings
 
+# pylint: disable=unused-import
+from airflow.providers.tableau.sensors.tableau_job_status import (  # noqa
+    TableauJobFailedException,
+    TableauJobStatusSensor,
+)
 
-class TableauJobFailedException(AirflowException):
-    """An exception that indicates that a Job failed to complete."""
-
-
-class TableauJobStatusSensor(BaseSensorOperator):
-    """
-    Watches the status of a Tableau Server Job.
-
-    .. seealso:: https://tableau.github.io/server-client-python/docs/api-ref#jobs
-
-    :param job_id: The job to watch.
-    :type job_id: str
-    :param site_id: The id of the site where the workbook belongs to.
-    :type site_id: Optional[str]
-    :param tableau_conn_id: The Tableau Connection id containing the credentials
-        to authenticate to the Tableau Server.
-    :type tableau_conn_id: str
-    """
-
-    template_fields = ('job_id',)
-
-    @apply_defaults
-    def __init__(
-        self,
-        *,
-        job_id: str,
-        site_id: Optional[str] = None,
-        tableau_conn_id: str = 'tableau_default',
-        **kwargs,
-    ) -> None:
-        super().__init__(**kwargs)
-        self.tableau_conn_id = tableau_conn_id
-        self.job_id = job_id
-        self.site_id = site_id
-
-    def poke(self, context: dict) -> bool:
-        """
-        Pokes until the job has successfully finished.
-
-        :param context: The task context during execution.
-        :type context: dict
-        :return: True if it succeeded and False if not.
-        :rtype: bool
-        """
-        with TableauHook(self.site_id, self.tableau_conn_id) as tableau_hook:
-            finish_code = TableauJobFinishCode(
-                int(tableau_hook.server.jobs.get_by_id(self.job_id).finish_code)
-            )
-            self.log.info('Current finishCode is %s (%s)', finish_code.name, finish_code.value)
-            if finish_code in [TableauJobFinishCode.ERROR, TableauJobFinishCode.CANCELED]:
-                raise TableauJobFailedException('The Tableau Refresh Workbook Job failed!')
-            return finish_code == TableauJobFinishCode.SUCCESS
+warnings.warn(
+    "This module is deprecated. Please use `airflow.providers.tableau.sensors.tableau_job_status`.",
+    DeprecationWarning,
+    stacklevel=2,
+)
diff --git a/airflow/providers/salesforce/CHANGELOG.rst b/airflow/providers/tableau/CHANGELOG.rst
similarity index 100%
copy from airflow/providers/salesforce/CHANGELOG.rst
copy to airflow/providers/tableau/CHANGELOG.rst
diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/tableau/__init__.py
similarity index 50%
copy from airflow/providers/salesforce/provider.yaml
copy to airflow/providers/tableau/__init__.py
index fe739ff..217e5db 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/airflow/providers/tableau/__init__.py
@@ -1,3 +1,4 @@
+#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -14,36 +15,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
-description: |
-    `Salesforce <https://www.salesforce.com/>`__
-
-versions:
-  - 1.0.0
-
-integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
-    tags: [service]
-
-operators:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
-
-sensors:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
-
-hooks:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
-
-hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/tableau/example_dags/__init__.py
similarity index 50%
copy from airflow/providers/salesforce/provider.yaml
copy to airflow/providers/tableau/example_dags/__init__.py
index fe739ff..13a8339 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/airflow/providers/tableau/example_dags/__init__.py
@@ -14,36 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
-description: |
-    `Salesforce <https://www.salesforce.com/>`__
-
-versions:
-  - 1.0.0
-
-integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
-    tags: [service]
-
-operators:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
-
-sensors:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
-
-hooks:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
-
-hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/airflow/providers/salesforce/example_dags/example_tableau_refresh_workbook.py b/airflow/providers/tableau/example_dags/example_tableau_refresh_workbook.py
similarity index 92%
rename from airflow/providers/salesforce/example_dags/example_tableau_refresh_workbook.py
rename to airflow/providers/tableau/example_dags/example_tableau_refresh_workbook.py
index 32b347c..da1cc8b 100644
--- a/airflow/providers/salesforce/example_dags/example_tableau_refresh_workbook.py
+++ b/airflow/providers/tableau/example_dags/example_tableau_refresh_workbook.py
@@ -23,8 +23,8 @@ when the operation actually finishes. That's why we have another task that check
 from datetime import timedelta
 
 from airflow import DAG
-from airflow.providers.salesforce.operators.tableau_refresh_workbook import TableauRefreshWorkbookOperator
-from airflow.providers.salesforce.sensors.tableau_job_status import TableauJobStatusSensor
+from airflow.providers.tableau.operators.tableau_refresh_workbook import TableauRefreshWorkbookOperator
+from airflow.providers.tableau.sensors.tableau_job_status import TableauJobStatusSensor
 from airflow.utils.dates import days_ago
 
 DEFAULT_ARGS = {
diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/tableau/hooks/__init__.py
similarity index 50%
copy from airflow/providers/salesforce/provider.yaml
copy to airflow/providers/tableau/hooks/__init__.py
index fe739ff..217e5db 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/airflow/providers/tableau/hooks/__init__.py
@@ -1,3 +1,4 @@
+#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -14,36 +15,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
-description: |
-    `Salesforce <https://www.salesforce.com/>`__
-
-versions:
-  - 1.0.0
-
-integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
-    tags: [service]
-
-operators:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
-
-sensors:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
-
-hooks:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
-
-hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/airflow/providers/salesforce/hooks/tableau.py b/airflow/providers/tableau/hooks/tableau.py
similarity index 100%
copy from airflow/providers/salesforce/hooks/tableau.py
copy to airflow/providers/tableau/hooks/tableau.py
diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/tableau/operators/__init__.py
similarity index 50%
copy from airflow/providers/salesforce/provider.yaml
copy to airflow/providers/tableau/operators/__init__.py
index fe739ff..13a8339 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/airflow/providers/tableau/operators/__init__.py
@@ -14,36 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
-description: |
-    `Salesforce <https://www.salesforce.com/>`__
-
-versions:
-  - 1.0.0
-
-integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
-    tags: [service]
-
-operators:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
-
-sensors:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
-
-hooks:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
-
-hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/airflow/providers/salesforce/operators/tableau_refresh_workbook.py b/airflow/providers/tableau/operators/tableau_refresh_workbook.py
similarity index 95%
copy from airflow/providers/salesforce/operators/tableau_refresh_workbook.py
copy to airflow/providers/tableau/operators/tableau_refresh_workbook.py
index 7d4ffdc..25ca77b 100644
--- a/airflow/providers/salesforce/operators/tableau_refresh_workbook.py
+++ b/airflow/providers/tableau/operators/tableau_refresh_workbook.py
@@ -20,7 +20,7 @@ from tableauserverclient import WorkbookItem
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
-from airflow.providers.salesforce.hooks.tableau import TableauHook
+from airflow.providers.tableau.hooks.tableau import TableauHook
 from airflow.utils.decorators import apply_defaults
 
 
@@ -71,7 +71,7 @@ class TableauRefreshWorkbookOperator(BaseOperator):
 
             job_id = self._refresh_workbook(tableau_hook, workbook.id)
             if self.blocking:
-                from airflow.providers.salesforce.sensors.tableau_job_status import TableauJobStatusSensor
+                from airflow.providers.tableau.sensors.tableau_job_status import TableauJobStatusSensor
 
                 TableauJobStatusSensor(
                     job_id=job_id,
diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/tableau/provider.yaml
similarity index 61%
copy from airflow/providers/salesforce/provider.yaml
copy to airflow/providers/tableau/provider.yaml
index fe739ff..e777947 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/airflow/providers/tableau/provider.yaml
@@ -16,34 +16,34 @@
 # under the License.
 
 ---
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
+package-name: apache-airflow-providers-tableau
+name: Tableau
 description: |
-    `Salesforce <https://www.salesforce.com/>`__
+    `Tableau <https://www.tableau.com/>`__
 
 versions:
   - 1.0.0
 
 integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
+  - integration-name: Tableau
+    external-doc-url: https://www.tableau.com/
+    logo: /integration-logos/tableau/tableau.png
     tags: [service]
 
 operators:
-  - integration-name: Salesforce
+  - integration-name: Tableau
     python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
+      - airflow.providers.tableau.operators.tableau_refresh_workbook
 
 sensors:
-  - integration-name: Salesforce
+  - integration-name: Tableau
     python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
+      - airflow.providers.tableau.sensors.tableau_job_status
 
 hooks:
-  - integration-name: Salesforce
+  - integration-name: Tableau
     python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
+      - airflow.providers.tableau.hooks.tableau
 
 hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
+  - airflow.providers.tableau.hooks.tableau.TableauHook
diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/tableau/sensors/__init__.py
similarity index 50%
copy from airflow/providers/salesforce/provider.yaml
copy to airflow/providers/tableau/sensors/__init__.py
index fe739ff..13a8339 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/airflow/providers/tableau/sensors/__init__.py
@@ -14,36 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
-description: |
-    `Salesforce <https://www.salesforce.com/>`__
-
-versions:
-  - 1.0.0
-
-integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
-    tags: [service]
-
-operators:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
-
-sensors:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
-
-hooks:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
-
-hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/airflow/providers/salesforce/sensors/tableau_job_status.py b/airflow/providers/tableau/sensors/tableau_job_status.py
similarity index 96%
copy from airflow/providers/salesforce/sensors/tableau_job_status.py
copy to airflow/providers/tableau/sensors/tableau_job_status.py
index 4939203..518e2f0 100644
--- a/airflow/providers/salesforce/sensors/tableau_job_status.py
+++ b/airflow/providers/tableau/sensors/tableau_job_status.py
@@ -17,7 +17,7 @@
 from typing import Optional
 
 from airflow.exceptions import AirflowException
-from airflow.providers.salesforce.hooks.tableau import TableauHook, TableauJobFinishCode
+from airflow.providers.tableau.hooks.tableau import TableauHook, TableauJobFinishCode
 from airflow.sensors.base import BaseSensorOperator
 from airflow.utils.decorators import apply_defaults
 
diff --git a/airflow/providers/salesforce/CHANGELOG.rst b/docs/apache-airflow-providers-tableau/index.rst
similarity index 56%
copy from airflow/providers/salesforce/CHANGELOG.rst
copy to docs/apache-airflow-providers-tableau/index.rst
index cef7dda..47ace94 100644
--- a/airflow/providers/salesforce/CHANGELOG.rst
+++ b/docs/apache-airflow-providers-tableau/index.rst
@@ -1,3 +1,4 @@
+
  .. Licensed to the Apache Software Foundation (ASF) under one
     or more contributor license agreements.  See the NOTICE file
     distributed with this work for additional information
@@ -15,11 +16,29 @@
     specific language governing permissions and limitations
     under the License.
 
+``apache-airflow-providers-tableau``
+=======================================
+
+Content
+-------
+
+.. toctree::
+    :maxdepth: 1
+    :caption: Guides
+
+    Connection types <connections/tableau>
+
+.. toctree::
+    :maxdepth: 1
+    :caption: References
+
+    Python API <_api/airflow/providers/tableau/index>
 
-Changelog
----------
+.. toctree::
+    :maxdepth: 1
+    :caption: Resources
 
-1.0.0
-.....
+    Example DAGs <https://github.com/apache/airflow/tree/master/airflow/providers/tableau/example_dags>
+    PyPI Repository <https://pypi.org/project/apache-airflow-providers-tableau/>
 
-Initial version of the provider.
+.. THE REMINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME!
diff --git a/docs/integration-logos/tableau/tableau.png b/docs/integration-logos/tableau/tableau.png
new file mode 100644
index 0000000..4ec356c
Binary files /dev/null and b/docs/integration-logos/tableau/tableau.png differ
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 71f9e34..0e89285 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1280,6 +1280,7 @@ sync'ed
 sys
 syspath
 systemd
+tableau
 tableauserverclient
 tablefmt
 tagKey
diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh
index 76d41e4..5eb039a 100755
--- a/scripts/in_container/run_install_and_test_provider_packages.sh
+++ b/scripts/in_container/run_install_and_test_provider_packages.sh
@@ -95,7 +95,7 @@ function discover_all_provider_packages() {
     # Columns is to force it wider, so it doesn't wrap at 80 characters
     COLUMNS=180 airflow providers list
 
-    local expected_number_of_providers=63
+    local expected_number_of_providers=64
     local actual_number_of_providers
     actual_providers=$(airflow providers list --output yaml | grep package_name)
     actual_number_of_providers=$(wc -l <<<"$actual_providers")
diff --git a/setup.py b/setup.py
index 2867b36..4ee7a5c 100644
--- a/setup.py
+++ b/setup.py
@@ -444,7 +444,7 @@ statsd = [
     'statsd>=3.3.0, <4.0',
 ]
 tableau = [
-    'tableauserverclient~=0.12',
+    'tableauserverclient',
 ]
 telegram = [
     'python-telegram-bot==13.0',
@@ -576,6 +576,7 @@ PROVIDERS_REQUIREMENTS: Dict[str, List[str]] = {
     'snowflake': snowflake,
     'sqlite': [],
     'ssh': ssh,
+    'tableau': tableau,
     'telegram': telegram,
     'vertica': vertica,
     'yandex': yandex,
@@ -608,7 +609,6 @@ CORE_EXTRAS_REQUIREMENTS: Dict[str, List[str]] = {
     'rabbitmq': rabbitmq,
     'sentry': sentry,
     'statsd': statsd,
-    'tableau': tableau,
     'virtualenv': virtualenv,
 }
 
diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py
index 39ee588..9112d5e 100644
--- a/tests/core/test_providers_manager.py
+++ b/tests/core/test_providers_manager.py
@@ -81,6 +81,7 @@ ALL_PROVIDERS = [
     # 'apache-airflow-providers-snowflake',
     'apache-airflow-providers-sqlite',
     'apache-airflow-providers-ssh',
+    'apache-airflow-providers-tableau',
     'apache-airflow-providers-telegram',
     'apache-airflow-providers-vertica',
     'apache-airflow-providers-yandex',
diff --git a/airflow/providers/salesforce/provider.yaml b/tests/providers/tableau/hooks/__init__.py
similarity index 50%
copy from airflow/providers/salesforce/provider.yaml
copy to tests/providers/tableau/hooks/__init__.py
index fe739ff..217e5db 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/tests/providers/tableau/hooks/__init__.py
@@ -1,3 +1,4 @@
+#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -14,36 +15,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
-description: |
-    `Salesforce <https://www.salesforce.com/>`__
-
-versions:
-  - 1.0.0
-
-integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
-    tags: [service]
-
-operators:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
-
-sensors:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
-
-hooks:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
-
-hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/tests/providers/salesforce/hooks/test_tableau.py b/tests/providers/tableau/hooks/test_tableau.py
similarity index 81%
rename from tests/providers/salesforce/hooks/test_tableau.py
rename to tests/providers/tableau/hooks/test_tableau.py
index 130746d..66ecdf7 100644
--- a/tests/providers/salesforce/hooks/test_tableau.py
+++ b/tests/providers/tableau/hooks/test_tableau.py
@@ -19,12 +19,19 @@ import unittest
 from unittest.mock import patch
 
 from airflow import configuration, models
-from airflow.providers.salesforce.hooks.tableau import TableauHook
+from airflow.providers.tableau.hooks.tableau import TableauHook
 from airflow.utils import db
 
 
 class TestTableauHook(unittest.TestCase):
+    """
+    Test class for TableauHook
+    """
+
     def setUp(self):
+        """
+        setup
+        """
         configuration.conf.load_test_config()
 
         db.merge_conn(
@@ -46,9 +53,12 @@ class TestTableauHook(unittest.TestCase):
             )
         )
 
-    @patch('airflow.providers.salesforce.hooks.tableau.TableauAuth')
-    @patch('airflow.providers.salesforce.hooks.tableau.Server')
+    @patch('airflow.providers.tableau.hooks.tableau.TableauAuth')
+    @patch('airflow.providers.tableau.hooks.tableau.Server')
     def test_get_conn_auth_via_password_and_site_in_connection(self, mock_server, mock_tableau_auth):
+        """
+        Test get conn auth via password
+        """
         with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
             mock_server.assert_called_once_with(tableau_hook.conn.host, use_server_version=True)
             mock_tableau_auth.assert_called_once_with(
@@ -59,9 +69,12 @@ class TestTableauHook(unittest.TestCase):
             mock_server.return_value.auth.sign_in.assert_called_once_with(mock_tableau_auth.return_value)
         mock_server.return_value.auth.sign_out.assert_called_once_with()
 
-    @patch('airflow.providers.salesforce.hooks.tableau.PersonalAccessTokenAuth')
-    @patch('airflow.providers.salesforce.hooks.tableau.Server')
+    @patch('airflow.providers.tableau.hooks.tableau.PersonalAccessTokenAuth')
+    @patch('airflow.providers.tableau.hooks.tableau.Server')
     def test_get_conn_auth_via_token_and_site_in_init(self, mock_server, mock_tableau_auth):
+        """
+        Test get conn auth via token
+        """
         with TableauHook(site_id='test', tableau_conn_id='tableau_test_token') as tableau_hook:
             mock_server.assert_called_once_with(tableau_hook.conn.host, use_server_version=True)
             mock_tableau_auth.assert_called_once_with(
@@ -74,10 +87,13 @@ class TestTableauHook(unittest.TestCase):
             )
         mock_server.return_value.auth.sign_out.assert_called_once_with()
 
-    @patch('airflow.providers.salesforce.hooks.tableau.TableauAuth')
-    @patch('airflow.providers.salesforce.hooks.tableau.Server')
-    @patch('airflow.providers.salesforce.hooks.tableau.Pager', return_value=[1, 2, 3])
+    @patch('airflow.providers.tableau.hooks.tableau.TableauAuth')
+    @patch('airflow.providers.tableau.hooks.tableau.Server')
+    @patch('airflow.providers.tableau.hooks.tableau.Pager', return_value=[1, 2, 3])
     def test_get_all(self, mock_pager, mock_server, mock_tableau_auth):  # pylint: disable=unused-argument
+        """
+        Test get all
+        """
         with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook:
             jobs = tableau_hook.get_all(resource_name='jobs')
             assert jobs == mock_pager.return_value
diff --git a/airflow/providers/salesforce/provider.yaml b/tests/providers/tableau/operators/__init__.py
similarity index 50%
copy from airflow/providers/salesforce/provider.yaml
copy to tests/providers/tableau/operators/__init__.py
index fe739ff..13a8339 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/tests/providers/tableau/operators/__init__.py
@@ -14,36 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
-description: |
-    `Salesforce <https://www.salesforce.com/>`__
-
-versions:
-  - 1.0.0
-
-integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
-    tags: [service]
-
-operators:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
-
-sensors:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
-
-hooks:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
-
-hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py b/tests/providers/tableau/operators/test_tableau_refresh_workbook.py
similarity index 80%
rename from tests/providers/salesforce/operators/test_tableau_refresh_workbook.py
rename to tests/providers/tableau/operators/test_tableau_refresh_workbook.py
index 77139c1..72377a5 100644
--- a/tests/providers/salesforce/operators/test_tableau_refresh_workbook.py
+++ b/tests/providers/tableau/operators/test_tableau_refresh_workbook.py
@@ -21,11 +21,18 @@ from unittest.mock import Mock, patch
 import pytest
 
 from airflow.exceptions import AirflowException
-from airflow.providers.salesforce.operators.tableau_refresh_workbook import TableauRefreshWorkbookOperator
+from airflow.providers.tableau.operators.tableau_refresh_workbook import TableauRefreshWorkbookOperator
 
 
 class TestTableauRefreshWorkbookOperator(unittest.TestCase):
+    """
+    Test class for TableauRefreshWorkbookOperator
+    """
+
     def setUp(self):
+        """
+        setup
+        """
         self.mocked_workbooks = []
         for i in range(3):
             mock_workbook = Mock()
@@ -34,8 +41,11 @@ class TestTableauRefreshWorkbookOperator(unittest.TestCase):
             self.mocked_workbooks.append(mock_workbook)
         self.kwargs = {'site_id': 'test_site', 'task_id': 'task', 'dag': None}
 
-    @patch('airflow.providers.salesforce.operators.tableau_refresh_workbook.TableauHook')
+    @patch('airflow.providers.tableau.operators.tableau_refresh_workbook.TableauHook')
     def test_execute(self, mock_tableau_hook):
+        """
+        Test Execute
+        """
         mock_tableau_hook.get_all = Mock(return_value=self.mocked_workbooks)
         mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
         operator = TableauRefreshWorkbookOperator(blocking=False, workbook_name='wb_2', **self.kwargs)
@@ -45,9 +55,12 @@ class TestTableauRefreshWorkbookOperator(unittest.TestCase):
         mock_tableau_hook.server.workbooks.refresh.assert_called_once_with(2)
         assert mock_tableau_hook.server.workbooks.refresh.return_value.id == job_id
 
-    @patch('airflow.providers.salesforce.sensors.tableau_job_status.TableauJobStatusSensor')
-    @patch('airflow.providers.salesforce.operators.tableau_refresh_workbook.TableauHook')
+    @patch('airflow.providers.tableau.sensors.tableau_job_status.TableauJobStatusSensor')
+    @patch('airflow.providers.tableau.operators.tableau_refresh_workbook.TableauHook')
     def test_execute_blocking(self, mock_tableau_hook, mock_tableau_job_status_sensor):
+        """
+        Test execute blocking
+        """
         mock_tableau_hook.get_all = Mock(return_value=self.mocked_workbooks)
         mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
         operator = TableauRefreshWorkbookOperator(workbook_name='wb_2', **self.kwargs)
@@ -64,8 +77,11 @@ class TestTableauRefreshWorkbookOperator(unittest.TestCase):
             dag=None,
         )
 
-    @patch('airflow.providers.salesforce.operators.tableau_refresh_workbook.TableauHook')
+    @patch('airflow.providers.tableau.operators.tableau_refresh_workbook.TableauHook')
     def test_execute_missing_workbook(self, mock_tableau_hook):
+        """
+        Test execute missing workbook
+        """
         mock_tableau_hook.get_all = Mock(return_value=self.mocked_workbooks)
         mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
         operator = TableauRefreshWorkbookOperator(workbook_name='test', **self.kwargs)
diff --git a/airflow/providers/salesforce/provider.yaml b/tests/providers/tableau/sensors/__init__.py
similarity index 50%
copy from airflow/providers/salesforce/provider.yaml
copy to tests/providers/tableau/sensors/__init__.py
index fe739ff..13a8339 100644
--- a/airflow/providers/salesforce/provider.yaml
+++ b/tests/providers/tableau/sensors/__init__.py
@@ -14,36 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
----
-package-name: apache-airflow-providers-salesforce
-name: Salesforce
-description: |
-    `Salesforce <https://www.salesforce.com/>`__
-
-versions:
-  - 1.0.0
-
-integrations:
-  - integration-name: Salesforce
-    external-doc-url: https://www.salesforce.com/
-    tags: [service]
-
-operators:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.operators.tableau_refresh_workbook
-
-sensors:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.sensors.tableau_job_status
-
-hooks:
-  - integration-name: Salesforce
-    python-modules:
-      - airflow.providers.salesforce.hooks.salesforce
-      - airflow.providers.salesforce.hooks.tableau
-
-hook-class-names:
-  - airflow.providers.salesforce.hooks.tableau.TableauHook
diff --git a/tests/providers/salesforce/sensors/test_tableau_job_status.py b/tests/providers/tableau/sensors/test_tableau_job_status.py
similarity index 84%
rename from tests/providers/salesforce/sensors/test_tableau_job_status.py
rename to tests/providers/tableau/sensors/test_tableau_job_status.py
index 7f01011..ea6eeb2 100644
--- a/tests/providers/salesforce/sensors/test_tableau_job_status.py
+++ b/tests/providers/tableau/sensors/test_tableau_job_status.py
@@ -21,18 +21,25 @@ from unittest.mock import Mock, patch
 import pytest
 from parameterized import parameterized
 
-from airflow.providers.salesforce.sensors.tableau_job_status import (
+from airflow.providers.tableau.sensors.tableau_job_status import (
     TableauJobFailedException,
     TableauJobStatusSensor,
 )
 
 
 class TestTableauJobStatusSensor(unittest.TestCase):
+    """
+    Test Class for JobStatusSensor
+    """
+
     def setUp(self):
         self.kwargs = {'job_id': 'job_2', 'site_id': 'test_site', 'task_id': 'task', 'dag': None}
 
-    @patch('airflow.providers.salesforce.sensors.tableau_job_status.TableauHook')
+    @patch('airflow.providers.tableau.sensors.tableau_job_status.TableauHook')
     def test_poke(self, mock_tableau_hook):
+        """
+        Test poke
+        """
         mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
         mock_get = mock_tableau_hook.server.jobs.get_by_id
         mock_get.return_value.finish_code = '0'
@@ -44,8 +51,11 @@ class TestTableauJobStatusSensor(unittest.TestCase):
         mock_get.assert_called_once_with(sensor.job_id)
 
     @parameterized.expand([('1',), ('2',)])
-    @patch('airflow.providers.salesforce.sensors.tableau_job_status.TableauHook')
+    @patch('airflow.providers.tableau.sensors.tableau_job_status.TableauHook')
     def test_poke_failed(self, finish_code, mock_tableau_hook):
+        """
+        Test poke failed
+        """
         mock_tableau_hook.return_value.__enter__ = Mock(return_value=mock_tableau_hook)
         mock_get = mock_tableau_hook.server.jobs.get_by_id
         mock_get.return_value.finish_code = finish_code


[airflow] 18/28: Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 21831aaf3c00d8a941b76f44d823238b93b85b7a
Author: Tomek Urbaszek <tu...@gmail.com>
AuthorDate: Mon Jan 18 17:49:19 2021 +0100

    Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256)
    
    (cherry picked from commit 309788e5e2023c598095a4ee00df417d94b6a5df)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |   2 +
 airflow/providers/google/cloud/hooks/dataproc.py   | 104 ++++++++---------
 .../providers/google/cloud/operators/dataproc.py   |  30 +++--
 airflow/providers/google/cloud/sensors/dataproc.py |  12 +-
 setup.py                                           |   2 +-
 .../providers/google/cloud/hooks/test_dataproc.py  | 129 ++++++++++++---------
 .../google/cloud/operators/test_dataproc.py        |  14 ++-
 .../google/cloud/sensors/test_dataproc.py          |   8 +-
 8 files changed, 157 insertions(+), 144 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index c696e1b..16a6683 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -32,11 +32,13 @@ Details are covered in the UPDATING.md files for each library, but there are som
 | [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
 | [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
 | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
+| [``google-cloud-dataproc``](https://pypi.org/project/google-cloud-dataproc/) | ``>=1.0.1,<2.0.0`` | ``>=2.2.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-dataproc/blob/master/UPGRADING.md) |
 | [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
 | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
 | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
 | [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) |
 
+
 ### The field names use the snake_case convention
 
 If your DAG uses an object from the above mentioned libraries passed by XCom, it is necessary to update the naming convention of the fields that are read. Previously, the fields used the CamelSnake convention, now the snake_case convention is used.
diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py
index 12d5941..35d4786 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -26,18 +26,16 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
 from google.api_core.exceptions import ServerError
 from google.api_core.retry import Retry
 from google.cloud.dataproc_v1beta2 import (  # pylint: disable=no-name-in-module
-    ClusterControllerClient,
-    JobControllerClient,
-    WorkflowTemplateServiceClient,
-)
-from google.cloud.dataproc_v1beta2.types import (  # pylint: disable=no-name-in-module
     Cluster,
-    Duration,
-    FieldMask,
+    ClusterControllerClient,
     Job,
+    JobControllerClient,
     JobStatus,
     WorkflowTemplate,
+    WorkflowTemplateServiceClient,
 )
+from google.protobuf.duration_pb2 import Duration
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow.exceptions import AirflowException
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@@ -291,10 +289,12 @@ class DataprocHook(GoogleBaseHook):
 
         client = self.get_cluster_client(location=region)
         result = client.create_cluster(
-            project_id=project_id,
-            region=region,
-            cluster=cluster,
-            request_id=request_id,
+            request={
+                'project_id': project_id,
+                'region': region,
+                'cluster': cluster,
+                'request_id': request_id,
+            },
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -340,11 +340,13 @@ class DataprocHook(GoogleBaseHook):
         """
         client = self.get_cluster_client(location=region)
         result = client.delete_cluster(
-            project_id=project_id,
-            region=region,
-            cluster_name=cluster_name,
-            cluster_uuid=cluster_uuid,
-            request_id=request_id,
+            request={
+                'project_id': project_id,
+                'region': region,
+                'cluster_name': cluster_name,
+                'cluster_uuid': cluster_uuid,
+                'request_id': request_id,
+            },
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -382,9 +384,7 @@ class DataprocHook(GoogleBaseHook):
         """
         client = self.get_cluster_client(location=region)
         operation = client.diagnose_cluster(
-            project_id=project_id,
-            region=region,
-            cluster_name=cluster_name,
+            request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -423,9 +423,7 @@ class DataprocHook(GoogleBaseHook):
         """
         client = self.get_cluster_client(location=region)
         result = client.get_cluster(
-            project_id=project_id,
-            region=region,
-            cluster_name=cluster_name,
+            request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -467,10 +465,7 @@ class DataprocHook(GoogleBaseHook):
         """
         client = self.get_cluster_client(location=region)
         result = client.list_clusters(
-            project_id=project_id,
-            region=region,
-            filter_=filter_,
-            page_size=page_size,
+            request={'project_id': project_id, 'region': region, 'filter': filter_, 'page_size': page_size},
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -551,13 +546,15 @@ class DataprocHook(GoogleBaseHook):
         """
         client = self.get_cluster_client(location=location)
         operation = client.update_cluster(
-            project_id=project_id,
-            region=location,
-            cluster_name=cluster_name,
-            cluster=cluster,
-            update_mask=update_mask,
-            graceful_decommission_timeout=graceful_decommission_timeout,
-            request_id=request_id,
+            request={
+                'project_id': project_id,
+                'region': location,
+                'cluster_name': cluster_name,
+                'cluster': cluster,
+                'update_mask': update_mask,
+                'graceful_decommission_timeout': graceful_decommission_timeout,
+                'request_id': request_id,
+            },
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -593,10 +590,11 @@ class DataprocHook(GoogleBaseHook):
         :param metadata: Additional metadata that is provided to the method.
         :type metadata: Sequence[Tuple[str, str]]
         """
+        metadata = metadata or ()
         client = self.get_template_client(location)
-        parent = client.region_path(project_id, location)
+        parent = f'projects/{project_id}/regions/{location}'
         return client.create_workflow_template(
-            parent=parent, template=template, retry=retry, timeout=timeout, metadata=metadata
+            request={'parent': parent, 'template': template}, retry=retry, timeout=timeout, metadata=metadata
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -643,13 +641,11 @@ class DataprocHook(GoogleBaseHook):
         :param metadata: Additional metadata that is provided to the method.
         :type metadata: Sequence[Tuple[str, str]]
         """
+        metadata = metadata or ()
         client = self.get_template_client(location)
-        name = client.workflow_template_path(project_id, location, template_name)
+        name = f'projects/{project_id}/regions/{location}/workflowTemplates/{template_name}'
         operation = client.instantiate_workflow_template(
-            name=name,
-            version=version,
-            parameters=parameters,
-            request_id=request_id,
+            request={'name': name, 'version': version, 'request_id': request_id, 'parameters': parameters},
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -690,12 +686,11 @@ class DataprocHook(GoogleBaseHook):
         :param metadata: Additional metadata that is provided to the method.
         :type metadata: Sequence[Tuple[str, str]]
         """
+        metadata = metadata or ()
         client = self.get_template_client(location)
-        parent = client.region_path(project_id, location)
+        parent = f'projects/{project_id}/regions/{location}'
         operation = client.instantiate_inline_workflow_template(
-            parent=parent,
-            template=template,
-            request_id=request_id,
+            request={'parent': parent, 'template': template, 'request_id': request_id},
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -722,19 +717,19 @@ class DataprocHook(GoogleBaseHook):
         """
         state = None
         start = time.monotonic()
-        while state not in (JobStatus.ERROR, JobStatus.DONE, JobStatus.CANCELLED):
+        while state not in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
             if timeout and start + timeout < time.monotonic():
                 raise AirflowException(f"Timeout: dataproc job {job_id} is not ready after {timeout}s")
             time.sleep(wait_time)
             try:
-                job = self.get_job(location=location, job_id=job_id, project_id=project_id)
+                job = self.get_job(project_id=project_id, location=location, job_id=job_id)
                 state = job.status.state
             except ServerError as err:
                 self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)
 
-        if state == JobStatus.ERROR:
+        if state == JobStatus.State.ERROR:
             raise AirflowException(f'Job failed:\n{job}')
-        if state == JobStatus.CANCELLED:
+        if state == JobStatus.State.CANCELLED:
             raise AirflowException(f'Job was cancelled:\n{job}')
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -767,9 +762,7 @@ class DataprocHook(GoogleBaseHook):
         """
         client = self.get_job_client(location=location)
         job = client.get_job(
-            project_id=project_id,
-            region=location,
-            job_id=job_id,
+            request={'project_id': project_id, 'region': location, 'job_id': job_id},
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -812,10 +805,7 @@ class DataprocHook(GoogleBaseHook):
         """
         client = self.get_job_client(location=location)
         return client.submit_job(
-            project_id=project_id,
-            region=location,
-            job=job,
-            request_id=request_id,
+            request={'project_id': project_id, 'region': location, 'job': job, 'request_id': request_id},
             retry=retry,
             timeout=timeout,
             metadata=metadata,
@@ -884,9 +874,7 @@ class DataprocHook(GoogleBaseHook):
         client = self.get_job_client(location=location)
 
         job = client.cancel_job(
-            project_id=project_id,
-            region=location,
-            job_id=job_id,
+            request={'project_id': project_id, 'region': location, 'job_id': job_id},
             retry=retry,
             timeout=timeout,
             metadata=metadata,
diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py
index ac93915..13b7026 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -17,7 +17,6 @@
 # under the License.
 #
 """This module contains Google Dataproc operators."""
-# pylint: disable=C0302
 
 import inspect
 import ntpath
@@ -31,12 +30,9 @@ from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from google.api_core.exceptions import AlreadyExists, NotFound
 from google.api_core.retry import Retry, exponential_sleep_generator
-from google.cloud.dataproc_v1beta2.types import (  # pylint: disable=no-name-in-module
-    Cluster,
-    Duration,
-    FieldMask,
-)
-from google.protobuf.json_format import MessageToDict
+from google.cloud.dataproc_v1beta2 import Cluster  # pylint: disable=no-name-in-module
+from google.protobuf.duration_pb2 import Duration
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
@@ -562,7 +558,7 @@ class DataprocCreateClusterOperator(BaseOperator):
         )
 
     def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
-        if cluster.status.state != cluster.status.ERROR:
+        if cluster.status.state != cluster.status.State.ERROR:
             return
         self.log.info("Cluster is in ERROR state")
         gcs_uri = hook.diagnose_cluster(
@@ -590,7 +586,7 @@ class DataprocCreateClusterOperator(BaseOperator):
         time_left = self.timeout
         cluster = self._get_cluster(hook)
         for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
-            if cluster.status.state != cluster.status.CREATING:
+            if cluster.status.state != cluster.status.State.CREATING:
                 break
             if time_left < 0:
                 raise AirflowException(f"Cluster {self.cluster_name} is still CREATING state, aborting")
@@ -613,18 +609,18 @@ class DataprocCreateClusterOperator(BaseOperator):
 
         # Check if cluster is not in ERROR state
         self._handle_error_state(hook, cluster)
-        if cluster.status.state == cluster.status.CREATING:
+        if cluster.status.state == cluster.status.State.CREATING:
             # Wait for cluster to be be created
             cluster = self._wait_for_cluster_in_creating_state(hook)
             self._handle_error_state(hook, cluster)
-        elif cluster.status.state == cluster.status.DELETING:
+        elif cluster.status.state == cluster.status.State.DELETING:
             # Wait for cluster to be deleted
             self._wait_for_cluster_in_deleting_state(hook)
             # Create new cluster
             cluster = self._create_cluster(hook)
             self._handle_error_state(hook, cluster)
 
-        return MessageToDict(cluster)
+        return Cluster.to_dict(cluster)
 
 
 class DataprocScaleClusterOperator(BaseOperator):
@@ -1790,7 +1786,7 @@ class DataprocSubmitJobOperator(BaseOperator):
     :type wait_timeout: int
     """
 
-    template_fields = ('project_id', 'location', 'job', 'impersonation_chain')
+    template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id')
     template_fields_renderers = {"job": "json"}
 
     @apply_defaults
@@ -1876,14 +1872,14 @@ class DataprocUpdateClusterOperator(BaseOperator):
         example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be
         specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify the
         new value. If a dict is provided, it must be of the same form as the protobuf message
-        :class:`~google.cloud.dataproc_v1beta2.types.FieldMask`
-    :type update_mask: Union[Dict, google.cloud.dataproc_v1beta2.types.FieldMask]
+        :class:`~google.protobuf.field_mask_pb2.FieldMask`
+    :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
     :param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful
         decommissioning allows removing nodes from the cluster without interrupting jobs in progress. Timeout
         specifies how long to wait for jobs in progress to finish before forcefully removing nodes (and
         potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the maximum
         allowed timeout is 1 day.
-    :type graceful_decommission_timeout: Union[Dict, google.cloud.dataproc_v1beta2.types.Duration]
+    :type graceful_decommission_timeout: Union[Dict, google.protobuf.duration_pb2.Duration]
     :param request_id: Optional. A unique id used to identify the request. If the server receives two
         ``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and the
         first ``google.longrunning.Operation`` created and stored in the backend is returned.
@@ -1909,7 +1905,7 @@ class DataprocUpdateClusterOperator(BaseOperator):
     :type impersonation_chain: Union[str, Sequence[str]]
     """
 
-    template_fields = ('impersonation_chain',)
+    template_fields = ('impersonation_chain', 'cluster_name')
 
     @apply_defaults
     def __init__(  # pylint: disable=too-many-arguments
diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py
index 1777a22..93656df 100644
--- a/airflow/providers/google/cloud/sensors/dataproc.py
+++ b/airflow/providers/google/cloud/sensors/dataproc.py
@@ -65,14 +65,18 @@ class DataprocJobSensor(BaseSensorOperator):
         job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id)
         state = job.status.state
 
-        if state == JobStatus.ERROR:
+        if state == JobStatus.State.ERROR:
             raise AirflowException(f'Job failed:\n{job}')
-        elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, JobStatus.CANCEL_STARTED}:
+        elif state in {
+            JobStatus.State.CANCELLED,
+            JobStatus.State.CANCEL_PENDING,
+            JobStatus.State.CANCEL_STARTED,
+        }:
             raise AirflowException(f'Job was cancelled:\n{job}')
-        elif JobStatus.DONE == state:
+        elif JobStatus.State.DONE == state:
             self.log.debug("Job %s completed successfully.", self.dataproc_job_id)
             return True
-        elif JobStatus.ATTEMPT_FAILURE == state:
+        elif JobStatus.State.ATTEMPT_FAILURE == state:
             self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)
 
         self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
diff --git a/setup.py b/setup.py
index 520b059..0f40d88 100644
--- a/setup.py
+++ b/setup.py
@@ -288,7 +288,7 @@ google = [
     'google-cloud-bigtable>=1.0.0,<2.0.0',
     'google-cloud-container>=0.1.1,<2.0.0',
     'google-cloud-datacatalog>=3.0.0,<4.0.0',
-    'google-cloud-dataproc>=1.0.1,<2.0.0',
+    'google-cloud-dataproc>=2.2.0,<3.0.0',
     'google-cloud-dlp>=0.11.0,<2.0.0',
     'google-cloud-kms>=2.0.0,<3.0.0',
     'google-cloud-language>=1.1.1,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py
index d09c91e..6842acc 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -20,7 +20,7 @@ import unittest
 from unittest import mock
 
 import pytest
-from google.cloud.dataproc_v1beta2.types import JobStatus  # pylint: disable=no-name-in-module
+from google.cloud.dataproc_v1beta2 import JobStatus  # pylint: disable=no-name-in-module
 
 from airflow.exceptions import AirflowException
 from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder
@@ -43,8 +43,6 @@ CLUSTER = {
     "project_id": GCP_PROJECT,
 }
 
-PARENT = "parent"
-NAME = "name"
 BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
 DATAPROC_STRING = "airflow.providers.google.cloud.hooks.dataproc.{}"
 
@@ -113,11 +111,13 @@ class TestDataprocHook(unittest.TestCase):
         )
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.create_cluster.assert_called_once_with(
-            project_id=GCP_PROJECT,
-            region=GCP_LOCATION,
-            cluster=CLUSTER,
+            request=dict(
+                project_id=GCP_PROJECT,
+                region=GCP_LOCATION,
+                cluster=CLUSTER,
+                request_id=None,
+            ),
             metadata=None,
-            request_id=None,
             retry=None,
             timeout=None,
         )
@@ -127,12 +127,14 @@ class TestDataprocHook(unittest.TestCase):
         self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.delete_cluster.assert_called_once_with(
-            project_id=GCP_PROJECT,
-            region=GCP_LOCATION,
-            cluster_name=CLUSTER_NAME,
-            cluster_uuid=None,
+            request=dict(
+                project_id=GCP_PROJECT,
+                region=GCP_LOCATION,
+                cluster_name=CLUSTER_NAME,
+                cluster_uuid=None,
+                request_id=None,
+            ),
             metadata=None,
-            request_id=None,
             retry=None,
             timeout=None,
         )
@@ -142,9 +144,11 @@ class TestDataprocHook(unittest.TestCase):
         self.hook.diagnose_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.diagnose_cluster.assert_called_once_with(
-            project_id=GCP_PROJECT,
-            region=GCP_LOCATION,
-            cluster_name=CLUSTER_NAME,
+            request=dict(
+                project_id=GCP_PROJECT,
+                region=GCP_LOCATION,
+                cluster_name=CLUSTER_NAME,
+            ),
             metadata=None,
             retry=None,
             timeout=None,
@@ -156,9 +160,11 @@ class TestDataprocHook(unittest.TestCase):
         self.hook.get_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.get_cluster.assert_called_once_with(
-            project_id=GCP_PROJECT,
-            region=GCP_LOCATION,
-            cluster_name=CLUSTER_NAME,
+            request=dict(
+                project_id=GCP_PROJECT,
+                region=GCP_LOCATION,
+                cluster_name=CLUSTER_NAME,
+            ),
             metadata=None,
             retry=None,
             timeout=None,
@@ -171,10 +177,12 @@ class TestDataprocHook(unittest.TestCase):
         self.hook.list_clusters(project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_)
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.list_clusters.assert_called_once_with(
-            project_id=GCP_PROJECT,
-            region=GCP_LOCATION,
-            filter_=filter_,
-            page_size=None,
+            request=dict(
+                project_id=GCP_PROJECT,
+                region=GCP_LOCATION,
+                filter=filter_,
+                page_size=None,
+            ),
             metadata=None,
             retry=None,
             timeout=None,
@@ -192,14 +200,16 @@ class TestDataprocHook(unittest.TestCase):
         )
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.update_cluster.assert_called_once_with(
-            project_id=GCP_PROJECT,
-            region=GCP_LOCATION,
-            cluster=CLUSTER,
-            cluster_name=CLUSTER_NAME,
-            update_mask=update_mask,
-            graceful_decommission_timeout=None,
+            request=dict(
+                project_id=GCP_PROJECT,
+                region=GCP_LOCATION,
+                cluster=CLUSTER,
+                cluster_name=CLUSTER_NAME,
+                update_mask=update_mask,
+                graceful_decommission_timeout=None,
+                request_id=None,
+            ),
             metadata=None,
-            request_id=None,
             retry=None,
             timeout=None,
         )
@@ -207,44 +217,45 @@ class TestDataprocHook(unittest.TestCase):
     @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
     def test_create_workflow_template(self, mock_client):
         template = {"test": "test"}
-        mock_client.return_value.region_path.return_value = PARENT
+        parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}'
         self.hook.create_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT)
-        mock_client.return_value.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION)
         mock_client.return_value.create_workflow_template.assert_called_once_with(
-            parent=PARENT, template=template, retry=None, timeout=None, metadata=None
+            request=dict(parent=parent, template=template), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
     def test_instantiate_workflow_template(self, mock_client):
         template_name = "template_name"
-        mock_client.return_value.workflow_template_path.return_value = NAME
+        name = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}/workflowTemplates/{template_name}'
         self.hook.instantiate_workflow_template(
             location=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT
         )
-        mock_client.return_value.workflow_template_path.assert_called_once_with(
-            GCP_PROJECT, GCP_LOCATION, template_name
-        )
         mock_client.return_value.instantiate_workflow_template.assert_called_once_with(
-            name=NAME, version=None, parameters=None, request_id=None, retry=None, timeout=None, metadata=None
+            request=dict(name=name, version=None, parameters=None, request_id=None),
+            retry=None,
+            timeout=None,
+            metadata=(),
         )
 
     @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
     def test_instantiate_inline_workflow_template(self, mock_client):
         template = {"test": "test"}
-        mock_client.return_value.region_path.return_value = PARENT
+        parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}'
         self.hook.instantiate_inline_workflow_template(
             location=GCP_LOCATION, template=template, project_id=GCP_PROJECT
         )
-        mock_client.return_value.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION)
         mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with(
-            parent=PARENT, template=template, request_id=None, retry=None, timeout=None, metadata=None
+            request=dict(parent=parent, template=template, request_id=None),
+            retry=None,
+            timeout=None,
+            metadata=(),
         )
 
     @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job"))
     def test_wait_for_job(self, mock_get_job):
         mock_get_job.side_effect = [
-            mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)),
-            mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)),
+            mock.MagicMock(status=mock.MagicMock(state=JobStatus.State.RUNNING)),
+            mock.MagicMock(status=mock.MagicMock(state=JobStatus.State.ERROR)),
         ]
         with pytest.raises(AirflowException):
             self.hook.wait_for_job(job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0)
@@ -259,9 +270,11 @@ class TestDataprocHook(unittest.TestCase):
         self.hook.get_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT)
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.get_job.assert_called_once_with(
-            region=GCP_LOCATION,
-            job_id=JOB_ID,
-            project_id=GCP_PROJECT,
+            request=dict(
+                region=GCP_LOCATION,
+                job_id=JOB_ID,
+                project_id=GCP_PROJECT,
+            ),
             retry=None,
             timeout=None,
             metadata=None,
@@ -272,10 +285,12 @@ class TestDataprocHook(unittest.TestCase):
         self.hook.submit_job(location=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT)
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.submit_job.assert_called_once_with(
-            region=GCP_LOCATION,
-            job=JOB,
-            project_id=GCP_PROJECT,
-            request_id=None,
+            request=dict(
+                region=GCP_LOCATION,
+                job=JOB,
+                project_id=GCP_PROJECT,
+                request_id=None,
+            ),
             retry=None,
             timeout=None,
             metadata=None,
@@ -297,9 +312,11 @@ class TestDataprocHook(unittest.TestCase):
         self.hook.cancel_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT)
         mock_client.assert_called_once_with(location=GCP_LOCATION)
         mock_client.return_value.cancel_job.assert_called_once_with(
-            region=GCP_LOCATION,
-            job_id=JOB_ID,
-            project_id=GCP_PROJECT,
+            request=dict(
+                region=GCP_LOCATION,
+                job_id=JOB_ID,
+                project_id=GCP_PROJECT,
+            ),
             retry=None,
             timeout=None,
             metadata=None,
@@ -311,9 +328,11 @@ class TestDataprocHook(unittest.TestCase):
             self.hook.cancel_job(job_id=JOB_ID, project_id=GCP_PROJECT)
         mock_client.assert_called_once_with(location='global')
         mock_client.return_value.cancel_job.assert_called_once_with(
-            region='global',
-            job_id=JOB_ID,
-            project_id=GCP_PROJECT,
+            request=dict(
+                region='global',
+                job_id=JOB_ID,
+                project_id=GCP_PROJECT,
+            ),
             retry=None,
             timeout=None,
             metadata=None,
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py
index ca8f706..791e8ea 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -204,8 +204,9 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
         assert_warning("Default region value", warning)
         self.assertEqual(op_default_region.region, 'global')
 
+    @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute(self, mock_hook):
+    def test_execute(self, mock_hook, to_dict_mock):
         op = DataprocCreateClusterOperator(
             task_id=TASK_ID,
             region=GCP_LOCATION,
@@ -233,9 +234,11 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
 
+    @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    def test_execute_if_cluster_exists(self, mock_hook):
+    def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock):
         mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")]
         mock_hook.return_value.get_cluster.return_value.status.state = 0
         op = DataprocCreateClusterOperator(
@@ -273,6 +276,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+        to_dict_mock.assert_called_once_with(mock_hook.return_value.get_cluster.return_value)
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_if_cluster_exists_do_not_use(self, mock_hook):
@@ -300,7 +304,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
         mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")]
         cluster_status = mock_hook.return_value.get_cluster.return_value.status
         cluster_status.state = 0
-        cluster_status.ERROR = 0
+        cluster_status.State.ERROR = 0
 
         op = DataprocCreateClusterOperator(
             task_id=TASK_ID,
@@ -335,11 +339,11 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
     ):
         cluster = mock.MagicMock()
         cluster.status.state = 0
-        cluster.status.DELETING = 0
+        cluster.status.State.DELETING = 0  # pylint: disable=no-member
 
         cluster2 = mock.MagicMock()
         cluster2.status.state = 0
-        cluster2.status.ERROR = 0
+        cluster2.status.State.ERROR = 0  # pylint: disable=no-member
 
         mock_create_cluster.side_effect = [AlreadyExists("test"), cluster2]
         mock_generator.return_value = [0]
diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py b/tests/providers/google/cloud/sensors/test_dataproc.py
index 1ce8eea..6f2991a 100644
--- a/tests/providers/google/cloud/sensors/test_dataproc.py
+++ b/tests/providers/google/cloud/sensors/test_dataproc.py
@@ -45,7 +45,7 @@ class TestDataprocJobSensor(unittest.TestCase):
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_done(self, mock_hook):
-        job = self.create_job(JobStatus.DONE)
+        job = self.create_job(JobStatus.State.DONE)
         job_id = "job_id"
         mock_hook.return_value.get_job.return_value = job
 
@@ -66,7 +66,7 @@ class TestDataprocJobSensor(unittest.TestCase):
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_error(self, mock_hook):
-        job = self.create_job(JobStatus.ERROR)
+        job = self.create_job(JobStatus.State.ERROR)
         job_id = "job_id"
         mock_hook.return_value.get_job.return_value = job
 
@@ -88,7 +88,7 @@ class TestDataprocJobSensor(unittest.TestCase):
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_wait(self, mock_hook):
-        job = self.create_job(JobStatus.RUNNING)
+        job = self.create_job(JobStatus.State.RUNNING)
         job_id = "job_id"
         mock_hook.return_value.get_job.return_value = job
 
@@ -109,7 +109,7 @@ class TestDataprocJobSensor(unittest.TestCase):
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_cancelled(self, mock_hook):
-        job = self.create_job(JobStatus.CANCELLED)
+        job = self.create_job(JobStatus.State.CANCELLED)
         job_id = "job_id"
         mock_hook.return_value.get_job.return_value = job
 


[airflow] 13/28: Support google-cloud-bigquery-datatransfer>=3.0.0 (#13337)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 62d985b642233767ae7341940540881f65d79d3a
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Thu Dec 31 18:07:32 2020 +0100

    Support google-cloud-bigquery-datatransfer>=3.0.0 (#13337)
    
    (cherry picked from commit 9de71270838ad3cc59043f1ab0bb6ca97af13622)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |  1 +
 .../cloud/example_dags/example_bigquery_dts.py     | 20 ++++------
 .../providers/google/cloud/hooks/bigquery_dts.py   | 45 ++++++++++++++--------
 .../google/cloud/operators/bigquery_dts.py         | 12 +++---
 .../providers/google/cloud/sensors/bigquery_dts.py | 35 ++++++++++++-----
 setup.py                                           |  2 +-
 .../google/cloud/hooks/test_bigquery_dts.py        | 39 ++++++++-----------
 .../google/cloud/operators/test_bigquery_dts.py    | 37 +++++++++++++-----
 .../google/cloud/sensors/test_bigquery_dts.py      | 39 ++++++++++++++++---
 9 files changed, 142 insertions(+), 88 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index b54b240..eca05df 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -29,6 +29,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
 
 | Library name | Previous constraints | Current constraints | |
 | --- | --- | --- | --- |
+| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
 | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=1.0.0,<2.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
 | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
 | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py
index 260dc5d..da13c9d 100644
--- a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py
+++ b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py
@@ -22,9 +22,6 @@ Example Airflow DAG that creates and deletes Bigquery data transfer configuratio
 import os
 import time
 
-from google.cloud.bigquery_datatransfer_v1.types import TransferConfig
-from google.protobuf.json_format import ParseDict
-
 from airflow import models
 from airflow.providers.google.cloud.operators.bigquery_dts import (
     BigQueryCreateDataTransferOperator,
@@ -55,16 +52,13 @@ PARAMS = {
     "file_format": "CSV",
 }
 
-TRANSFER_CONFIG = ParseDict(
-    {
-        "destination_dataset_id": GCP_DTS_BQ_DATASET,
-        "display_name": "GCS Test Config",
-        "data_source_id": "google_cloud_storage",
-        "schedule_options": schedule_options,
-        "params": PARAMS,
-    },
-    TransferConfig(),
-)
+TRANSFER_CONFIG = {
+    "destination_dataset_id": GCP_DTS_BQ_DATASET,
+    "display_name": "GCS Test Config",
+    "data_source_id": "google_cloud_storage",
+    "schedule_options": schedule_options,
+    "params": PARAMS,
+}
 
 # [END howto_bigquery_dts_create_args]
 
diff --git a/airflow/providers/google/cloud/hooks/bigquery_dts.py b/airflow/providers/google/cloud/hooks/bigquery_dts.py
index 2d8d12b..37d42ef 100644
--- a/airflow/providers/google/cloud/hooks/bigquery_dts.py
+++ b/airflow/providers/google/cloud/hooks/bigquery_dts.py
@@ -27,7 +27,6 @@ from google.cloud.bigquery_datatransfer_v1.types import (
     TransferConfig,
     TransferRun,
 )
-from google.protobuf.json_format import MessageToDict, ParseDict
 from googleapiclient.discovery import Resource
 
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@@ -71,7 +70,7 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :param config: Data transfer configuration to create.
         :type config: Union[dict, google.cloud.bigquery_datatransfer_v1.types.TransferConfig]
         """
-        config = MessageToDict(config) if isinstance(config, TransferConfig) else config
+        config = TransferConfig.to_dict(config) if isinstance(config, TransferConfig) else config
         new_config = copy(config)
         schedule_options = new_config.get("schedule_options")
         if schedule_options:
@@ -80,7 +79,11 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
                 schedule_options["disable_auto_scheduling"] = True
         else:
             new_config["schedule_options"] = {"disable_auto_scheduling": True}
-        return ParseDict(new_config, TransferConfig())
+        # HACK: TransferConfig.to_dict returns invalid representation
+        # See: https://github.com/googleapis/python-bigquery-datatransfer/issues/90
+        if isinstance(new_config.get('user_id'), str):
+            new_config['user_id'] = int(new_config['user_id'])
+        return TransferConfig(**new_config)
 
     def get_conn(self) -> DataTransferServiceClient:
         """
@@ -129,14 +132,16 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :return: A ``google.cloud.bigquery_datatransfer_v1.types.TransferConfig`` instance.
         """
         client = self.get_conn()
-        parent = client.project_path(project_id)
+        parent = f"projects/{project_id}"
         return client.create_transfer_config(
-            parent=parent,
-            transfer_config=self._disable_auto_scheduling(transfer_config),
-            authorization_code=authorization_code,
+            request={
+                'parent': parent,
+                'transfer_config': self._disable_auto_scheduling(transfer_config),
+                'authorization_code': authorization_code,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -169,8 +174,10 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :return: None
         """
         client = self.get_conn()
-        name = client.project_transfer_config_path(project=project_id, transfer_config=transfer_config_id)
-        return client.delete_transfer_config(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        name = f"projects/{project_id}/transferConfigs/{transfer_config_id}"
+        return client.delete_transfer_config(
+            request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def start_manual_transfer_runs(
@@ -216,14 +223,16 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :return: An ``google.cloud.bigquery_datatransfer_v1.types.StartManualTransferRunsResponse`` instance.
         """
         client = self.get_conn()
-        parent = client.project_transfer_config_path(project=project_id, transfer_config=transfer_config_id)
+        parent = f"projects/{project_id}/transferConfigs/{transfer_config_id}"
         return client.start_manual_transfer_runs(
-            parent=parent,
-            requested_time_range=requested_time_range,
-            requested_run_time=requested_run_time,
+            request={
+                'parent': parent,
+                'requested_time_range': requested_time_range,
+                'requested_run_time': requested_run_time,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -259,5 +268,7 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance.
         """
         client = self.get_conn()
-        name = client.project_run_path(project=project_id, transfer_config=transfer_config_id, run=run_id)
-        return client.get_transfer_run(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        name = f"projects/{project_id}/transferConfigs/{transfer_config_id}/runs/{run_id}"
+        return client.get_transfer_run(
+            request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
diff --git a/airflow/providers/google/cloud/operators/bigquery_dts.py b/airflow/providers/google/cloud/operators/bigquery_dts.py
index e941bd4..656fc77 100644
--- a/airflow/providers/google/cloud/operators/bigquery_dts.py
+++ b/airflow/providers/google/cloud/operators/bigquery_dts.py
@@ -19,7 +19,7 @@
 from typing import Optional, Sequence, Tuple, Union
 
 from google.api_core.retry import Retry
-from google.protobuf.json_format import MessageToDict
+from google.cloud.bigquery_datatransfer_v1 import StartManualTransferRunsResponse, TransferConfig
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook, get_object_id
@@ -110,7 +110,7 @@ class BigQueryCreateDataTransferOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(response)
+        result = TransferConfig.to_dict(response)
         self.log.info("Created DTS transfer config %s", get_object_id(result))
         self.xcom_push(context, key="transfer_config_id", value=get_object_id(result))
         return result
@@ -289,10 +289,8 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(response)
-        run_id = None
-        if 'runs' in result:
-            run_id = get_object_id(result['runs'][0])
-            self.xcom_push(context, key="run_id", value=run_id)
+        result = StartManualTransferRunsResponse.to_dict(response)
+        run_id = get_object_id(result['runs'][0])
+        self.xcom_push(context, key="run_id", value=run_id)
         self.log.info('Transfer run %s submitted successfully.', run_id)
         return result
diff --git a/airflow/providers/google/cloud/sensors/bigquery_dts.py b/airflow/providers/google/cloud/sensors/bigquery_dts.py
index 5b851ed..49e124c 100644
--- a/airflow/providers/google/cloud/sensors/bigquery_dts.py
+++ b/airflow/providers/google/cloud/sensors/bigquery_dts.py
@@ -19,7 +19,7 @@
 from typing import Optional, Sequence, Set, Tuple, Union
 
 from google.api_core.retry import Retry
-from google.protobuf.json_format import MessageToDict
+from google.cloud.bigquery_datatransfer_v1 import TransferState
 
 from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook
 from airflow.sensors.base import BaseSensorOperator
@@ -81,7 +81,9 @@ class BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator):
         *,
         run_id: str,
         transfer_config_id: str,
-        expected_statuses: Union[Set[str], str] = 'SUCCEEDED',
+        expected_statuses: Union[
+            Set[Union[str, TransferState, int]], str, TransferState, int
+        ] = TransferState.SUCCEEDED,
         project_id: Optional[str] = None,
         gcp_conn_id: str = "google_cloud_default",
         retry: Optional[Retry] = None,
@@ -96,13 +98,29 @@ class BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator):
         self.retry = retry
         self.request_timeout = request_timeout
         self.metadata = metadata
-        self.expected_statuses = (
-            {expected_statuses} if isinstance(expected_statuses, str) else expected_statuses
-        )
+        self.expected_statuses = self._normalize_state_list(expected_statuses)
         self.project_id = project_id
         self.gcp_cloud_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
 
+    def _normalize_state_list(self, states) -> Set[TransferState]:
+        states = {states} if isinstance(states, (str, TransferState, int)) else states
+        result = set()
+        for state in states:
+            if isinstance(state, str):
+                result.add(TransferState[state.upper()])
+            elif isinstance(state, int):
+                result.add(TransferState(state))
+            elif isinstance(state, TransferState):
+                result.add(state)
+            else:
+                raise TypeError(
+                    f"Unsupported type. "
+                    f"Expected: str, int, google.cloud.bigquery_datatransfer_v1.TransferState."
+                    f"Current type: {type(state)}"
+                )
+        return result
+
     def poke(self, context: dict) -> bool:
         hook = BiqQueryDataTransferServiceHook(
             gcp_conn_id=self.gcp_cloud_conn_id,
@@ -116,8 +134,5 @@ class BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator):
             timeout=self.request_timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(run)
-        state = result["state"]
-        self.log.info("Status of %s run: %s", self.run_id, state)
-
-        return state in self.expected_statuses
+        self.log.info("Status of %s run: %s", self.run_id, str(run.state))
+        return run.state in self.expected_statuses
diff --git a/setup.py b/setup.py
index 3df9e47..628ecd1 100644
--- a/setup.py
+++ b/setup.py
@@ -284,7 +284,7 @@ google = [
     'google-auth>=1.0.0,<2.0.0',
     'google-auth-httplib2>=0.0.1',
     'google-cloud-automl>=0.4.0,<2.0.0',
-    'google-cloud-bigquery-datatransfer>=0.4.0,<2.0.0',
+    'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0',
     'google-cloud-bigtable>=1.0.0,<2.0.0',
     'google-cloud-container>=0.1.1,<2.0.0',
     'google-cloud-datacatalog>=1.0.0,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_bigquery_dts.py b/tests/providers/google/cloud/hooks/test_bigquery_dts.py
index 64ad79c..b53cb76 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery_dts.py
@@ -20,9 +20,7 @@ import unittest
 from copy import deepcopy
 from unittest import mock
 
-from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient
 from google.cloud.bigquery_datatransfer_v1.types import TransferConfig
-from google.protobuf.json_format import ParseDict
 
 from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook
 from airflow.version import version
@@ -33,21 +31,18 @@ PROJECT_ID = "id"
 
 PARAMS = {
     "field_delimiter": ",",
-    "max_bad_records": "0",
-    "skip_leading_rows": "1",
+    "max_bad_records": 0,
+    "skip_leading_rows": 1,
     "data_path_template": "bucket",
     "destination_table_name_template": "name",
     "file_format": "CSV",
 }
 
-TRANSFER_CONFIG = ParseDict(
-    {
-        "destination_dataset_id": "dataset",
-        "display_name": "GCS Test Config",
-        "data_source_id": "google_cloud_storage",
-        "params": PARAMS,
-    },
-    TransferConfig(),
+TRANSFER_CONFIG = TransferConfig(
+    destination_dataset_id="dataset",
+    display_name="GCS Test Config",
+    data_source_id="google_cloud_storage",
+    params=PARAMS,
 )
 
 TRANSFER_CONFIG_ID = "id1234"
@@ -77,14 +72,12 @@ class BigQueryDataTransferHookTestCase(unittest.TestCase):
     )
     def test_create_transfer_config(self, service_mock):
         self.hook.create_transfer_config(transfer_config=TRANSFER_CONFIG, project_id=PROJECT_ID)
-        parent = DataTransferServiceClient.project_path(PROJECT_ID)
+        parent = f"projects/{PROJECT_ID}"
         expected_config = deepcopy(TRANSFER_CONFIG)
         expected_config.schedule_options.disable_auto_scheduling = True
         service_mock.assert_called_once_with(
-            parent=parent,
-            transfer_config=expected_config,
-            authorization_code=None,
-            metadata=None,
+            request=dict(parent=parent, transfer_config=expected_config, authorization_code=None),
+            metadata=(),
             retry=None,
             timeout=None,
         )
@@ -96,8 +89,8 @@ class BigQueryDataTransferHookTestCase(unittest.TestCase):
     def test_delete_transfer_config(self, service_mock):
         self.hook.delete_transfer_config(transfer_config_id=TRANSFER_CONFIG_ID, project_id=PROJECT_ID)
 
-        name = DataTransferServiceClient.project_transfer_config_path(PROJECT_ID, TRANSFER_CONFIG_ID)
-        service_mock.assert_called_once_with(name=name, metadata=None, retry=None, timeout=None)
+        name = f"projects/{PROJECT_ID}/transferConfigs/{TRANSFER_CONFIG_ID}"
+        service_mock.assert_called_once_with(request=dict(name=name), metadata=(), retry=None, timeout=None)
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.bigquery_dts."
@@ -106,12 +99,10 @@ class BigQueryDataTransferHookTestCase(unittest.TestCase):
     def test_start_manual_transfer_runs(self, service_mock):
         self.hook.start_manual_transfer_runs(transfer_config_id=TRANSFER_CONFIG_ID, project_id=PROJECT_ID)
 
-        parent = DataTransferServiceClient.project_transfer_config_path(PROJECT_ID, TRANSFER_CONFIG_ID)
+        parent = f"projects/{PROJECT_ID}/transferConfigs/{TRANSFER_CONFIG_ID}"
         service_mock.assert_called_once_with(
-            parent=parent,
-            requested_time_range=None,
-            requested_run_time=None,
-            metadata=None,
+            request=dict(parent=parent, requested_time_range=None, requested_run_time=None),
+            metadata=(),
             retry=None,
             timeout=None,
         )
diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py b/tests/providers/google/cloud/operators/test_bigquery_dts.py
index 4d42352..d6071fa 100644
--- a/tests/providers/google/cloud/operators/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py
@@ -18,6 +18,8 @@
 import unittest
 from unittest import mock
 
+from google.cloud.bigquery_datatransfer_v1 import StartManualTransferRunsResponse, TransferConfig, TransferRun
+
 from airflow.providers.google.cloud.operators.bigquery_dts import (
     BigQueryCreateDataTransferOperator,
     BigQueryDataTransferServiceStartTransferRunsOperator,
@@ -39,20 +41,23 @@ TRANSFER_CONFIG = {
 
 TRANSFER_CONFIG_ID = "id1234"
 
-NAME = "projects/123abc/locations/321cba/transferConfig/1a2b3c"
+TRANSFER_CONFIG_NAME = "projects/123abc/locations/321cba/transferConfig/1a2b3c"
+RUN_NAME = "projects/123abc/locations/321cba/transferConfig/1a2b3c/runs/123"
 
 
 class BigQueryCreateDataTransferOperatorTestCase(unittest.TestCase):
-    @mock.patch("airflow.providers.google.cloud.operators.bigquery_dts.BiqQueryDataTransferServiceHook")
-    @mock.patch("airflow.providers.google.cloud.operators.bigquery_dts.get_object_id")
-    def test_execute(self, mock_name, mock_hook):
-        mock_name.return_value = TRANSFER_CONFIG_ID
-        mock_xcom = mock.MagicMock()
+    @mock.patch(
+        "airflow.providers.google.cloud.operators.bigquery_dts.BiqQueryDataTransferServiceHook",
+        **{'return_value.create_transfer_config.return_value': TransferConfig(name=TRANSFER_CONFIG_NAME)},
+    )
+    def test_execute(self, mock_hook):
         op = BigQueryCreateDataTransferOperator(
             transfer_config=TRANSFER_CONFIG, project_id=PROJECT_ID, task_id="id"
         )
-        op.xcom_push = mock_xcom
-        op.execute(None)
+        ti = mock.MagicMock()
+
+        op.execute({'ti': ti})
+
         mock_hook.return_value.create_transfer_config.assert_called_once_with(
             authorization_code=None,
             metadata=None,
@@ -61,6 +66,7 @@ class BigQueryCreateDataTransferOperatorTestCase(unittest.TestCase):
             retry=None,
             timeout=None,
         )
+        ti.xcom_push.assert_called_once_with(execution_date=None, key='transfer_config_id', value='1a2b3c')
 
 
 class BigQueryDeleteDataTransferConfigOperatorTestCase(unittest.TestCase):
@@ -80,12 +86,22 @@ class BigQueryDeleteDataTransferConfigOperatorTestCase(unittest.TestCase):
 
 
 class BigQueryDataTransferServiceStartTransferRunsOperatorTestCase(unittest.TestCase):
-    @mock.patch("airflow.providers.google.cloud.operators.bigquery_dts.BiqQueryDataTransferServiceHook")
+    @mock.patch(
+        "airflow.providers.google.cloud.operators.bigquery_dts.BiqQueryDataTransferServiceHook",
+        **{
+            'return_value.start_manual_transfer_runs.return_value': StartManualTransferRunsResponse(
+                runs=[TransferRun(name=RUN_NAME)]
+            )
+        },
+    )
     def test_execute(self, mock_hook):
         op = BigQueryDataTransferServiceStartTransferRunsOperator(
             transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID
         )
-        op.execute(None)
+        ti = mock.MagicMock()
+
+        op.execute({'ti': ti})
+
         mock_hook.return_value.start_manual_transfer_runs.assert_called_once_with(
             transfer_config_id=TRANSFER_CONFIG_ID,
             project_id=PROJECT_ID,
@@ -95,3 +111,4 @@ class BigQueryDataTransferServiceStartTransferRunsOperatorTestCase(unittest.Test
             retry=None,
             timeout=None,
         )
+        ti.xcom_push.assert_called_once_with(execution_date=None, key='run_id', value='123')
diff --git a/tests/providers/google/cloud/sensors/test_bigquery_dts.py b/tests/providers/google/cloud/sensors/test_bigquery_dts.py
index 92a116e..c8a0548 100644
--- a/tests/providers/google/cloud/sensors/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/sensors/test_bigquery_dts.py
@@ -19,6 +19,8 @@
 import unittest
 from unittest import mock
 
+from google.cloud.bigquery_datatransfer_v1 import TransferState
+
 from airflow.providers.google.cloud.sensors.bigquery_dts import BigQueryDataTransferServiceTransferRunSensor
 
 TRANSFER_CONFIG_ID = "config_id"
@@ -27,20 +29,45 @@ PROJECT_ID = "project_id"
 
 
 class TestBigQueryDataTransferServiceTransferRunSensor(unittest.TestCase):
-    @mock.patch("airflow.providers.google.cloud.sensors.bigquery_dts.BiqQueryDataTransferServiceHook")
     @mock.patch(
-        "airflow.providers.google.cloud.sensors.bigquery_dts.MessageToDict",
-        return_value={"state": "success"},
+        "airflow.providers.google.cloud.sensors.bigquery_dts.BiqQueryDataTransferServiceHook",
+        **{'return_value.get_transfer_run.return_value.state': TransferState.FAILED},
+    )
+    def test_poke_returns_false(self, mock_hook):
+        op = BigQueryDataTransferServiceTransferRunSensor(
+            transfer_config_id=TRANSFER_CONFIG_ID,
+            run_id=RUN_ID,
+            task_id="id",
+            project_id=PROJECT_ID,
+            expected_statuses={"SUCCEEDED"},
+        )
+        result = op.poke({})
+
+        self.assertEqual(result, False)
+        mock_hook.return_value.get_transfer_run.assert_called_once_with(
+            transfer_config_id=TRANSFER_CONFIG_ID,
+            run_id=RUN_ID,
+            project_id=PROJECT_ID,
+            metadata=None,
+            retry=None,
+            timeout=None,
+        )
+
+    @mock.patch(
+        "airflow.providers.google.cloud.sensors.bigquery_dts.BiqQueryDataTransferServiceHook",
+        **{'return_value.get_transfer_run.return_value.state': TransferState.SUCCEEDED},
     )
-    def test_poke(self, mock_msg_to_dict, mock_hook):
+    def test_poke_returns_true(self, mock_hook):
         op = BigQueryDataTransferServiceTransferRunSensor(
             transfer_config_id=TRANSFER_CONFIG_ID,
             run_id=RUN_ID,
             task_id="id",
             project_id=PROJECT_ID,
-            expected_statuses={"success"},
+            expected_statuses={"SUCCEEDED"},
         )
-        op.poke(None)
+        result = op.poke({})
+
+        self.assertEqual(result, True)
         mock_hook.return_value.get_transfer_run.assert_called_once_with(
             transfer_config_id=TRANSFER_CONFIG_ID,
             run_id=RUN_ID,


[airflow] 05/28: Upgrade slack_sdk to v3 (#13745)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3c5173aa65778e495716a73f29f4b10632617f14
Author: Jyoti Dhiman <36...@users.noreply.github.com>
AuthorDate: Tue Jan 26 02:43:48 2021 +0530

    Upgrade slack_sdk to v3 (#13745)
    
    Co-authored-by: Kamil BreguĊ‚a <ka...@polidea.com>
    Co-authored-by: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
    (cherry picked from commit 283945001363d8f492fbd25f2765d39fa06d757a)
---
 airflow/providers/slack/ADDITIONAL_INFO.md         | 25 ++++++++++++++++++++++
 .../providers/slack/BACKPORT_PROVIDER_README.md    |  2 +-
 airflow/providers/slack/README.md                  |  2 +-
 airflow/providers/slack/hooks/slack.py             |  4 ++--
 docs/conf.py                                       |  2 +-
 docs/spelling_wordlist.txt                         |  1 +
 scripts/ci/libraries/_verify_image.sh              |  2 +-
 setup.py                                           |  2 +-
 tests/providers/slack/hooks/test_slack.py          |  2 +-
 9 files changed, 34 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/slack/ADDITIONAL_INFO.md b/airflow/providers/slack/ADDITIONAL_INFO.md
new file mode 100644
index 0000000..9b05d8a
--- /dev/null
+++ b/airflow/providers/slack/ADDITIONAL_INFO.md
@@ -0,0 +1,25 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
+
+# Migration Guide
+
+## 2.0.0
+
+We updated the support for `slack_sdk` from ``>=2.0.0,<3.0.0`` to ``>=3.0.0,<4.0.0``. In most cases, this doesn't mean any breaking changes to the DAG files, but if you used this library directly then you have to make the changes.
+For details, see [the Migration Guide](https://slack.dev/python-slack-sdk/v3-migration/index.html#from-slackclient-2-x) for Python Slack SDK.
diff --git a/airflow/providers/slack/BACKPORT_PROVIDER_README.md b/airflow/providers/slack/BACKPORT_PROVIDER_README.md
index 7863eb4..0e20d06 100644
--- a/airflow/providers/slack/BACKPORT_PROVIDER_README.md
+++ b/airflow/providers/slack/BACKPORT_PROVIDER_README.md
@@ -60,7 +60,7 @@ You can install this package on top of an existing airflow 1.10.* installation v
 
 | PIP package   | Version required   |
 |:--------------|:-------------------|
-| slackclient   | &gt;=2.0.0,&lt;3.0.0     |
+| slack_sdk   | &gt;=3.0.0,&lt;4.0.0     |
 
 ## Cross provider package dependencies
 
diff --git a/airflow/providers/slack/README.md b/airflow/providers/slack/README.md
index 7a630c6..ea4968a 100644
--- a/airflow/providers/slack/README.md
+++ b/airflow/providers/slack/README.md
@@ -61,7 +61,7 @@ You can install this package on top of an existing airflow 2.* installation via
 
 | PIP package   | Version required   |
 |:--------------|:-------------------|
-| slackclient   | &gt;=2.0.0,&lt;3.0.0     |
+| slack_sdk   | &gt;=3.0.0,&lt;4.0.0     |
 
 ## Cross provider package dependencies
 
diff --git a/airflow/providers/slack/hooks/slack.py b/airflow/providers/slack/hooks/slack.py
index 6f27091..da449a7 100644
--- a/airflow/providers/slack/hooks/slack.py
+++ b/airflow/providers/slack/hooks/slack.py
@@ -18,7 +18,7 @@
 """Hook for Slack"""
 from typing import Any, Optional
 
-from slack import WebClient
+from slack_sdk import WebClient
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
@@ -41,7 +41,7 @@ class SlackHook(BaseHook):  # noqa
         slack_hook.call("chat.postMessage", json={"channel": "#random", "text": "Hello world!"})
 
         # Call method from Slack SDK (you have to handle errors yourself)
-        #  For more details check https://slack.dev/python-slackclient/basic_usage.html#sending-a-message
+        #  For more details check https://slack.dev/python-slack-sdk/web/index.html#messaging
         slack_hook.client.chat_postMessage(channel="#random", text="Hello world!")
 
     :param token: Slack API token
diff --git a/docs/conf.py b/docs/conf.py
index a60bbe3..411796c 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -391,7 +391,7 @@ autodoc_mock_imports = [
     'qds_sdk',
     'redis',
     'simple_salesforce',
-    'slackclient',
+    'slack_sdk',
     'smbclient',
     'snowflake',
     'sshtunnel',
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index f8f8f83..71f9e34 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1204,6 +1204,7 @@ skipable
 sku
 sla
 slackclient
+slack_sdk
 slas
 smtp
 sortable
diff --git a/scripts/ci/libraries/_verify_image.sh b/scripts/ci/libraries/_verify_image.sh
index 2092bd2..05e91c6 100644
--- a/scripts/ci/libraries/_verify_image.sh
+++ b/scripts/ci/libraries/_verify_image.sh
@@ -190,7 +190,7 @@ function verify_image::verify_production_image_python_modules() {
     verify_image::check_command "Import: redis" "python -c 'import redis'"
     verify_image::check_command "Import: sendgrid" "python -c 'import sendgrid'"
     verify_image::check_command "Import: sftp/ssh" "python -c 'import paramiko, pysftp, sshtunnel'"
-    verify_image::check_command "Import: slack" "python -c 'import slack'"
+    verify_image::check_command "Import: slack" "python -c 'import slack_sdk'"
     verify_image::check_command "Import: statsd" "python -c 'import statsd'"
     verify_image::check_command "Import: virtualenv" "python -c 'import virtualenv'"
 
diff --git a/setup.py b/setup.py
index 50f6a2f..0689bd5 100644
--- a/setup.py
+++ b/setup.py
@@ -417,7 +417,7 @@ sentry = [
 ]
 singularity = ['spython>=0.0.56']
 slack = [
-    'slackclient>=2.0.0,<3.0.0',
+    'slack_sdk>=3.0.0,<4.0.0',
 ]
 snowflake = [
     # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
diff --git a/tests/providers/slack/hooks/test_slack.py b/tests/providers/slack/hooks/test_slack.py
index cbe3d26..5fef409 100644
--- a/tests/providers/slack/hooks/test_slack.py
+++ b/tests/providers/slack/hooks/test_slack.py
@@ -20,7 +20,7 @@ import unittest
 from unittest import mock
 
 import pytest
-from slack.errors import SlackApiError
+from slack_sdk.errors import SlackApiError
 
 from airflow.exceptions import AirflowException
 from airflow.providers.slack.hooks.slack import SlackHook


[airflow] 25/28: Pin moto to <2 (#14433)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 982a3a24ff0643bd37ca25fa5da93616beac26fa
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Wed Feb 24 22:22:54 2021 +0000

    Pin moto to <2 (#14433)
    
    https://pypi.org/project/moto/#history -- moto 2.0.0 was released yesterday and is causing CI failures
    (cherry picked from commit 802159767baf1768d92c6047c2fdb2094ee7a2a8)
---
 setup.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/setup.py b/setup.py
index ad4fdd5..2867b36 100644
--- a/setup.py
+++ b/setup.py
@@ -491,7 +491,7 @@ devel = [
     # See: https://github.com/spulec/moto/issues/3535
     'mock<4.0.3',
     'mongomock',
-    'moto',
+    'moto<2',
     'mypy==0.770',
     'parameterized',
     'paramiko',


[airflow] 17/28: Support google-cloud-tasks>=2.0.0 (#13347)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7b31def58ed8a136d2ca9f5c55ffb55f585370d6
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Thu Jan 14 12:18:49 2021 +0100

    Support google-cloud-tasks>=2.0.0 (#13347)
    
    (cherry picked from commit ef8617ec9d6e4b7c433a29bd388f5102a7a17c11)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |   4 +-
 airflow/providers/google/cloud/hooks/tasks.py      | 118 ++++++++---------
 airflow/providers/google/cloud/operators/tasks.py  |  41 +++---
 setup.py                                           |   2 +-
 tests/providers/google/cloud/hooks/test_tasks.py   |  86 ++++++-------
 .../providers/google/cloud/operators/test_tasks.py | 140 ++++++++++++++++-----
 6 files changed, 235 insertions(+), 156 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index 800703b..c696e1b 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -32,10 +32,10 @@ Details are covered in the UPDATING.md files for each library, but there are som
 | [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
 | [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
 | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
+| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
 | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
 | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
-| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
-
+| [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) |
 
 ### The field names use the snake_case convention
 
diff --git a/airflow/providers/google/cloud/hooks/tasks.py b/airflow/providers/google/cloud/hooks/tasks.py
index 1c3223d..633f227 100644
--- a/airflow/providers/google/cloud/hooks/tasks.py
+++ b/airflow/providers/google/cloud/hooks/tasks.py
@@ -21,11 +21,13 @@ This module contains a CloudTasksHook
 which allows you to connect to Google Cloud Tasks service,
 performing actions to queues or tasks.
 """
+
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 from google.api_core.retry import Retry
-from google.cloud.tasks_v2 import CloudTasksClient, enums
-from google.cloud.tasks_v2.types import FieldMask, Queue, Task
+from google.cloud.tasks_v2 import CloudTasksClient
+from google.cloud.tasks_v2.types import Queue, Task
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow.exceptions import AirflowException
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@@ -120,20 +122,19 @@ class CloudTasksHook(GoogleBaseHook):
         client = self.get_conn()
 
         if queue_name:
-            full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
+            full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
             if isinstance(task_queue, Queue):
                 task_queue.name = full_queue_name
             elif isinstance(task_queue, dict):
                 task_queue['name'] = full_queue_name
             else:
                 raise AirflowException('Unable to set queue_name.')
-        full_location_path = CloudTasksClient.location_path(project_id, location)
+        full_location_path = f"projects/{project_id}/locations/{location}"
         return client.create_queue(
-            parent=full_location_path,
-            queue=task_queue,
+            request={'parent': full_location_path, 'queue': task_queue},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -167,7 +168,7 @@ class CloudTasksHook(GoogleBaseHook):
         :param update_mask: A mast used to specify which fields of the queue are being updated.
             If empty, then all fields will be updated.
             If a dict is provided, it must be of the same form as the protobuf message.
-        :type update_mask: dict or google.cloud.tasks_v2.types.FieldMask
+        :type update_mask: dict or google.protobuf.field_mask_pb2.FieldMask
         :param retry: (Optional) A retry object used to retry requests.
             If None is specified, requests will not be retried.
         :type retry: google.api_core.retry.Retry
@@ -182,7 +183,7 @@ class CloudTasksHook(GoogleBaseHook):
         client = self.get_conn()
 
         if queue_name and location:
-            full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
+            full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
             if isinstance(task_queue, Queue):
                 task_queue.name = full_queue_name
             elif isinstance(task_queue, dict):
@@ -190,11 +191,10 @@ class CloudTasksHook(GoogleBaseHook):
             else:
                 raise AirflowException('Unable to set queue_name.')
         return client.update_queue(
-            queue=task_queue,
-            update_mask=update_mask,
+            request={'queue': task_queue, 'update_mask': update_mask},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -230,8 +230,10 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
-        return client.get_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
+        full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
+        return client.get_queue(
+            request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def list_queues(
@@ -270,14 +272,12 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_location_path = CloudTasksClient.location_path(project_id, location)
+        full_location_path = f"projects/{project_id}/locations/{location}"
         queues = client.list_queues(
-            parent=full_location_path,
-            filter_=results_filter,
-            page_size=page_size,
+            request={'parent': full_location_path, 'filter': results_filter, 'page_size': page_size},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return list(queues)
 
@@ -313,8 +313,10 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
-        client.delete_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
+        full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
+        client.delete_queue(
+            request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def purge_queue(
@@ -349,8 +351,10 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
-        return client.purge_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
+        full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
+        return client.purge_queue(
+            request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def pause_queue(
@@ -385,8 +389,10 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
-        return client.pause_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
+        full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
+        return client.pause_queue(
+            request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def resume_queue(
@@ -421,8 +427,10 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
-        return client.resume_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata)
+        full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
+        return client.resume_queue(
+            request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def create_task(
@@ -432,7 +440,7 @@ class CloudTasksHook(GoogleBaseHook):
         task: Union[Dict, Task],
         project_id: str,
         task_name: Optional[str] = None,
-        response_view: Optional[enums.Task.View] = None,
+        response_view: Optional = None,
         retry: Optional[Retry] = None,
         timeout: Optional[float] = None,
         metadata: Optional[Sequence[Tuple[str, str]]] = None,
@@ -455,7 +463,7 @@ class CloudTasksHook(GoogleBaseHook):
         :type task_name: str
         :param response_view: (Optional) This field specifies which subset of the Task will
             be returned.
-        :type response_view: google.cloud.tasks_v2.enums.Task.View
+        :type response_view: google.cloud.tasks_v2.Task.View
         :param retry: (Optional) A retry object used to retry requests.
             If None is specified, requests will not be retried.
         :type retry: google.api_core.retry.Retry
@@ -470,21 +478,21 @@ class CloudTasksHook(GoogleBaseHook):
         client = self.get_conn()
 
         if task_name:
-            full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
+            full_task_name = (
+                f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
+            )
             if isinstance(task, Task):
                 task.name = full_task_name
             elif isinstance(task, dict):
                 task['name'] = full_task_name
             else:
                 raise AirflowException('Unable to set task_name.')
-        full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
+        full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
         return client.create_task(
-            parent=full_queue_name,
-            task=task,
-            response_view=response_view,
+            request={'parent': full_queue_name, 'task': task, 'response_view': response_view},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -494,7 +502,7 @@ class CloudTasksHook(GoogleBaseHook):
         queue_name: str,
         task_name: str,
         project_id: str,
-        response_view: Optional[enums.Task.View] = None,
+        response_view: Optional = None,
         retry: Optional[Retry] = None,
         timeout: Optional[float] = None,
         metadata: Optional[Sequence[Tuple[str, str]]] = None,
@@ -513,7 +521,7 @@ class CloudTasksHook(GoogleBaseHook):
         :type project_id: str
         :param response_view: (Optional) This field specifies which subset of the Task will
             be returned.
-        :type response_view: google.cloud.tasks_v2.enums.Task.View
+        :type response_view: google.cloud.tasks_v2.Task.View
         :param retry: (Optional) A retry object used to retry requests.
             If None is specified, requests will not be retried.
         :type retry: google.api_core.retry.Retry
@@ -527,13 +535,12 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
+        full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
         return client.get_task(
-            name=full_task_name,
-            response_view=response_view,
+            request={'name': full_task_name, 'response_view': response_view},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -542,7 +549,7 @@ class CloudTasksHook(GoogleBaseHook):
         location: str,
         queue_name: str,
         project_id: str,
-        response_view: Optional[enums.Task.View] = None,
+        response_view: Optional = None,
         page_size: Optional[int] = None,
         retry: Optional[Retry] = None,
         timeout: Optional[float] = None,
@@ -560,7 +567,7 @@ class CloudTasksHook(GoogleBaseHook):
         :type project_id: str
         :param response_view: (Optional) This field specifies which subset of the Task will
             be returned.
-        :type response_view: google.cloud.tasks_v2.enums.Task.View
+        :type response_view: google.cloud.tasks_v2.Task.View
         :param page_size: (Optional) The maximum number of resources contained in the
             underlying API response.
         :type page_size: int
@@ -576,14 +583,12 @@ class CloudTasksHook(GoogleBaseHook):
         :rtype: list[google.cloud.tasks_v2.types.Task]
         """
         client = self.get_conn()
-        full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name)
+        full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}"
         tasks = client.list_tasks(
-            parent=full_queue_name,
-            response_view=response_view,
-            page_size=page_size,
+            request={'parent': full_queue_name, 'response_view': response_view, 'page_size': page_size},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return list(tasks)
 
@@ -622,8 +627,10 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
-        client.delete_task(name=full_task_name, retry=retry, timeout=timeout, metadata=metadata)
+        full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
+        client.delete_task(
+            request={'name': full_task_name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def run_task(
@@ -632,7 +639,7 @@ class CloudTasksHook(GoogleBaseHook):
         queue_name: str,
         task_name: str,
         project_id: str,
-        response_view: Optional[enums.Task.View] = None,
+        response_view: Optional = None,
         retry: Optional[Retry] = None,
         timeout: Optional[float] = None,
         metadata: Optional[Sequence[Tuple[str, str]]] = None,
@@ -651,7 +658,7 @@ class CloudTasksHook(GoogleBaseHook):
         :type project_id: str
         :param response_view: (Optional) This field specifies which subset of the Task will
             be returned.
-        :type response_view: google.cloud.tasks_v2.enums.Task.View
+        :type response_view: google.cloud.tasks_v2.Task.View
         :param retry: (Optional) A retry object used to retry requests.
             If None is specified, requests will not be retried.
         :type retry: google.api_core.retry.Retry
@@ -665,11 +672,10 @@ class CloudTasksHook(GoogleBaseHook):
         """
         client = self.get_conn()
 
-        full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name)
+        full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}"
         return client.run_task(
-            name=full_task_name,
-            response_view=response_view,
+            request={'name': full_task_name, 'response_view': response_view},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
diff --git a/airflow/providers/google/cloud/operators/tasks.py b/airflow/providers/google/cloud/operators/tasks.py
index 6598d66..2834b32 100644
--- a/airflow/providers/google/cloud/operators/tasks.py
+++ b/airflow/providers/google/cloud/operators/tasks.py
@@ -25,9 +25,8 @@ from typing import Dict, Optional, Sequence, Tuple, Union
 
 from google.api_core.exceptions import AlreadyExists
 from google.api_core.retry import Retry
-from google.cloud.tasks_v2 import enums
-from google.cloud.tasks_v2.types import FieldMask, Queue, Task
-from google.protobuf.json_format import MessageToDict
+from google.cloud.tasks_v2.types import Queue, Task
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook
@@ -136,7 +135,7 @@ class CloudTasksQueueCreateOperator(BaseOperator):
                 metadata=self.metadata,
             )
 
-        return MessageToDict(queue)
+        return Queue.to_dict(queue)
 
 
 class CloudTasksQueueUpdateOperator(BaseOperator):
@@ -159,7 +158,7 @@ class CloudTasksQueueUpdateOperator(BaseOperator):
     :param update_mask: A mast used to specify which fields of the queue are being updated.
         If empty, then all fields will be updated.
         If a dict is provided, it must be of the same form as the protobuf message.
-    :type update_mask: dict or google.cloud.tasks_v2.types.FieldMask
+    :type update_mask: dict or google.protobuf.field_mask_pb2.FieldMask
     :param retry: (Optional) A retry object used to retry requests.
         If None is specified, requests will not be retried.
     :type retry: google.api_core.retry.Retry
@@ -237,7 +236,7 @@ class CloudTasksQueueUpdateOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(queue)
+        return Queue.to_dict(queue)
 
 
 class CloudTasksQueueGetOperator(BaseOperator):
@@ -320,7 +319,7 @@ class CloudTasksQueueGetOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(queue)
+        return Queue.to_dict(queue)
 
 
 class CloudTasksQueuesListOperator(BaseOperator):
@@ -408,7 +407,7 @@ class CloudTasksQueuesListOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return [MessageToDict(q) for q in queues]
+        return [Queue.to_dict(q) for q in queues]
 
 
 class CloudTasksQueueDeleteOperator(BaseOperator):
@@ -571,7 +570,7 @@ class CloudTasksQueuePurgeOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(queue)
+        return Queue.to_dict(queue)
 
 
 class CloudTasksQueuePauseOperator(BaseOperator):
@@ -646,7 +645,7 @@ class CloudTasksQueuePauseOperator(BaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
-        queues = hook.pause_queue(
+        queue = hook.pause_queue(
             location=self.location,
             queue_name=self.queue_name,
             project_id=self.project_id,
@@ -654,7 +653,7 @@ class CloudTasksQueuePauseOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return [MessageToDict(q) for q in queues]
+        return Queue.to_dict(queue)
 
 
 class CloudTasksQueueResumeOperator(BaseOperator):
@@ -737,7 +736,7 @@ class CloudTasksQueueResumeOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(queue)
+        return Queue.to_dict(queue)
 
 
 class CloudTasksTaskCreateOperator(BaseOperator):
@@ -803,7 +802,7 @@ class CloudTasksTaskCreateOperator(BaseOperator):
         task: Union[Dict, Task],
         project_id: Optional[str] = None,
         task_name: Optional[str] = None,
-        response_view: Optional[enums.Task.View] = None,
+        response_view: Optional = None,
         retry: Optional[Retry] = None,
         timeout: Optional[float] = None,
         metadata: Optional[MetaData] = None,
@@ -840,7 +839,7 @@ class CloudTasksTaskCreateOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(task)
+        return Task.to_dict(task)
 
 
 class CloudTasksTaskGetOperator(BaseOperator):
@@ -900,7 +899,7 @@ class CloudTasksTaskGetOperator(BaseOperator):
         queue_name: str,
         task_name: str,
         project_id: Optional[str] = None,
-        response_view: Optional[enums.Task.View] = None,
+        response_view: Optional = None,
         retry: Optional[Retry] = None,
         timeout: Optional[float] = None,
         metadata: Optional[MetaData] = None,
@@ -935,7 +934,7 @@ class CloudTasksTaskGetOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(task)
+        return Task.to_dict(task)
 
 
 class CloudTasksTasksListOperator(BaseOperator):
@@ -994,7 +993,7 @@ class CloudTasksTasksListOperator(BaseOperator):
         location: str,
         queue_name: str,
         project_id: Optional[str] = None,
-        response_view: Optional[enums.Task.View] = None,
+        response_view: Optional = None,
         page_size: Optional[int] = None,
         retry: Optional[Retry] = None,
         timeout: Optional[float] = None,
@@ -1030,7 +1029,7 @@ class CloudTasksTasksListOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return [MessageToDict(t) for t in tasks]
+        return [Task.to_dict(t) for t in tasks]
 
 
 class CloudTasksTaskDeleteOperator(BaseOperator):
@@ -1134,7 +1133,7 @@ class CloudTasksTaskRunOperator(BaseOperator):
     :type project_id: str
     :param response_view: (Optional) This field specifies which subset of the Task will
         be returned.
-    :type response_view: google.cloud.tasks_v2.enums.Task.View
+    :type response_view: google.cloud.tasks_v2.Task.View
     :param retry: (Optional) A retry object used to retry requests.
         If None is specified, requests will not be retried.
     :type retry: google.api_core.retry.Retry
@@ -1176,7 +1175,7 @@ class CloudTasksTaskRunOperator(BaseOperator):
         queue_name: str,
         task_name: str,
         project_id: Optional[str] = None,
-        response_view: Optional[enums.Task.View] = None,
+        response_view: Optional = None,
         retry: Optional[Retry] = None,
         timeout: Optional[float] = None,
         metadata: Optional[MetaData] = None,
@@ -1211,4 +1210,4 @@ class CloudTasksTaskRunOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(task)
+        return Task.to_dict(task)
diff --git a/setup.py b/setup.py
index ff9e65d..520b059 100644
--- a/setup.py
+++ b/setup.py
@@ -302,7 +302,7 @@ google = [
     'google-cloud-spanner>=1.10.0,<2.0.0',
     'google-cloud-speech>=0.36.3,<2.0.0',
     'google-cloud-storage>=1.30,<2.0.0',
-    'google-cloud-tasks>=1.2.1,<2.0.0',
+    'google-cloud-tasks>=2.0.0,<3.0.0',
     'google-cloud-texttospeech>=0.4.0,<2.0.0',
     'google-cloud-translate>=1.5.0,<2.0.0',
     'google-cloud-videointelligence>=1.7.0,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_tasks.py b/tests/providers/google/cloud/hooks/test_tasks.py
index 8be6686..6504595 100644
--- a/tests/providers/google/cloud/hooks/test_tasks.py
+++ b/tests/providers/google/cloud/hooks/test_tasks.py
@@ -72,11 +72,10 @@ class TestCloudTasksHook(unittest.TestCase):
         self.assertIs(result, API_RESPONSE)
 
         get_conn.return_value.create_queue.assert_called_once_with(
-            parent=FULL_LOCATION_PATH,
-            queue=Queue(name=FULL_QUEUE_PATH),
+            request=dict(parent=FULL_LOCATION_PATH, queue=Queue(name=FULL_QUEUE_PATH)),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -94,11 +93,10 @@ class TestCloudTasksHook(unittest.TestCase):
         self.assertIs(result, API_RESPONSE)
 
         get_conn.return_value.update_queue.assert_called_once_with(
-            queue=Queue(name=FULL_QUEUE_PATH, state=3),
-            update_mask=None,
+            request=dict(queue=Queue(name=FULL_QUEUE_PATH, state=3), update_mask=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -111,30 +109,28 @@ class TestCloudTasksHook(unittest.TestCase):
         self.assertIs(result, API_RESPONSE)
 
         get_conn.return_value.get_queue.assert_called_once_with(
-            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
+            request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.list_queues.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.list_queues.return_value": [Queue(name=FULL_QUEUE_PATH)]},  # type: ignore
     )
     def test_list_queues(self, get_conn):
         result = self.hook.list_queues(location=LOCATION, project_id=PROJECT_ID)
 
-        self.assertEqual(result, list(API_RESPONSE))
+        self.assertEqual(result, [Queue(name=FULL_QUEUE_PATH)])
 
         get_conn.return_value.list_queues.assert_called_once_with(
-            parent=FULL_LOCATION_PATH,
-            filter_=None,
-            page_size=None,
+            request=dict(parent=FULL_LOCATION_PATH, filter=None, page_size=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.delete_queue.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.delete_queue.return_value": None},  # type: ignore
     )
     def test_delete_queue(self, get_conn):
         result = self.hook.delete_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
@@ -142,51 +138,51 @@ class TestCloudTasksHook(unittest.TestCase):
         self.assertEqual(result, None)
 
         get_conn.return_value.delete_queue.assert_called_once_with(
-            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
+            request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.purge_queue.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.purge_queue.return_value": Queue(name=FULL_QUEUE_PATH)},  # type: ignore
     )
     def test_purge_queue(self, get_conn):
         result = self.hook.purge_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
 
-        self.assertEqual(result, API_RESPONSE)
+        self.assertEqual(result, Queue(name=FULL_QUEUE_PATH))
 
         get_conn.return_value.purge_queue.assert_called_once_with(
-            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
+            request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.pause_queue.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.pause_queue.return_value": Queue(name=FULL_QUEUE_PATH)},  # type: ignore
     )
     def test_pause_queue(self, get_conn):
         result = self.hook.pause_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
 
-        self.assertEqual(result, API_RESPONSE)
+        self.assertEqual(result, Queue(name=FULL_QUEUE_PATH))
 
         get_conn.return_value.pause_queue.assert_called_once_with(
-            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
+            request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.resume_queue.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.resume_queue.return_value": Queue(name=FULL_QUEUE_PATH)},  # type: ignore
     )
     def test_resume_queue(self, get_conn):
         result = self.hook.resume_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
 
-        self.assertEqual(result, API_RESPONSE)
+        self.assertEqual(result, Queue(name=FULL_QUEUE_PATH))
 
         get_conn.return_value.resume_queue.assert_called_once_with(
-            name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None
+            request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.create_task.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.create_task.return_value": Task(name=FULL_TASK_PATH)},  # type: ignore
     )
     def test_create_task(self, get_conn):
         result = self.hook.create_task(
@@ -197,20 +193,18 @@ class TestCloudTasksHook(unittest.TestCase):
             task_name=TASK_NAME,
         )
 
-        self.assertEqual(result, API_RESPONSE)
+        self.assertEqual(result, Task(name=FULL_TASK_PATH))
 
         get_conn.return_value.create_task.assert_called_once_with(
-            parent=FULL_QUEUE_PATH,
-            task=Task(name=FULL_TASK_PATH),
-            response_view=None,
+            request=dict(parent=FULL_QUEUE_PATH, task=Task(name=FULL_TASK_PATH), response_view=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.get_task.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.get_task.return_value": Task(name=FULL_TASK_PATH)},  # type: ignore
     )
     def test_get_task(self, get_conn):
         result = self.hook.get_task(
@@ -220,37 +214,34 @@ class TestCloudTasksHook(unittest.TestCase):
             project_id=PROJECT_ID,
         )
 
-        self.assertEqual(result, API_RESPONSE)
+        self.assertEqual(result, Task(name=FULL_TASK_PATH))
 
         get_conn.return_value.get_task.assert_called_once_with(
-            name=FULL_TASK_PATH,
-            response_view=None,
+            request=dict(name=FULL_TASK_PATH, response_view=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.list_tasks.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.list_tasks.return_value": [Task(name=FULL_TASK_PATH)]},  # type: ignore
     )
     def test_list_tasks(self, get_conn):
         result = self.hook.list_tasks(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID)
 
-        self.assertEqual(result, list(API_RESPONSE))
+        self.assertEqual(result, [Task(name=FULL_TASK_PATH)])
 
         get_conn.return_value.list_tasks.assert_called_once_with(
-            parent=FULL_QUEUE_PATH,
-            response_view=None,
-            page_size=None,
+            request=dict(parent=FULL_QUEUE_PATH, response_view=None, page_size=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.delete_task.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.delete_task.return_value": None},  # type: ignore
     )
     def test_delete_task(self, get_conn):
         result = self.hook.delete_task(
@@ -263,12 +254,12 @@ class TestCloudTasksHook(unittest.TestCase):
         self.assertEqual(result, None)
 
         get_conn.return_value.delete_task.assert_called_once_with(
-            name=FULL_TASK_PATH, retry=None, timeout=None, metadata=None
+            request=dict(name=FULL_TASK_PATH), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn",
-        **{"return_value.run_task.return_value": API_RESPONSE},  # type: ignore
+        **{"return_value.run_task.return_value": Task(name=FULL_TASK_PATH)},  # type: ignore
     )
     def test_run_task(self, get_conn):
         result = self.hook.run_task(
@@ -278,12 +269,11 @@ class TestCloudTasksHook(unittest.TestCase):
             project_id=PROJECT_ID,
         )
 
-        self.assertEqual(result, API_RESPONSE)
+        self.assertEqual(result, Task(name=FULL_TASK_PATH))
 
         get_conn.return_value.run_task.assert_called_once_with(
-            name=FULL_TASK_PATH,
-            response_view=None,
+            request=dict(name=FULL_TASK_PATH, response_view=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
diff --git a/tests/providers/google/cloud/operators/test_tasks.py b/tests/providers/google/cloud/operators/test_tasks.py
index cac1441..ed76911 100644
--- a/tests/providers/google/cloud/operators/test_tasks.py
+++ b/tests/providers/google/cloud/operators/test_tasks.py
@@ -45,21 +45,26 @@ QUEUE_ID = "test-queue"
 FULL_QUEUE_PATH = "projects/test-project/locations/asia-east2/queues/test-queue"
 TASK_NAME = "test-task"
 FULL_TASK_PATH = "projects/test-project/locations/asia-east2/queues/test-queue/tasks/test-task"
+TEST_QUEUE = Queue(name=FULL_QUEUE_PATH)
+TEST_TASK = Task(app_engine_http_request={})
 
 
 class TestCloudTasksQueueCreate(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_create_queue(self, mock_hook):
-        mock_hook.return_value.create_queue.return_value = mock.MagicMock()
-        operator = CloudTasksQueueCreateOperator(location=LOCATION, task_queue=Queue(), task_id="id")
-        operator.execute(context=None)
+        mock_hook.return_value.create_queue.return_value = TEST_QUEUE
+        operator = CloudTasksQueueCreateOperator(location=LOCATION, task_queue=TEST_QUEUE, task_id="id")
+
+        result = operator.execute(context=None)
+
+        self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
         )
         mock_hook.return_value.create_queue.assert_called_once_with(
             location=LOCATION,
-            task_queue=Queue(),
+            task_queue=TEST_QUEUE,
             project_id=None,
             queue_name=None,
             retry=None,
@@ -71,9 +76,12 @@ class TestCloudTasksQueueCreate(unittest.TestCase):
 class TestCloudTasksQueueUpdate(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_update_queue(self, mock_hook):
-        mock_hook.return_value.update_queue.return_value = mock.MagicMock()
+        mock_hook.return_value.update_queue.return_value = TEST_QUEUE
         operator = CloudTasksQueueUpdateOperator(task_queue=Queue(name=FULL_QUEUE_PATH), task_id="id")
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -93,9 +101,12 @@ class TestCloudTasksQueueUpdate(unittest.TestCase):
 class TestCloudTasksQueueGet(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_get_queue(self, mock_hook):
-        mock_hook.return_value.get_queue.return_value = mock.MagicMock()
+        mock_hook.return_value.get_queue.return_value = TEST_QUEUE
         operator = CloudTasksQueueGetOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id")
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -113,9 +124,12 @@ class TestCloudTasksQueueGet(unittest.TestCase):
 class TestCloudTasksQueuesList(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_list_queues(self, mock_hook):
-        mock_hook.return_value.list_queues.return_value = mock.MagicMock()
+        mock_hook.return_value.list_queues.return_value = [TEST_QUEUE]
         operator = CloudTasksQueuesListOperator(location=LOCATION, task_id="id")
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual([{'name': FULL_QUEUE_PATH, 'state': 0}], result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -134,9 +148,12 @@ class TestCloudTasksQueuesList(unittest.TestCase):
 class TestCloudTasksQueueDelete(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_delete_queue(self, mock_hook):
-        mock_hook.return_value.delete_queue.return_value = mock.MagicMock()
+        mock_hook.return_value.delete_queue.return_value = None
         operator = CloudTasksQueueDeleteOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id")
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual(None, result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -154,9 +171,12 @@ class TestCloudTasksQueueDelete(unittest.TestCase):
 class TestCloudTasksQueuePurge(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_delete_queue(self, mock_hook):
-        mock_hook.return_value.purge_queue.return_value = mock.MagicMock()
+        mock_hook.return_value.purge_queue.return_value = TEST_QUEUE
         operator = CloudTasksQueuePurgeOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id")
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -174,9 +194,12 @@ class TestCloudTasksQueuePurge(unittest.TestCase):
 class TestCloudTasksQueuePause(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_pause_queue(self, mock_hook):
-        mock_hook.return_value.pause_queue.return_value = mock.MagicMock()
+        mock_hook.return_value.pause_queue.return_value = TEST_QUEUE
         operator = CloudTasksQueuePauseOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id")
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -194,9 +217,12 @@ class TestCloudTasksQueuePause(unittest.TestCase):
 class TestCloudTasksQueueResume(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_resume_queue(self, mock_hook):
-        mock_hook.return_value.resume_queue.return_value = mock.MagicMock()
+        mock_hook.return_value.resume_queue.return_value = TEST_QUEUE
         operator = CloudTasksQueueResumeOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id")
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -214,11 +240,23 @@ class TestCloudTasksQueueResume(unittest.TestCase):
 class TestCloudTasksTaskCreate(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_create_task(self, mock_hook):
-        mock_hook.return_value.create_task.return_value = mock.MagicMock()
+        mock_hook.return_value.create_task.return_value = TEST_TASK
         operator = CloudTasksTaskCreateOperator(
             location=LOCATION, queue_name=QUEUE_ID, task=Task(), task_id="id"
         )
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual(
+            {
+                'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
+                'dispatch_count': 0,
+                'name': '',
+                'response_count': 0,
+                'view': 0,
+            },
+            result,
+        )
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -239,11 +277,23 @@ class TestCloudTasksTaskCreate(unittest.TestCase):
 class TestCloudTasksTaskGet(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_get_task(self, mock_hook):
-        mock_hook.return_value.get_task.return_value = mock.MagicMock()
+        mock_hook.return_value.get_task.return_value = TEST_TASK
         operator = CloudTasksTaskGetOperator(
             location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id"
         )
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual(
+            {
+                'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
+                'dispatch_count': 0,
+                'name': '',
+                'response_count': 0,
+                'view': 0,
+            },
+            result,
+        )
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -263,9 +313,28 @@ class TestCloudTasksTaskGet(unittest.TestCase):
 class TestCloudTasksTasksList(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_list_tasks(self, mock_hook):
-        mock_hook.return_value.list_tasks.return_value = mock.MagicMock()
+        mock_hook.return_value.list_tasks.return_value = [TEST_TASK]
         operator = CloudTasksTasksListOperator(location=LOCATION, queue_name=QUEUE_ID, task_id="id")
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual(
+            [
+                {
+                    'app_engine_http_request': {
+                        'body': '',
+                        'headers': {},
+                        'http_method': 0,
+                        'relative_uri': '',
+                    },
+                    'dispatch_count': 0,
+                    'name': '',
+                    'response_count': 0,
+                    'view': 0,
+                }
+            ],
+            result,
+        )
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -285,11 +354,14 @@ class TestCloudTasksTasksList(unittest.TestCase):
 class TestCloudTasksTaskDelete(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_delete_task(self, mock_hook):
-        mock_hook.return_value.delete_task.return_value = mock.MagicMock()
+        mock_hook.return_value.delete_task.return_value = None
         operator = CloudTasksTaskDeleteOperator(
             location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id"
         )
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual(None, result)
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,
@@ -308,11 +380,23 @@ class TestCloudTasksTaskDelete(unittest.TestCase):
 class TestCloudTasksTaskRun(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.tasks.CloudTasksHook")
     def test_run_task(self, mock_hook):
-        mock_hook.return_value.run_task.return_value = mock.MagicMock()
+        mock_hook.return_value.run_task.return_value = TEST_TASK
         operator = CloudTasksTaskRunOperator(
             location=LOCATION, queue_name=QUEUE_ID, task_name=TASK_NAME, task_id="id"
         )
-        operator.execute(context=None)
+
+        result = operator.execute(context=None)
+
+        self.assertEqual(
+            {
+                'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''},
+                'dispatch_count': 0,
+                'name': '',
+                'response_count': 0,
+                'view': 0,
+            },
+            result,
+        )
         mock_hook.assert_called_once_with(
             gcp_conn_id=GCP_CONN_ID,
             impersonation_chain=None,


[airflow] 08/28: Support google-cloud-datacatalog>=1.0.0 (#13097)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 51d70e3b02409f085e18e9c3043ab71d9055a850
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Tue Dec 22 12:58:45 2020 +0100

    Support google-cloud-datacatalog>=1.0.0 (#13097)
    
    (cherry picked from commit 9a1d3820d6f1373df790da8751f25e723f9ce037)
---
 airflow/providers/google/cloud/hooks/datacatalog.py | 6 +++---
 setup.py                                            | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/datacatalog.py b/airflow/providers/google/cloud/hooks/datacatalog.py
index 9c689c3..70b488d 100644
--- a/airflow/providers/google/cloud/hooks/datacatalog.py
+++ b/airflow/providers/google/cloud/hooks/datacatalog.py
@@ -537,7 +537,7 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.field_path(project_id, location, tag_template, field)
+        name = DataCatalogClient.tag_template_field_path(project_id, location, tag_template, field)
 
         self.log.info('Deleting a tag template field: name=%s', name)
         client.delete_tag_template_field(
@@ -860,7 +860,7 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.field_path(project_id, location, tag_template, field)
+        name = DataCatalogClient.tag_template_field_path(project_id, location, tag_template, field)
 
         self.log.info(
             'Renaming field: old_name=%s, new_tag_template_field_id=%s', name, new_tag_template_field_id
@@ -1246,7 +1246,7 @@ class CloudDataCatalogHook(GoogleBaseHook):
         """
         client = self.get_conn()
         if project_id and location and tag_template and tag_template_field_id:
-            tag_template_field_name = DataCatalogClient.field_path(
+            tag_template_field_name = DataCatalogClient.tag_template_field_path(
                 project_id, location, tag_template, tag_template_field_id
             )
 
diff --git a/setup.py b/setup.py
index 0586bf3..63dd6d7 100644
--- a/setup.py
+++ b/setup.py
@@ -287,7 +287,7 @@ google = [
     'google-cloud-bigquery-datatransfer>=0.4.0,<2.0.0',
     'google-cloud-bigtable>=1.0.0,<2.0.0',
     'google-cloud-container>=0.1.1,<2.0.0',
-    'google-cloud-datacatalog>=0.5.0, <0.8',  # TODO: we should migrate to 1.0 likely and add <2.0.0 then
+    'google-cloud-datacatalog>=1.0.0,<2.0.0',
     'google-cloud-dataproc>=1.0.1,<2.0.0',
     'google-cloud-dlp>=0.11.0,<2.0.0',
     'google-cloud-kms>=1.2.1,<2.0.0',


[airflow] 01/28: Add Neo4j hook and operator (#13324)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 8543471e89bd61d7244f72d68c1b9cd2eb4090af
Author: Kanthi <su...@gmail.com>
AuthorDate: Thu Jan 14 11:27:50 2021 -0500

    Add Neo4j hook and operator (#13324)
    
    Close: #12873
    (cherry picked from commit 1d2977f6a4c67fa6174c79dcdc4e9ee3ce06f1b1)
---
 CONTRIBUTING.rst                                   |   9 +-
 INSTALL                                            |   9 +-
 airflow/providers/neo4j/README.md                  |  18 +++
 airflow/providers/neo4j/__init__.py                |  17 +++
 airflow/providers/neo4j/example_dags/__init__.py   |  17 +++
 .../providers/neo4j/example_dags/example_neo4j.py  |  48 ++++++++
 airflow/providers/neo4j/hooks/__init__.py          |  17 +++
 airflow/providers/neo4j/hooks/neo4j.py             | 117 +++++++++++++++++++
 airflow/providers/neo4j/operators/__init__.py      |  17 +++
 airflow/providers/neo4j/operators/neo4j.py         |  62 +++++++++++
 airflow/providers/neo4j/provider.yaml              |  44 ++++++++
 docs/apache-airflow-providers-neo4j/commits.rst    |  10 +-
 .../connections/neo4j.rst                          |  63 +++++++++++
 docs/apache-airflow-providers-neo4j/index.rst      | 124 +++++++++++++++++++++
 .../operators/neo4j.rst                            |  50 +++++++++
 docs/apache-airflow/concepts.rst                   |   4 +-
 docs/apache-airflow/extra-packages-ref.rst         |   2 +
 docs/apache-airflow/start/local.rst                |   2 +-
 docs/spelling_wordlist.txt                         |   4 +
 .../run_install_and_test_provider_packages.sh      |   4 +-
 setup.py                                           |   3 +
 tests/core/test_providers_manager.py               |   2 +
 tests/providers/neo4j/__init__.py                  |  17 +++
 tests/providers/neo4j/hooks/__init__.py            |  17 +++
 tests/providers/neo4j/hooks/test_neo4j.py          |  65 +++++++++++
 tests/providers/neo4j/operators/__init__.py        |  17 +++
 tests/providers/neo4j/operators/test_neo4j.py      |  61 ++++++++++
 27 files changed, 802 insertions(+), 18 deletions(-)

diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index ff8c80a..6d0e224 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -578,10 +578,11 @@ async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes,
 databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker,
 druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google, google_auth,
 grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
-microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, odbc, openfaas, opsgenie,
-oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole, rabbitmq,
-redis, s3, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack, snowflake, spark,
-sqlite, ssh, statsd, tableau, telegram, vertica, virtualenv, webhdfs, winrm, yandex, zendesk
+microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
+opsgenie, oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole,
+rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack,
+snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica, virtualenv, webhdfs, winrm,
+yandex, zendesk
 
   .. END EXTRAS HERE
 
diff --git a/INSTALL b/INSTALL
index 4ee3f2b..e1ef456 100644
--- a/INSTALL
+++ b/INSTALL
@@ -103,10 +103,11 @@ async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes,
 databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker,
 druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google, google_auth,
 grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
-microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, odbc, openfaas, opsgenie,
-oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole, rabbitmq,
-redis, s3, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack, snowflake, spark,
-sqlite, ssh, statsd, tableau, telegram, vertica, virtualenv, webhdfs, winrm, yandex, zendesk
+microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
+opsgenie, oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole,
+rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack,
+snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica, virtualenv, webhdfs, winrm,
+yandex, zendesk
 
 # END EXTRAS HERE
 
diff --git a/airflow/providers/neo4j/README.md b/airflow/providers/neo4j/README.md
new file mode 100644
index 0000000..ef14aff
--- /dev/null
+++ b/airflow/providers/neo4j/README.md
@@ -0,0 +1,18 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
diff --git a/airflow/providers/neo4j/__init__.py b/airflow/providers/neo4j/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/neo4j/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/neo4j/example_dags/__init__.py b/airflow/providers/neo4j/example_dags/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/neo4j/example_dags/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/neo4j/example_dags/example_neo4j.py b/airflow/providers/neo4j/example_dags/example_neo4j.py
new file mode 100644
index 0000000..7d6f2fc
--- /dev/null
+++ b/airflow/providers/neo4j/example_dags/example_neo4j.py
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example use of Neo4j related operators.
+"""
+
+from airflow import DAG
+from airflow.providers.neo4j.operators.neo4j import Neo4jOperator
+from airflow.utils.dates import days_ago
+
+default_args = {
+    'owner': 'airflow',
+}
+
+dag = DAG(
+    'example_neo4j',
+    default_args=default_args,
+    start_date=days_ago(2),
+    tags=['example'],
+)
+
+# [START run_query_neo4j_operator]
+
+neo4j_task = Neo4jOperator(
+    task_id='run_neo4j_query',
+    neo4j_conn_id='neo4j_conn_id',
+    sql='MATCH (tom {name: "Tom Hanks"}) RETURN tom',
+    dag=dag,
+)
+
+# [END run_query_neo4j_operator]
+
+neo4j_task
diff --git a/airflow/providers/neo4j/hooks/__init__.py b/airflow/providers/neo4j/hooks/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/neo4j/hooks/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/neo4j/hooks/neo4j.py b/airflow/providers/neo4j/hooks/neo4j.py
new file mode 100644
index 0000000..d473b01
--- /dev/null
+++ b/airflow/providers/neo4j/hooks/neo4j.py
@@ -0,0 +1,117 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""This module allows to connect to a Neo4j database."""
+
+from neo4j import GraphDatabase, Neo4jDriver, Result
+
+from airflow.hooks.base import BaseHook
+from airflow.models import Connection
+
+
+class Neo4jHook(BaseHook):
+    """
+    Interact with Neo4j.
+
+    Performs a connection to Neo4j and runs the query.
+    """
+
+    conn_name_attr = 'neo4j_conn_id'
+    default_conn_name = 'neo4j_default'
+    conn_type = 'neo4j'
+    hook_name = 'Neo4j'
+
+    def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.neo4j_conn_id = conn_id
+        self.connection = kwargs.pop("connection", None)
+        self.client = None
+        self.extras = None
+        self.uri = None
+
+    def get_conn(self) -> Neo4jDriver:
+        """
+        Function that initiates a new Neo4j connection
+        with username, password and database schema.
+        """
+        self.connection = self.get_connection(self.neo4j_conn_id)
+        self.extras = self.connection.extra_dejson.copy()
+
+        self.uri = self.get_uri(self.connection)
+        self.log.info('URI: %s', self.uri)
+
+        if self.client is not None:
+            return self.client
+
+        is_encrypted = self.connection.extra_dejson.get('encrypted', False)
+
+        self.client = GraphDatabase.driver(
+            self.uri, auth=(self.connection.login, self.connection.password), encrypted=is_encrypted
+        )
+
+        return self.client
+
+    def get_uri(self, conn: Connection) -> str:
+        """
+        Build the uri based on extras
+        - Default - uses bolt scheme(bolt://)
+        - neo4j_scheme - neo4j://
+        - certs_self_signed - neo4j+ssc://
+        - certs_trusted_ca - neo4j+s://
+        :param conn: connection object.
+        :return: uri
+        """
+        use_neo4j_scheme = conn.extra_dejson.get('neo4j_scheme', False)
+        scheme = 'neo4j' if use_neo4j_scheme else 'bolt'
+
+        # Self signed certificates
+        ssc = conn.extra_dejson.get('certs_self_signed', False)
+
+        # Only certificates signed by CA.
+        trusted_ca = conn.extra_dejson.get('certs_trusted_ca', False)
+        encryption_scheme = ''
+
+        if ssc:
+            encryption_scheme = '+ssc'
+        elif trusted_ca:
+            encryption_scheme = '+s'
+
+        return '{scheme}{encryption_scheme}://{host}:{port}'.format(
+            scheme=scheme,
+            encryption_scheme=encryption_scheme,
+            host=conn.host,
+            port='7687' if conn.port is None else f'{conn.port}',
+        )
+
+    def run(self, query) -> Result:
+        """
+        Function to create a neo4j session
+        and execute the query in the session.
+
+
+        :param query: Neo4j query
+        :return: Result
+        """
+        driver = self.get_conn()
+        if not self.connection.schema:
+            with driver.session() as session:
+                result = session.run(query)
+        else:
+            with driver.session(database=self.connection.schema) as session:
+                result = session.run(query)
+        return result
diff --git a/airflow/providers/neo4j/operators/__init__.py b/airflow/providers/neo4j/operators/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/neo4j/operators/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/neo4j/operators/neo4j.py b/airflow/providers/neo4j/operators/neo4j.py
new file mode 100644
index 0000000..20df9cb
--- /dev/null
+++ b/airflow/providers/neo4j/operators/neo4j.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Dict, Iterable, Mapping, Optional, Union
+
+from airflow.models import BaseOperator
+from airflow.providers.neo4j.hooks.neo4j import Neo4jHook
+from airflow.utils.decorators import apply_defaults
+
+
+class Neo4jOperator(BaseOperator):
+    """
+    Executes sql code in a specific Neo4j database
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:Neo4jOperator`
+
+    :param sql: the sql code to be executed. Can receive a str representing a
+        sql statement, a list of str (sql statements)
+    :type sql: str or list[str]
+    :param neo4j_conn_id: reference to a specific Neo4j database
+    :type neo4j_conn_id: str
+    """
+
+    @apply_defaults
+    def __init__(
+        self,
+        *,
+        sql: str,
+        neo4j_conn_id: str = 'neo4j_default',
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.neo4j_conn_id = neo4j_conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.hook = None
+
+    def get_hook(self):
+        """Function to retrieve the Neo4j Hook."""
+        return Neo4jHook(conn_id=self.neo4j_conn_id)
+
+    def execute(self, context: Dict) -> None:
+        self.log.info('Executing: %s', self.sql)
+        self.hook = self.get_hook()
+        self.hook.run(self.sql)
diff --git a/airflow/providers/neo4j/provider.yaml b/airflow/providers/neo4j/provider.yaml
new file mode 100644
index 0000000..9081694
--- /dev/null
+++ b/airflow/providers/neo4j/provider.yaml
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-neo4j
+name: Neo4j
+description: |
+    `Neo4j <https://neo4j.com/>`__
+
+versions:
+  - 1.0.0
+integrations:
+  - integration-name: Neo4j
+    external-doc-url: https://neo4j.com/
+    how-to-guide:
+      - /docs/apache-airflow-providers-neo4j/operators/neo4j.rst
+    tags: [software]
+
+operators:
+  - integration-name: Neo4j
+    python-modules:
+      - airflow.providers.neo4j.operators.neo4j
+
+hooks:
+  - integration-name: Neo4j
+    python-modules:
+      - airflow.providers.neo4j.hooks.neo4j
+
+hook-class-names:
+  - airflow.providers.neo4j.hooks.neo4j.Neo4jHook
diff --git a/docs/apache-airflow-providers-neo4j/commits.rst b/docs/apache-airflow-providers-neo4j/commits.rst
index 76dc03e..bfe9721 100644
--- a/docs/apache-airflow-providers-neo4j/commits.rst
+++ b/docs/apache-airflow-providers-neo4j/commits.rst
@@ -31,11 +31,11 @@ For high-level changelog, see :doc:`package information including changelog <ind
 1.0.0
 .....
 
-Latest change: 2021-01-31
+Latest change: 2021-02-01
 
-================================================================================================  ===========  ========================================
+================================================================================================  ===========  ================================================
 Commit                                                                                            Committed    Subject
-================================================================================================  ===========  ========================================
-`4a9ce091b <https://github.com/apache/airflow/commit/4a9ce091b11b901e4f73d36457de29d5a2154159>`_  2021-01-31   ``Implement provider versioning tools``
+================================================================================================  ===========  ================================================
+`ac2f72c98 <https://github.com/apache/airflow/commit/ac2f72c98dc0821b33721054588adbf2bb53bb0b>`_  2021-02-01   ``Implement provider versioning tools (#13767)``
 `1d2977f6a <https://github.com/apache/airflow/commit/1d2977f6a4c67fa6174c79dcdc4e9ee3ce06f1b1>`_  2021-01-14   ``Add Neo4j hook and operator (#13324)``
-================================================================================================  ===========  ========================================
+================================================================================================  ===========  ================================================
diff --git a/docs/apache-airflow-providers-neo4j/connections/neo4j.rst b/docs/apache-airflow-providers-neo4j/connections/neo4j.rst
new file mode 100644
index 0000000..33fd6b5
--- /dev/null
+++ b/docs/apache-airflow-providers-neo4j/connections/neo4j.rst
@@ -0,0 +1,63 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+
+
+Neo4j Connection
+================
+The Neo4j connection type provides connection to a Neo4j database.
+
+Configuring the Connection
+--------------------------
+Host (required)
+    The host to connect to.
+
+Schema (optional)
+    Specify the schema name to be used in the database.
+
+Login (required)
+    Specify the user name to connect.
+
+Password (required)
+    Specify the password to connect.
+
+Extra (optional)
+    Specify the extra parameters (as json dictionary) that can be used in Neo4j
+    connection.
+
+    The following extras are supported:
+
+        - Default - uses bolt scheme(bolt://)
+        - neo4j_scheme - neo4j://
+        - certs_self_signed - neo4j+ssc://
+        - certs_trusted_ca - neo4j+s://
+
+      * ``encrypted``: Sets encrypted=True/False for GraphDatabase.driver, Set to ``True`` for Neo4j Aura.
+      * ``neo4j_scheme``: Specifies the scheme to ``neo4j://``, default is ``bolt://``
+      * ``certs_self_signed``: Sets the URI scheme to support self-signed certificates(``neo4j+ssc://``)
+      * ``certs_trusted_ca``: Sets the URI scheme to support only trusted CA(``neo4j+s://``)
+
+      Example "extras" field:
+
+      .. code-block:: json
+
+         {
+            "encrypted": true,
+            "neo4j_scheme": true,
+            "certs_self_signed": true,
+            "certs_trusted_ca": false
+         }
diff --git a/docs/apache-airflow-providers-neo4j/index.rst b/docs/apache-airflow-providers-neo4j/index.rst
new file mode 100644
index 0000000..dd995fb
--- /dev/null
+++ b/docs/apache-airflow-providers-neo4j/index.rst
@@ -0,0 +1,124 @@
+
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+``apache-airflow-providers-neo4j``
+==================================
+
+Content
+-------
+
+.. toctree::
+    :maxdepth: 1
+    :caption: Guides
+
+    Connection types <connections/neo4j>
+    Operators <operators/neo4j>
+
+.. toctree::
+    :maxdepth: 1
+    :caption: References
+
+    Python API <_api/airflow/providers/neo4j/index>
+
+.. toctree::
+    :maxdepth: 1
+    :caption: Resources
+
+    Example DAGs <https://github.com/apache/airflow/tree/master/airflow/providers/neo4j/example_dags>
+
+.. toctree::
+    :maxdepth: 1
+    :caption: Resources
+
+    PyPI Repository <https://pypi.org/project/apache-airflow-providers-neo4j/>
+
+.. THE REMINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME!
+
+
+.. toctree::
+    :maxdepth: 1
+    :caption: Commits
+
+    Detailed list of commits <commits>
+
+
+Package apache-airflow-providers-neo4j
+------------------------------------------------------
+
+`Neo4j <https://neo4j.com/>`__
+
+
+Release: 1.0.0
+
+Provider package
+----------------
+
+This is a provider package for ``neo4j`` provider. All classes for this provider package
+are in ``airflow.providers.neo4j`` python package.
+
+Installation
+------------
+
+.. note::
+
+    On November 2020, new version of PIP (20.3) has been released with a new, 2020 resolver. This resolver
+    does not yet work with Apache Airflow and might lead to errors in installation - depends on your choice
+    of extras. In order to install Airflow you need to either downgrade pip to version 20.2.4
+    ``pip install --upgrade pip==20.2.4`` or, in case you use Pip 20.3, you need to add option
+    ``--use-deprecated legacy-resolver`` to your pip install command.
+
+
+You can install this package on top of an existing airflow 2.* installation via
+``pip install apache-airflow-providers-neo4j``
+
+PIP requirements
+----------------
+
+=============  ==================
+PIP package    Version required
+=============  ==================
+``neo4j``      ``>=4.2.1``
+=============  ==================
+
+
+
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+Changelog
+---------
+
+
+1.0.0
+.....
+
+Initial version of the provider.
diff --git a/docs/apache-airflow-providers-neo4j/operators/neo4j.rst b/docs/apache-airflow-providers-neo4j/operators/neo4j.rst
new file mode 100644
index 0000000..411aa0c
--- /dev/null
+++ b/docs/apache-airflow-providers-neo4j/operators/neo4j.rst
@@ -0,0 +1,50 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+
+
+.. _howto/operator:Neo4jOperator:
+
+Neo4jOperator
+=============
+
+Use the :class:`~airflow.providers.neo4j.operators.Neo4jOperator` to execute
+SQL commands in a `Neo4j <https://neo4j.com/>`__ database.
+
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+Use the ``neo4j_conn_id`` argument to connect to your Neo4j instance where
+the connection metadata is structured as follows:
+
+.. list-table:: Neo4j Airflow Connection Metadata
+   :widths: 25 25
+   :header-rows: 1
+
+   * - Parameter
+     - Input
+   * - Host: string
+     - Neo4j hostname
+   * - Schema: string
+     - Database name
+   * - Login: string
+     - Neo4j user
+   * - Password: string
+     - Neo4j user password
+   * - Port: int
+     - Neo4j port
diff --git a/docs/apache-airflow/concepts.rst b/docs/apache-airflow/concepts.rst
index 346f6c0..0522c0f 100644
--- a/docs/apache-airflow/concepts.rst
+++ b/docs/apache-airflow/concepts.rst
@@ -1321,8 +1321,8 @@ In case of DAG and task policies users may raise :class:`~airflow.exceptions.Air
 to prevent a DAG from being imported or prevent a task from being executed if the task is not compliant with
 users' check.
 
-Please note, cluster policy will have precedence over task attributes defined in DAG meaning
-if ``task.sla`` is defined in dag and also mutated via cluster policy then later will have precedence.
+Please note, cluster policy will have precedence over task attributes defined in DAG meaning that
+if ``task.sla`` is defined in dag and also mutated via cluster policy then the latter will have precedence.
 
 In next sections we show examples of each type of cluster policy.
 
diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst
index c565a93..b2549ae 100644
--- a/docs/apache-airflow/extra-packages-ref.rst
+++ b/docs/apache-airflow/extra-packages-ref.rst
@@ -213,6 +213,8 @@ Those are extras that add dependencies needed for integration with other softwar
 +---------------------+-----------------------------------------------------+-------------------------------------------+
 | mysql               | ``pip install 'apache-airflow[mysql]'``             | MySQL operators and hook                  |
 +---------------------+-----------------------------------------------------+-------------------------------------------+
+| neo4j               | ``pip install 'apache-airflow[neo4j]'``             | Neo4j operators and hook                  |
++---------------------+-----------------------------------------------------+-------------------------------------------+
 | odbc                | ``pip install 'apache-airflow[odbc]'``              | ODBC data sources including MS SQL Server |
 +---------------------+-----------------------------------------------------+-------------------------------------------+
 | openfaas            | ``pip install 'apache-airflow[openfaas]'``          | OpenFaaS hooks                            |
diff --git a/docs/apache-airflow/start/local.rst b/docs/apache-airflow/start/local.rst
index 7b0bb33..64aaa7a 100644
--- a/docs/apache-airflow/start/local.rst
+++ b/docs/apache-airflow/start/local.rst
@@ -86,7 +86,7 @@ the ``Admin->Configuration`` menu. The PID file for the webserver will be stored
 in ``$AIRFLOW_HOME/airflow-webserver.pid`` or in ``/run/airflow/webserver.pid``
 if started by systemd.
 
-Out of the box, Airflow uses a sqlite database, which you should outgrow
+Out of the box, Airflow uses a SQLite database, which you should outgrow
 fairly quickly since no parallelization is possible using this database
 backend. It works in conjunction with the
 :class:`~airflow.executors.sequential_executor.SequentialExecutor` which will
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index c541f06..db4342a 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -252,6 +252,7 @@ NaN
 Naik
 Namenode
 Namespace
+Neo4j
 Nextdoor
 Nones
 NotFound
@@ -992,6 +993,8 @@ navbar
 nd
 ndjson
 neighbours
+neo
+neo4j
 neq
 networkUri
 nginx
@@ -1219,6 +1222,7 @@ sqlsensor
 sqoop
 src
 srv
+ssc
 ssd
 sshHook
 sshtunnel
diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh
index b3ee63b..969fa29 100755
--- a/scripts/in_container/run_install_and_test_provider_packages.sh
+++ b/scripts/in_container/run_install_and_test_provider_packages.sh
@@ -95,7 +95,7 @@ function discover_all_provider_packages() {
     # Columns is to force it wider, so it doesn't wrap at 80 characters
     COLUMNS=180 airflow providers list
 
-    local expected_number_of_providers=61
+    local expected_number_of_providers=62
     local actual_number_of_providers
     actual_providers=$(airflow providers list --output yaml | grep package_name)
     actual_number_of_providers=$(wc -l <<<"$actual_providers")
@@ -118,7 +118,7 @@ function discover_all_hooks() {
     group_start "Listing available hooks via 'airflow providers hooks'"
     COLUMNS=180 airflow providers hooks
 
-    local expected_number_of_hooks=59
+    local expected_number_of_hooks=60
     local actual_number_of_hooks
     actual_number_of_hooks=$(airflow providers hooks --output table | grep -c "| apache" | xargs)
     if [[ ${actual_number_of_hooks} != "${expected_number_of_hooks}" ]]; then
diff --git a/setup.py b/setup.py
index e967781..210b12f 100644
--- a/setup.py
+++ b/setup.py
@@ -360,6 +360,7 @@ mysql = [
     'mysql-connector-python>=8.0.11, <=8.0.22',
     'mysqlclient>=1.3.6,<1.4',
 ]
+neo4j = ['neo4j>=4.2.1']
 odbc = [
     'pyodbc',
 ]
@@ -557,6 +558,7 @@ PROVIDERS_REQUIREMENTS: Dict[str, List[str]] = {
     'microsoft.winrm': winrm,
     'mongo': mongo,
     'mysql': mysql,
+    'neo4j': neo4j,
     'odbc': odbc,
     'openfaas': [],
     'opsgenie': [],
@@ -711,6 +713,7 @@ ALL_DB_PROVIDERS = [
     'microsoft.mssql',
     'mongo',
     'mysql',
+    'neo4j',
     'postgres',
     'presto',
     'vertica',
diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py
index 4c03984..7d80c58 100644
--- a/tests/core/test_providers_manager.py
+++ b/tests/core/test_providers_manager.py
@@ -57,6 +57,7 @@ ALL_PROVIDERS = [
     'apache-airflow-providers-microsoft-winrm',
     'apache-airflow-providers-mongo',
     'apache-airflow-providers-mysql',
+    'apache-airflow-providers-neo4j',
     'apache-airflow-providers-odbc',
     'apache-airflow-providers-openfaas',
     'apache-airflow-providers-opsgenie',
@@ -122,6 +123,7 @@ CONNECTIONS_LIST = [
     'mongo',
     'mssql',
     'mysql',
+    'neo4j',
     'odbc',
     'oracle',
     'pig_cli',
diff --git a/tests/providers/neo4j/__init__.py b/tests/providers/neo4j/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/tests/providers/neo4j/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/neo4j/hooks/__init__.py b/tests/providers/neo4j/hooks/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/tests/providers/neo4j/hooks/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/neo4j/hooks/test_neo4j.py b/tests/providers/neo4j/hooks/test_neo4j.py
new file mode 100644
index 0000000..7f64fc4
--- /dev/null
+++ b/tests/providers/neo4j/hooks/test_neo4j.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import json
+import unittest
+from unittest import mock
+
+from airflow.models import Connection
+from airflow.providers.neo4j.hooks.neo4j import Neo4jHook
+
+
+class TestNeo4jHookConn(unittest.TestCase):
+    def setUp(self):
+        super().setUp()
+        self.neo4j_hook = Neo4jHook()
+        self.connection = Connection(
+            conn_type='neo4j', login='login', password='password', host='host', schema='schema'
+        )
+
+    def test_get_uri_neo4j_scheme(self):
+
+        self.neo4j_hook.get_connection = mock.Mock()
+        self.neo4j_hook.get_connection.return_value = self.connection
+        uri = self.neo4j_hook.get_uri(self.connection)
+
+        self.assertEqual(uri, "bolt://host:7687")
+
+    def test_get_uri_bolt_scheme(self):
+
+        self.connection.extra = json.dumps({"bolt_scheme": True})
+        self.neo4j_hook.get_connection = mock.Mock()
+        self.neo4j_hook.get_connection.return_value = self.connection
+        uri = self.neo4j_hook.get_uri(self.connection)
+
+        self.assertEqual(uri, "bolt://host:7687")
+
+    def test_get_uri_bolt_ssc_scheme(self):
+        self.connection.extra = json.dumps({"certs_self_signed": True, "bolt_scheme": True})
+        self.neo4j_hook.get_connection = mock.Mock()
+        self.neo4j_hook.get_connection.return_value = self.connection
+        uri = self.neo4j_hook.get_uri(self.connection)
+
+        self.assertEqual(uri, "bolt+ssc://host:7687")
+
+    def test_get_uri_bolt_trusted_ca_scheme(self):
+        self.connection.extra = json.dumps({"certs_trusted_ca": True, "bolt_scheme": True})
+        self.neo4j_hook.get_connection = mock.Mock()
+        self.neo4j_hook.get_connection.return_value = self.connection
+        uri = self.neo4j_hook.get_uri(self.connection)
+
+        self.assertEqual(uri, "bolt+s://host:7687")
diff --git a/tests/providers/neo4j/operators/__init__.py b/tests/providers/neo4j/operators/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/tests/providers/neo4j/operators/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/neo4j/operators/test_neo4j.py b/tests/providers/neo4j/operators/test_neo4j.py
new file mode 100644
index 0000000..39c8d69
--- /dev/null
+++ b/tests/providers/neo4j/operators/test_neo4j.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import unittest
+from unittest import mock
+
+from airflow.models.dag import DAG
+from airflow.providers.neo4j.operators.neo4j import Neo4jOperator
+from airflow.utils import timezone
+
+DEFAULT_DATE = timezone.datetime(2015, 1, 1)
+DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
+DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
+TEST_DAG_ID = 'unit_test_dag'
+
+
+class TestNeo4jOperator(unittest.TestCase):
+    def setUp(self):
+        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
+        dag = DAG(TEST_DAG_ID, default_args=args)
+        self.dag = dag
+
+    @mock.patch('airflow.providers.neo4j.operators.neo4j.Neo4jOperator.get_hook')
+    def test_neo4j_operator_test(self, mock_hook):
+
+        sql = """
+            MATCH (tom {name: "Tom Hanks"}) RETURN tom
+            """
+        op = Neo4jOperator(task_id='basic_neo4j', sql=sql, dag=self.dag)
+        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)


[airflow] 22/28: Remove reinstalling azure-storage steps from CI / Breeze (#14102)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit cac851c59dbc787b9f96c44767cbb7510d740fc6
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Mon Feb 8 21:56:46 2021 +0000

    Remove reinstalling azure-storage steps from CI / Breeze (#14102)
    
    Since https://github.com/apache/airflow/pull/12188 was merged I
    don't think we need this steps.
    
    This step also caused the docker build step for 2.0.1rc2 to fail
    
    Co-authored-by: Jarek Potiuk <ja...@potiuk.com>
    (cherry picked from commit 3ffd21745d25e6239254fe3f5688b34f5f6f77e8)
---
 scripts/docker/install_airflow.sh                  |  7 +------
 scripts/in_container/_in_container_utils.sh        | 22 ++++------------------
 scripts/in_container/entrypoint_ci.sh              |  4 ++--
 scripts/in_container/run_ci_tests.sh               |  2 --
 .../run_install_and_test_provider_packages.sh      |  5 ++---
 .../run_prepare_provider_documentation.sh          |  1 -
 setup.py                                           | 18 +++++-------------
 7 files changed, 14 insertions(+), 45 deletions(-)

diff --git a/scripts/docker/install_airflow.sh b/scripts/docker/install_airflow.sh
index bfe88be..5f1e9d9 100755
--- a/scripts/docker/install_airflow.sh
+++ b/scripts/docker/install_airflow.sh
@@ -66,9 +66,7 @@ function install_airflow() {
             pip install ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
                 "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}"
         fi
-        # Work around to install azure-storage-blob
-        pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
-        pip install azure-storage-blob azure-storage-file
+
         # make sure correct PIP version is used
         pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade "pip==${AIRFLOW_PIP_VERSION}"
         pip check || ${CONTINUE_ON_PIP_CHECK_FAILURE}
@@ -85,9 +83,6 @@ function install_airflow() {
         pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade --upgrade-strategy only-if-needed \
             ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
             "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \
-        # Work around to install azure-storage-blob
-        pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
-        pip install azure-storage-blob azure-storage-file
         # make sure correct PIP version is used
         pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade "pip==${AIRFLOW_PIP_VERSION}"
         pip check || ${CONTINUE_ON_PIP_CHECK_FAILURE}
diff --git a/scripts/in_container/_in_container_utils.sh b/scripts/in_container/_in_container_utils.sh
index 0a9db95..1e9a192 100644
--- a/scripts/in_container/_in_container_utils.sh
+++ b/scripts/in_container/_in_container_utils.sh
@@ -275,7 +275,7 @@ function install_airflow_from_wheel() {
         >&2 echo
         exit 4
     fi
-    pip install "${airflow_package}${1}"
+    pip install "${airflow_package}${extras}"
 }
 
 function install_airflow_from_sdist() {
@@ -292,20 +292,7 @@ function install_airflow_from_sdist() {
         >&2 echo
         exit 4
     fi
-    pip install "${airflow_package}${1}"
-}
-
-function reinstall_azure_storage_blob() {
-    group_start "Reinstalls azure-storage-blob (temporary workaround)"
-    # Reinstall azure-storage-blob here until https://github.com/apache/airflow/pull/12188 is solved
-    # Azure-storage-blob need to be reinstalled to overwrite azure-storage-blob installed by old version
-    # of the `azure-storage` library
-    echo
-    echo "Reinstalling azure-storage-blob"
-    echo
-    pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
-    pip install azure-storage-blob azure-storage-file --no-deps --force-reinstall
-    group_end
+    pip install "${airflow_package}${extras}"
 }
 
 function install_remaining_dependencies() {
@@ -338,13 +325,12 @@ function uninstall_airflow_and_providers() {
 
 function install_released_airflow_version() {
     local version="${1}"
-    local extras="${2}"
     echo
-    echo "Installing released ${version} version of airflow with extras ${extras}"
+    echo "Installing released ${version} version of airflow without extras"
     echo
 
     rm -rf "${AIRFLOW_SOURCES}"/*.egg-info
-    pip install --upgrade "apache-airflow${extras}==${version}"
+    pip install --upgrade "apache-airflow==${version}"
 }
 
 function install_local_airflow_with_eager_upgrade() {
diff --git a/scripts/in_container/entrypoint_ci.sh b/scripts/in_container/entrypoint_ci.sh
index 3761a3b..b99cdc1 100755
--- a/scripts/in_container/entrypoint_ci.sh
+++ b/scripts/in_container/entrypoint_ci.sh
@@ -98,9 +98,9 @@ elif [[ ${INSTALL_AIRFLOW_VERSION} == "sdist"  ]]; then
     uninstall_providers
 else
     echo
-    echo "Install airflow from PyPI including [${AIRFLOW_EXTRAS}] extras"
+    echo "Install airflow from PyPI without extras"
     echo
-    install_released_airflow_version "${INSTALL_AIRFLOW_VERSION}" "[${AIRFLOW_EXTRAS}]"
+    install_released_airflow_version "${INSTALL_AIRFLOW_VERSION}"
 fi
 if [[ ${INSTALL_PACKAGES_FROM_DIST=} == "true" ]]; then
     echo
diff --git a/scripts/in_container/run_ci_tests.sh b/scripts/in_container/run_ci_tests.sh
index 43be453..ca3c41d 100755
--- a/scripts/in_container/run_ci_tests.sh
+++ b/scripts/in_container/run_ci_tests.sh
@@ -18,8 +18,6 @@
 # shellcheck source=scripts/in_container/_in_container_script_init.sh
 . "$( dirname "${BASH_SOURCE[0]}" )/_in_container_script_init.sh"
 
-reinstall_azure_storage_blob
-
 echo
 echo "Starting the tests with those pytest arguments:" "${@}"
 echo
diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh
index 9b951c7..76d41e4 100755
--- a/scripts/in_container/run_install_and_test_provider_packages.sh
+++ b/scripts/in_container/run_install_and_test_provider_packages.sh
@@ -67,9 +67,9 @@ function install_airflow_as_specified() {
         uninstall_providers
     else
         echo
-        echo "Install airflow from PyPI including [${AIRFLOW_EXTRAS}] extras"
+        echo "Install airflow from PyPI without extras"
         echo
-        install_released_airflow_version "${INSTALL_AIRFLOW_VERSION}" "[${AIRFLOW_EXTRAS}]"
+        install_released_airflow_version "${INSTALL_AIRFLOW_VERSION}"
         uninstall_providers
     fi
     group_end
@@ -197,7 +197,6 @@ setup_provider_packages
 verify_parameters
 install_airflow_as_specified
 install_remaining_dependencies
-reinstall_azure_storage_blob
 install_provider_packages
 import_all_provider_classes
 
diff --git a/scripts/in_container/run_prepare_provider_documentation.sh b/scripts/in_container/run_prepare_provider_documentation.sh
index e88cdfc..1a0bfa8 100755
--- a/scripts/in_container/run_prepare_provider_documentation.sh
+++ b/scripts/in_container/run_prepare_provider_documentation.sh
@@ -100,7 +100,6 @@ install_supported_pip_version
 # install extra packages missing in devel_ci
 # TODO: remove it when devel_all == devel_ci
 install_remaining_dependencies
-reinstall_azure_storage_blob
 
 if [[ ${BACKPORT_PACKAGES} != "true" ]]; then
     import_all_provider_classes
diff --git a/setup.py b/setup.py
index cd38ef2..a752d82 100644
--- a/setup.py
+++ b/setup.py
@@ -219,6 +219,8 @@ azure = [
     'azure-mgmt-containerinstance>=1.5.0,<2.0',
     'azure-mgmt-datalake-store>=0.5.0',
     'azure-mgmt-resource>=2.2.0',
+    'azure-storage-blob>=12.7.0',
+    'azure-storage-common>=2.1.0',
     'azure-storage-file>=2.1.0',
 ]
 cassandra = [
@@ -423,19 +425,9 @@ slack = [
     'slack_sdk>=3.0.0,<4.0.0',
 ]
 snowflake = [
-    # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
-    # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
-    # fine together with `snowflake` because the deprecated library does not overlap with the
-    # new libraries except the `blob` classes. So while `azure` works fine for most cases
-    # blob is the only exception
-    # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
-    # once it is merged, we can move those two back to `azure` extra.
-    'azure-core>=1.10.0',
-    'azure-storage-blob',
-    'azure-storage-common',
-    # Snowflake conector > 2.3.8 is needed because it has vendored urrllib3 and requests libraries which
-    # are monkey-patched. In earlier versions of the library, monkeypatching the libraries by snowflake
-    # caused other providers to fail (Google, Amazon etc.)
+    # Snowflake connector > 2.3.8 is needed because it has vendored-in, patched urllib and requests libraries
+    # In earlier versions of the snowflake library, monkey-patching the libraries caused other
+    # providers to fail (Google, Amazon etc.)
     'snowflake-connector-python>=2.3.8',
     'snowflake-sqlalchemy>=1.1.0',
 ]


[airflow] 10/28: Support google-cloud-pubsub>=2.0.0 (#13127)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7d49baabe34d98b593491f68193011d536ee9359
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Tue Dec 22 13:02:59 2020 +0100

    Support google-cloud-pubsub>=2.0.0 (#13127)
    
    (cherry picked from commit 8c00ec89b97aa6e725379d08c8ff29a01be47e73)
---
 airflow/providers/google/cloud/hooks/pubsub.py     |  81 ++++----
 airflow/providers/google/cloud/operators/pubsub.py |   3 +-
 airflow/providers/google/cloud/sensors/pubsub.py   |   3 +-
 setup.py                                           |   2 +-
 tests/providers/google/cloud/hooks/test_pubsub.py  | 221 +++++++++++----------
 .../google/cloud/operators/test_pubsub.py          |  16 +-
 .../providers/google/cloud/sensors/test_pubsub.py  |  16 +-
 7 files changed, 177 insertions(+), 165 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/pubsub.py b/airflow/providers/google/cloud/hooks/pubsub.py
index f2ae190..37240a2 100644
--- a/airflow/providers/google/cloud/hooks/pubsub.py
+++ b/airflow/providers/google/cloud/hooks/pubsub.py
@@ -111,7 +111,7 @@ class PubSubHook(GoogleBaseHook):
         self._validate_messages(messages)
 
         publisher = self.get_conn()
-        topic_path = PublisherClient.topic_path(project_id, topic)  # pylint: disable=no-member
+        topic_path = f"projects/{project_id}/topics/{topic}"
 
         self.log.info("Publish %d messages to topic (path) %s", len(messages), topic_path)
         try:
@@ -206,7 +206,7 @@ class PubSubHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]]
         """
         publisher = self.get_conn()
-        topic_path = PublisherClient.topic_path(project_id, topic)  # pylint: disable=no-member
+        topic_path = f"projects/{project_id}/topics/{topic}"
 
         # Add airflow-version label to the topic
         labels = labels or {}
@@ -216,13 +216,15 @@ class PubSubHook(GoogleBaseHook):
         try:
             # pylint: disable=no-member
             publisher.create_topic(
-                name=topic_path,
-                labels=labels,
-                message_storage_policy=message_storage_policy,
-                kms_key_name=kms_key_name,
+                request={
+                    "name": topic_path,
+                    "labels": labels,
+                    "message_storage_policy": message_storage_policy,
+                    "kms_key_name": kms_key_name,
+                },
                 retry=retry,
                 timeout=timeout,
-                metadata=metadata,
+                metadata=metadata or (),
             )
         except AlreadyExists:
             self.log.warning('Topic already exists: %s', topic)
@@ -266,16 +268,13 @@ class PubSubHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]]
         """
         publisher = self.get_conn()
-        topic_path = PublisherClient.topic_path(project_id, topic)  # pylint: disable=no-member
+        topic_path = f"projects/{project_id}/topics/{topic}"
 
         self.log.info("Deleting topic (path) %s", topic_path)
         try:
             # pylint: disable=no-member
             publisher.delete_topic(
-                topic=topic_path,
-                retry=retry,
-                timeout=timeout,
-                metadata=metadata,
+                request={"topic": topic_path}, retry=retry, timeout=timeout, metadata=metadata or ()
             )
         except NotFound:
             self.log.warning('Topic does not exist: %s', topic_path)
@@ -401,27 +400,29 @@ class PubSubHook(GoogleBaseHook):
         labels['airflow-version'] = 'v' + version.replace('.', '-').replace('+', '-')
 
         # pylint: disable=no-member
-        subscription_path = SubscriberClient.subscription_path(subscription_project_id, subscription)
-        topic_path = SubscriberClient.topic_path(project_id, topic)
+        subscription_path = f"projects/{subscription_project_id}/subscriptions/{subscription}"
+        topic_path = f"projects/{project_id}/topics/{topic}"
 
         self.log.info("Creating subscription (path) %s for topic (path) %a", subscription_path, topic_path)
         try:
             subscriber.create_subscription(
-                name=subscription_path,
-                topic=topic_path,
-                push_config=push_config,
-                ack_deadline_seconds=ack_deadline_secs,
-                retain_acked_messages=retain_acked_messages,
-                message_retention_duration=message_retention_duration,
-                labels=labels,
-                enable_message_ordering=enable_message_ordering,
-                expiration_policy=expiration_policy,
-                filter_=filter_,
-                dead_letter_policy=dead_letter_policy,
-                retry_policy=retry_policy,
+                request={
+                    "name": subscription_path,
+                    "topic": topic_path,
+                    "push_config": push_config,
+                    "ack_deadline_seconds": ack_deadline_secs,
+                    "retain_acked_messages": retain_acked_messages,
+                    "message_retention_duration": message_retention_duration,
+                    "labels": labels,
+                    "enable_message_ordering": enable_message_ordering,
+                    "expiration_policy": expiration_policy,
+                    "filter": filter_,
+                    "dead_letter_policy": dead_letter_policy,
+                    "retry_policy": retry_policy,
+                },
                 retry=retry,
                 timeout=timeout,
-                metadata=metadata,
+                metadata=metadata or (),
             )
         except AlreadyExists:
             self.log.warning('Subscription already exists: %s', subscription_path)
@@ -466,13 +467,16 @@ class PubSubHook(GoogleBaseHook):
         """
         subscriber = self.subscriber_client
         # noqa E501 # pylint: disable=no-member
-        subscription_path = SubscriberClient.subscription_path(project_id, subscription)
+        subscription_path = f"projects/{project_id}/subscriptions/{subscription}"
 
         self.log.info("Deleting subscription (path) %s", subscription_path)
         try:
             # pylint: disable=no-member
             subscriber.delete_subscription(
-                subscription=subscription_path, retry=retry, timeout=timeout, metadata=metadata
+                request={"subscription": subscription_path},
+                retry=retry,
+                timeout=timeout,
+                metadata=metadata or (),
             )
 
         except NotFound:
@@ -527,18 +531,20 @@ class PubSubHook(GoogleBaseHook):
         """
         subscriber = self.subscriber_client
         # noqa E501 # pylint: disable=no-member,line-too-long
-        subscription_path = SubscriberClient.subscription_path(project_id, subscription)
+        subscription_path = f"projects/{project_id}/subscriptions/{subscription}"
 
         self.log.info("Pulling max %d messages from subscription (path) %s", max_messages, subscription_path)
         try:
             # pylint: disable=no-member
             response = subscriber.pull(
-                subscription=subscription_path,
-                max_messages=max_messages,
-                return_immediately=return_immediately,
+                request={
+                    "subscription": subscription_path,
+                    "max_messages": max_messages,
+                    "return_immediately": return_immediately,
+                },
                 retry=retry,
                 timeout=timeout,
-                metadata=metadata,
+                metadata=metadata or (),
             )
             result = getattr(response, 'received_messages', [])
             self.log.info("Pulled %d messages from subscription (path) %s", len(result), subscription_path)
@@ -591,17 +597,16 @@ class PubSubHook(GoogleBaseHook):
 
         subscriber = self.subscriber_client
         # noqa E501 # pylint: disable=no-member
-        subscription_path = SubscriberClient.subscription_path(project_id, subscription)
+        subscription_path = f"projects/{project_id}/subscriptions/{subscription}"
 
         self.log.info("Acknowledging %d ack_ids from subscription (path) %s", len(ack_ids), subscription_path)
         try:
             # pylint: disable=no-member
             subscriber.acknowledge(
-                subscription=subscription_path,
-                ack_ids=ack_ids,
+                request={"subscription": subscription_path, "ack_ids": ack_ids},
                 retry=retry,
                 timeout=timeout,
-                metadata=metadata,
+                metadata=metadata or (),
             )
         except (HttpError, GoogleAPICallError) as e:
             raise PubSubException(
diff --git a/airflow/providers/google/cloud/operators/pubsub.py b/airflow/providers/google/cloud/operators/pubsub.py
index e8cf735..23b545f 100644
--- a/airflow/providers/google/cloud/operators/pubsub.py
+++ b/airflow/providers/google/cloud/operators/pubsub.py
@@ -29,7 +29,6 @@ from google.cloud.pubsub_v1.types import (
     ReceivedMessage,
     RetryPolicy,
 )
-from google.protobuf.json_format import MessageToDict
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.pubsub import PubSubHook
@@ -958,6 +957,6 @@ class PubSubPullOperator(BaseOperator):
         :param context: same as in `execute`
         :return: value to be saved to XCom.
         """
-        messages_json = [MessageToDict(m) for m in pulled_messages]
+        messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
 
         return messages_json
diff --git a/airflow/providers/google/cloud/sensors/pubsub.py b/airflow/providers/google/cloud/sensors/pubsub.py
index d6e0be5..ff1f811 100644
--- a/airflow/providers/google/cloud/sensors/pubsub.py
+++ b/airflow/providers/google/cloud/sensors/pubsub.py
@@ -20,7 +20,6 @@ import warnings
 from typing import Any, Callable, Dict, List, Optional, Sequence, Union
 
 from google.cloud.pubsub_v1.types import ReceivedMessage
-from google.protobuf.json_format import MessageToDict
 
 from airflow.providers.google.cloud.hooks.pubsub import PubSubHook
 from airflow.sensors.base import BaseSensorOperator
@@ -200,6 +199,6 @@ class PubSubPullSensor(BaseSensorOperator):
         :param context: same as in `execute`
         :return: value to be saved to XCom.
         """
-        messages_json = [MessageToDict(m) for m in pulled_messages]
+        messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]
 
         return messages_json
diff --git a/setup.py b/setup.py
index 1ec4f5d..ff9fd71 100644
--- a/setup.py
+++ b/setup.py
@@ -296,7 +296,7 @@ google = [
     'google-cloud-memcache>=0.2.0',
     'google-cloud-monitoring>=0.34.0,<2.0.0',
     'google-cloud-os-login>=2.0.0,<3.0.0',
-    'google-cloud-pubsub>=1.0.0,<2.0.0',
+    'google-cloud-pubsub>=2.0.0,<3.0.0',
     'google-cloud-redis>=0.3.0,<2.0.0',
     'google-cloud-secret-manager>=0.2.0,<2.0.0',
     'google-cloud-spanner>=1.10.0,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_pubsub.py b/tests/providers/google/cloud/hooks/test_pubsub.py
index 0841806..eadb806 100644
--- a/tests/providers/google/cloud/hooks/test_pubsub.py
+++ b/tests/providers/google/cloud/hooks/test_pubsub.py
@@ -25,7 +25,6 @@ import pytest
 from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
 from google.cloud.exceptions import NotFound
 from google.cloud.pubsub_v1.types import ReceivedMessage
-from google.protobuf.json_format import ParseDict
 from googleapiclient.errors import HttpError
 from parameterized import parameterized
 
@@ -67,15 +66,12 @@ class TestPubSubHook(unittest.TestCase):
 
     def _generate_messages(self, count) -> List[ReceivedMessage]:
         return [
-            ParseDict(
-                {
-                    "ack_id": str(i),
-                    "message": {
-                        "data": f'Message {i}'.encode('utf8'),
-                        "attributes": {"type": "generated message"},
-                    },
+            ReceivedMessage(
+                ack_id=str(i),
+                message={
+                    "data": f'Message {i}'.encode('utf8'),
+                    "attributes": {"type": "generated message"},
                 },
-                ReceivedMessage(),
             )
             for i in range(1, count + 1)
         ]
@@ -112,20 +108,19 @@ class TestPubSubHook(unittest.TestCase):
         create_method = mock_service.return_value.create_topic
         self.pubsub_hook.create_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC)
         create_method.assert_called_once_with(
-            name=EXPANDED_TOPIC,
-            labels=LABELS,
-            message_storage_policy=None,
-            kms_key_name=None,
+            request=dict(name=EXPANDED_TOPIC, labels=LABELS, message_storage_policy=None, kms_key_name=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
     def test_delete_topic(self, mock_service):
         delete_method = mock_service.return_value.delete_topic
         self.pubsub_hook.delete_topic(project_id=TEST_PROJECT, topic=TEST_TOPIC)
-        delete_method.assert_called_once_with(topic=EXPANDED_TOPIC, retry=None, timeout=None, metadata=None)
+        delete_method.assert_called_once_with(
+            request=dict(topic=EXPANDED_TOPIC), retry=None, timeout=None, metadata=()
+        )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.get_conn'))
     def test_delete_nonexisting_topic_failifnotexists(self, mock_service):
@@ -177,21 +172,23 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION
         )
         create_method.assert_called_once_with(
-            name=EXPANDED_SUBSCRIPTION,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=10,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_=None,
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=EXPANDED_SUBSCRIPTION,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=10,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter=None,
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert TEST_SUBSCRIPTION == response
 
@@ -208,21 +205,23 @@ class TestPubSubHook(unittest.TestCase):
             'a-different-project', TEST_SUBSCRIPTION
         )
         create_method.assert_called_once_with(
-            name=expected_subscription,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=10,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_=None,
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=expected_subscription,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=10,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter=None,
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
         assert TEST_SUBSCRIPTION == response
@@ -232,7 +231,7 @@ class TestPubSubHook(unittest.TestCase):
         self.pubsub_hook.delete_subscription(project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION)
         delete_method = mock_service.delete_subscription
         delete_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION, retry=None, timeout=None, metadata=None
+            request=dict(subscription=EXPANDED_SUBSCRIPTION), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
@@ -266,21 +265,23 @@ class TestPubSubHook(unittest.TestCase):
 
         response = self.pubsub_hook.create_subscription(project_id=TEST_PROJECT, topic=TEST_TOPIC)
         create_method.assert_called_once_with(
-            name=expected_name,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=10,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_=None,
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=expected_name,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=10,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter=None,
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert 'sub-%s' % TEST_UUID == response
 
@@ -292,21 +293,23 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, topic=TEST_TOPIC, subscription=TEST_SUBSCRIPTION, ack_deadline_secs=30
         )
         create_method.assert_called_once_with(
-            name=EXPANDED_SUBSCRIPTION,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=30,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_=None,
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=EXPANDED_SUBSCRIPTION,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=30,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter=None,
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert TEST_SUBSCRIPTION == response
 
@@ -321,21 +324,23 @@ class TestPubSubHook(unittest.TestCase):
             filter_='attributes.domain="com"',
         )
         create_method.assert_called_once_with(
-            name=EXPANDED_SUBSCRIPTION,
-            topic=EXPANDED_TOPIC,
-            push_config=None,
-            ack_deadline_seconds=10,
-            retain_acked_messages=None,
-            message_retention_duration=None,
-            labels=LABELS,
-            enable_message_ordering=False,
-            expiration_policy=None,
-            filter_='attributes.domain="com"',
-            dead_letter_policy=None,
-            retry_policy=None,
+            request=dict(
+                name=EXPANDED_SUBSCRIPTION,
+                topic=EXPANDED_TOPIC,
+                push_config=None,
+                ack_deadline_seconds=10,
+                retain_acked_messages=None,
+                message_retention_duration=None,
+                labels=LABELS,
+                enable_message_ordering=False,
+                expiration_policy=None,
+                filter='attributes.domain="com"',
+                dead_letter_policy=None,
+                retry_policy=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert TEST_SUBSCRIPTION == response
 
@@ -401,12 +406,14 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10
         )
         pull_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION,
-            max_messages=10,
-            return_immediately=False,
+            request=dict(
+                subscription=EXPANDED_SUBSCRIPTION,
+                max_messages=10,
+                return_immediately=False,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert pulled_messages == response
 
@@ -419,12 +426,14 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10
         )
         pull_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION,
-            max_messages=10,
-            return_immediately=False,
+            request=dict(
+                subscription=EXPANDED_SUBSCRIPTION,
+                max_messages=10,
+                return_immediately=False,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert [] == response
 
@@ -445,12 +454,14 @@ class TestPubSubHook(unittest.TestCase):
         with pytest.raises(PubSubException):
             self.pubsub_hook.pull(project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=10)
             pull_method.assert_called_once_with(
-                subscription=EXPANDED_SUBSCRIPTION,
-                max_messages=10,
-                return_immediately=False,
+                request=dict(
+                    subscription=EXPANDED_SUBSCRIPTION,
+                    max_messages=10,
+                    return_immediately=False,
+                ),
                 retry=None,
                 timeout=None,
-                metadata=None,
+                metadata=(),
             )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
@@ -461,11 +472,13 @@ class TestPubSubHook(unittest.TestCase):
             project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_ids=['1', '2', '3']
         )
         ack_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION,
-            ack_ids=['1', '2', '3'],
+            request=dict(
+                subscription=EXPANDED_SUBSCRIPTION,
+                ack_ids=['1', '2', '3'],
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(PUBSUB_STRING.format('PubSubHook.subscriber_client'))
@@ -478,11 +491,13 @@ class TestPubSubHook(unittest.TestCase):
             messages=self._generate_messages(3),
         )
         ack_method.assert_called_once_with(
-            subscription=EXPANDED_SUBSCRIPTION,
-            ack_ids=['1', '2', '3'],
+            request=dict(
+                subscription=EXPANDED_SUBSCRIPTION,
+                ack_ids=['1', '2', '3'],
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @parameterized.expand(
@@ -504,11 +519,13 @@ class TestPubSubHook(unittest.TestCase):
                 project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, ack_ids=['1', '2', '3']
             )
             ack_method.assert_called_once_with(
-                subscription=EXPANDED_SUBSCRIPTION,
-                ack_ids=['1', '2', '3'],
+                request=dict(
+                    subscription=EXPANDED_SUBSCRIPTION,
+                    ack_ids=['1', '2', '3'],
+                ),
                 retry=None,
                 timeout=None,
-                metadata=None,
+                metadata=(),
             )
 
     @parameterized.expand(
diff --git a/tests/providers/google/cloud/operators/test_pubsub.py b/tests/providers/google/cloud/operators/test_pubsub.py
index 9ff71e6..6abfffa 100644
--- a/tests/providers/google/cloud/operators/test_pubsub.py
+++ b/tests/providers/google/cloud/operators/test_pubsub.py
@@ -21,7 +21,6 @@ from typing import Any, Dict, List
 from unittest import mock
 
 from google.cloud.pubsub_v1.types import ReceivedMessage
-from google.protobuf.json_format import MessageToDict, ParseDict
 
 from airflow.providers.google.cloud.operators.pubsub import (
     PubSubCreateSubscriptionOperator,
@@ -230,21 +229,18 @@ class TestPubSubPublishOperator(unittest.TestCase):
 class TestPubSubPullOperator(unittest.TestCase):
     def _generate_messages(self, count):
         return [
-            ParseDict(
-                {
-                    "ack_id": "%s" % i,
-                    "message": {
-                        "data": f'Message {i}'.encode('utf8'),
-                        "attributes": {"type": "generated message"},
-                    },
+            ReceivedMessage(
+                ack_id="%s" % i,
+                message={
+                    "data": f'Message {i}'.encode('utf8'),
+                    "attributes": {"type": "generated message"},
                 },
-                ReceivedMessage(),
             )
             for i in range(1, count + 1)
         ]
 
     def _generate_dicts(self, count):
-        return [MessageToDict(m) for m in self._generate_messages(count)]
+        return [ReceivedMessage.to_dict(m) for m in self._generate_messages(count)]
 
     @mock.patch('airflow.providers.google.cloud.operators.pubsub.PubSubHook')
     def test_execute_no_messages(self, mock_hook):
diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py b/tests/providers/google/cloud/sensors/test_pubsub.py
index ba1aee9..795860b 100644
--- a/tests/providers/google/cloud/sensors/test_pubsub.py
+++ b/tests/providers/google/cloud/sensors/test_pubsub.py
@@ -22,7 +22,6 @@ from unittest import mock
 
 import pytest
 from google.cloud.pubsub_v1.types import ReceivedMessage
-from google.protobuf.json_format import MessageToDict, ParseDict
 
 from airflow.exceptions import AirflowSensorTimeout
 from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor
@@ -35,21 +34,18 @@ TEST_SUBSCRIPTION = 'test-subscription'
 class TestPubSubPullSensor(unittest.TestCase):
     def _generate_messages(self, count):
         return [
-            ParseDict(
-                {
-                    "ack_id": "%s" % i,
-                    "message": {
-                        "data": f'Message {i}'.encode('utf8'),
-                        "attributes": {"type": "generated message"},
-                    },
+            ReceivedMessage(
+                ack_id="%s" % i,
+                message={
+                    "data": f'Message {i}'.encode('utf8'),
+                    "attributes": {"type": "generated message"},
                 },
-                ReceivedMessage(),
             )
             for i in range(1, count + 1)
         ]
 
     def _generate_dicts(self, count):
-        return [MessageToDict(m) for m in self._generate_messages(count)]
+        return [ReceivedMessage.to_dict(m) for m in self._generate_messages(count)]
 
     @mock.patch('airflow.providers.google.cloud.sensors.pubsub.PubSubHook')
     def test_poke_no_messages(self, mock_hook):


[airflow] 16/28: Support google-cloud-automl >=2.1.0 (#13505)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 92c356eddb7a1eebe112091027ae3077a96a5829
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Mon Jan 11 09:39:44 2021 +0100

    Support google-cloud-automl >=2.1.0 (#13505)
    
    (cherry picked from commit a6f999b62e3c9aeb10ab24342674d3670a8ad259)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |   1 +
 .../cloud/example_dags/example_automl_tables.py    |   6 +-
 airflow/providers/google/cloud/hooks/automl.py     | 103 +++++++++++----------
 airflow/providers/google/cloud/operators/automl.py |  36 +++----
 setup.py                                           |   2 +-
 tests/providers/google/cloud/hooks/test_automl.py  |  70 +++++++-------
 .../google/cloud/operators/test_automl.py          |  29 ++++--
 7 files changed, 134 insertions(+), 113 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index d80f9e1..800703b 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -29,6 +29,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
 
 | Library name | Previous constraints | Current constraints | |
 | --- | --- | --- | --- |
+| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
 | [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
 | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
 | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
diff --git a/airflow/providers/google/cloud/example_dags/example_automl_tables.py b/airflow/providers/google/cloud/example_dags/example_automl_tables.py
index 4ff92b3..117bd34 100644
--- a/airflow/providers/google/cloud/example_dags/example_automl_tables.py
+++ b/airflow/providers/google/cloud/example_dags/example_automl_tables.py
@@ -47,7 +47,7 @@ GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1")
 GCP_AUTOML_DATASET_BUCKET = os.environ.get(
     "GCP_AUTOML_DATASET_BUCKET", "gs://cloud-ml-tables-data/bank-marketing.csv"
 )
-TARGET = os.environ.get("GCP_AUTOML_TARGET", "Class")
+TARGET = os.environ.get("GCP_AUTOML_TARGET", "Deposit")
 
 # Example values
 MODEL_ID = "TBL123456"
@@ -76,9 +76,9 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
     Using column name returns spec of the column.
     """
     for column in columns_specs:
-        if column["displayName"] == column_name:
+        if column["display_name"] == column_name:
             return extract_object_id(column)
-    return ""
+    raise Exception(f"Unknown target column: {column_name}")
 
 
 # Example DAG to create dataset, train model_id and deploy it.
diff --git a/airflow/providers/google/cloud/hooks/automl.py b/airflow/providers/google/cloud/hooks/automl.py
index 78ec4fb..75d7037 100644
--- a/airflow/providers/google/cloud/hooks/automl.py
+++ b/airflow/providers/google/cloud/hooks/automl.py
@@ -20,22 +20,23 @@
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 from cached_property import cached_property
+from google.api_core.operation import Operation
 from google.api_core.retry import Retry
-from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
-from google.cloud.automl_v1beta1.types import (
+from google.cloud.automl_v1beta1 import (
+    AutoMlClient,
     BatchPredictInputConfig,
     BatchPredictOutputConfig,
     ColumnSpec,
     Dataset,
     ExamplePayload,
-    FieldMask,
     ImageObjectDetectionModelDeploymentMetadata,
     InputConfig,
     Model,
-    Operation,
+    PredictionServiceClient,
     PredictResponse,
     TableSpec,
 )
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
 
@@ -123,9 +124,9 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
         """
         client = self.get_conn()
-        parent = client.location_path(project_id, location)
+        parent = f"projects/{project_id}/locations/{location}"
         return client.create_model(
-            parent=parent, model=model, retry=retry, timeout=timeout, metadata=metadata
+            request={'parent': parent, 'model': model}, retry=retry, timeout=timeout, metadata=metadata or ()
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -176,15 +177,17 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
         """
         client = self.prediction_client
-        name = client.model_path(project=project_id, location=location, model=model_id)
+        name = f"projects/{project_id}/locations/{location}/models/{model_id}"
         result = client.batch_predict(
-            name=name,
-            input_config=input_config,
-            output_config=output_config,
-            params=params,
+            request={
+                'name': name,
+                'input_config': input_config,
+                'output_config': output_config,
+                'params': params,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return result
 
@@ -229,14 +232,12 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types.PredictResponse` instance
         """
         client = self.prediction_client
-        name = client.model_path(project=project_id, location=location, model=model_id)
+        name = f"projects/{project_id}/locations/{location}/models/{model_id}"
         result = client.predict(
-            name=name,
-            payload=payload,
-            params=params,
+            request={'name': name, 'payload': payload, 'params': params},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return result
 
@@ -273,13 +274,12 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types.Dataset` instance.
         """
         client = self.get_conn()
-        parent = client.location_path(project=project_id, location=location)
+        parent = f"projects/{project_id}/locations/{location}"
         result = client.create_dataset(
-            parent=parent,
-            dataset=dataset,
+            request={'parent': parent, 'dataset': dataset},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return result
 
@@ -319,13 +319,12 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
         """
         client = self.get_conn()
-        name = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
+        name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
         result = client.import_data(
-            name=name,
-            input_config=input_config,
+            request={'name': name, 'input_config': input_config},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return result
 
@@ -385,13 +384,10 @@ class CloudAutoMLHook(GoogleBaseHook):
             table_spec=table_spec_id,
         )
         result = client.list_column_specs(
-            parent=parent,
-            field_mask=field_mask,
-            filter_=filter_,
-            page_size=page_size,
+            request={'parent': parent, 'field_mask': field_mask, 'filter': filter_, 'page_size': page_size},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return result
 
@@ -427,8 +423,10 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types.Model` instance.
         """
         client = self.get_conn()
-        name = client.model_path(project=project_id, location=location, model=model_id)
-        result = client.get_model(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+        result = client.get_model(
+            request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
         return result
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -463,8 +461,10 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
         """
         client = self.get_conn()
-        name = client.model_path(project=project_id, location=location, model=model_id)
-        result = client.delete_model(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+        result = client.delete_model(
+            request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
         return result
 
     def update_dataset(
@@ -497,11 +497,10 @@ class CloudAutoMLHook(GoogleBaseHook):
         """
         client = self.get_conn()
         result = client.update_dataset(
-            dataset=dataset,
-            update_mask=update_mask,
+            request={'dataset': dataset, 'update_mask': update_mask},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return result
 
@@ -547,13 +546,15 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
         """
         client = self.get_conn()
-        name = client.model_path(project=project_id, location=location, model=model_id)
+        name = f"projects/{project_id}/locations/{location}/models/{model_id}"
         result = client.deploy_model(
-            name=name,
+            request={
+                'name': name,
+                'image_object_detection_model_deployment_metadata': image_detection_metadata,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
-            image_object_detection_model_deployment_metadata=image_detection_metadata,
+            metadata=metadata or (),
         )
         return result
 
@@ -601,14 +602,12 @@ class CloudAutoMLHook(GoogleBaseHook):
             of the response through the `options` parameter.
         """
         client = self.get_conn()
-        parent = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
+        parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
         result = client.list_table_specs(
-            parent=parent,
-            filter_=filter_,
-            page_size=page_size,
+            request={'parent': parent, 'filter': filter_, 'page_size': page_size},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         return result
 
@@ -644,8 +643,10 @@ class CloudAutoMLHook(GoogleBaseHook):
             of the response through the `options` parameter.
         """
         client = self.get_conn()
-        parent = client.location_path(project=project_id, location=location)
-        result = client.list_datasets(parent=parent, retry=retry, timeout=timeout, metadata=metadata)
+        parent = f"projects/{project_id}/locations/{location}"
+        result = client.list_datasets(
+            request={'parent': parent}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
         return result
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -680,6 +681,8 @@ class CloudAutoMLHook(GoogleBaseHook):
         :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
         """
         client = self.get_conn()
-        name = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
-        result = client.delete_dataset(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
+        result = client.delete_dataset(
+            request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
         return result
diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py
index a1823cd..cdf79b0 100644
--- a/airflow/providers/google/cloud/operators/automl.py
+++ b/airflow/providers/google/cloud/operators/automl.py
@@ -22,7 +22,14 @@ import ast
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 from google.api_core.retry import Retry
-from google.protobuf.json_format import MessageToDict
+from google.cloud.automl_v1beta1 import (
+    BatchPredictResult,
+    ColumnSpec,
+    Dataset,
+    Model,
+    PredictResponse,
+    TableSpec,
+)
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
@@ -113,7 +120,7 @@ class AutoMLTrainModelOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(operation.result())
+        result = Model.to_dict(operation.result())
         model_id = hook.extract_object_id(result)
         self.log.info("Model created: %s", model_id)
 
@@ -212,7 +219,7 @@ class AutoMLPredictOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(result)
+        return PredictResponse.to_dict(result)
 
 
 class AutoMLBatchPredictOperator(BaseOperator):
@@ -324,7 +331,7 @@ class AutoMLBatchPredictOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(operation.result())
+        result = BatchPredictResult.to_dict(operation.result())
         self.log.info("Batch prediction ready.")
         return result
 
@@ -414,7 +421,7 @@ class AutoMLCreateDatasetOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(result)
+        result = Dataset.to_dict(result)
         dataset_id = hook.extract_object_id(result)
         self.log.info("Creating completed. Dataset id: %s", dataset_id)
 
@@ -513,9 +520,8 @@ class AutoMLImportDataOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(operation.result())
+        operation.result()
         self.log.info("Import completed")
-        return result
 
 
 class AutoMLTablesListColumnSpecsOperator(BaseOperator):
@@ -627,7 +633,7 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = [MessageToDict(spec) for spec in page_iterator]
+        result = [ColumnSpec.to_dict(spec) for spec in page_iterator]
         self.log.info("Columns specs obtained.")
 
         return result
@@ -718,7 +724,7 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator):
             metadata=self.metadata,
         )
         self.log.info("Dataset updated.")
-        return MessageToDict(result)
+        return Dataset.to_dict(result)
 
 
 class AutoMLGetModelOperator(BaseOperator):
@@ -804,7 +810,7 @@ class AutoMLGetModelOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(result)
+        return Model.to_dict(result)
 
 
 class AutoMLDeleteModelOperator(BaseOperator):
@@ -890,8 +896,7 @@ class AutoMLDeleteModelOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(operation.result())
-        return result
+        operation.result()
 
 
 class AutoMLDeployModelOperator(BaseOperator):
@@ -991,9 +996,8 @@ class AutoMLDeployModelOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(operation.result())
+        operation.result()
         self.log.info("Model deployed.")
-        return result
 
 
 class AutoMLTablesListTableSpecsOperator(BaseOperator):
@@ -1092,7 +1096,7 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = [MessageToDict(spec) for spec in page_iterator]
+        result = [TableSpec.to_dict(spec) for spec in page_iterator]
         self.log.info(result)
         self.log.info("Table specs obtained.")
         return result
@@ -1173,7 +1177,7 @@ class AutoMLListDatasetOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = [MessageToDict(dataset) for dataset in page_iterator]
+        result = [Dataset.to_dict(dataset) for dataset in page_iterator]
         self.log.info("Datasets obtained.")
 
         self.xcom_push(
diff --git a/setup.py b/setup.py
index 5314814..ff9e65d 100644
--- a/setup.py
+++ b/setup.py
@@ -283,7 +283,7 @@ google = [
     'google-api-python-client>=1.6.0,<2.0.0',
     'google-auth>=1.0.0,<2.0.0',
     'google-auth-httplib2>=0.0.1',
-    'google-cloud-automl>=0.4.0,<2.0.0',
+    'google-cloud-automl>=2.1.0,<3.0.0',
     'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0',
     'google-cloud-bigtable>=1.0.0,<2.0.0',
     'google-cloud-container>=0.1.1,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_automl.py b/tests/providers/google/cloud/hooks/test_automl.py
index 898001c..c9de712 100644
--- a/tests/providers/google/cloud/hooks/test_automl.py
+++ b/tests/providers/google/cloud/hooks/test_automl.py
@@ -19,7 +19,7 @@
 import unittest
 from unittest import mock
 
-from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
+from google.cloud.automl_v1beta1 import AutoMlClient
 
 from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
 from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id
@@ -38,9 +38,9 @@ MODEL = {
     "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
 }
 
-LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
-MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID)
-DATASET_PATH = AutoMlClient.dataset_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID)
+LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
+MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
+DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
 
 INPUT_CONFIG = {"input": "value"}
 OUTPUT_CONFIG = {"output": "value"}
@@ -81,7 +81,7 @@ class TestAuoMLHook(unittest.TestCase):
         self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
 
         mock_create_model.assert_called_once_with(
-            parent=LOCATION_PATH, model=MODEL, retry=None, timeout=None, metadata=None
+            request=dict(parent=LOCATION_PATH, model=MODEL), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict")
@@ -95,13 +95,12 @@ class TestAuoMLHook(unittest.TestCase):
         )
 
         mock_batch_predict.assert_called_once_with(
-            name=MODEL_PATH,
-            input_config=INPUT_CONFIG,
-            output_config=OUTPUT_CONFIG,
-            params=None,
+            request=dict(
+                name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict")
@@ -114,12 +113,10 @@ class TestAuoMLHook(unittest.TestCase):
         )
 
         mock_predict.assert_called_once_with(
-            name=MODEL_PATH,
-            payload=PAYLOAD,
-            params=None,
+            request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset")
@@ -127,11 +124,10 @@ class TestAuoMLHook(unittest.TestCase):
         self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
 
         mock_create_dataset.assert_called_once_with(
-            parent=LOCATION_PATH,
-            dataset=DATASET,
+            request=dict(parent=LOCATION_PATH, dataset=DATASET),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data")
@@ -144,11 +140,10 @@ class TestAuoMLHook(unittest.TestCase):
         )
 
         mock_import_data.assert_called_once_with(
-            name=DATASET_PATH,
-            input_config=INPUT_CONFIG,
+            request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs")
@@ -169,26 +164,27 @@ class TestAuoMLHook(unittest.TestCase):
 
         parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec)
         mock_list_column_specs.assert_called_once_with(
-            parent=parent,
-            field_mask=MASK,
-            filter_=filter_,
-            page_size=page_size,
+            request=dict(parent=parent, field_mask=MASK, filter=filter_, page_size=page_size),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model")
     def test_get_model(self, mock_get_model):
         self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
 
-        mock_get_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None)
+        mock_get_model.assert_called_once_with(
+            request=dict(name=MODEL_PATH), retry=None, timeout=None, metadata=()
+        )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model")
     def test_delete_model(self, mock_delete_model):
         self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
 
-        mock_delete_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None)
+        mock_delete_model.assert_called_once_with(
+            request=dict(name=MODEL_PATH), retry=None, timeout=None, metadata=()
+        )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset")
     def test_update_dataset(self, mock_update_dataset):
@@ -198,7 +194,7 @@ class TestAuoMLHook(unittest.TestCase):
         )
 
         mock_update_dataset.assert_called_once_with(
-            dataset=DATASET, update_mask=MASK, retry=None, timeout=None, metadata=None
+            request=dict(dataset=DATASET, update_mask=MASK), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model")
@@ -213,11 +209,13 @@ class TestAuoMLHook(unittest.TestCase):
         )
 
         mock_deploy_model.assert_called_once_with(
-            name=MODEL_PATH,
+            request=dict(
+                name=MODEL_PATH,
+                image_object_detection_model_deployment_metadata=image_detection_metadata,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
-            image_object_detection_model_deployment_metadata=image_detection_metadata,
+            metadata=(),
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs")
@@ -234,12 +232,10 @@ class TestAuoMLHook(unittest.TestCase):
         )
 
         mock_list_table_specs.assert_called_once_with(
-            parent=DATASET_PATH,
-            filter_=filter_,
-            page_size=page_size,
+            request=dict(parent=DATASET_PATH, filter=filter_, page_size=page_size),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets")
@@ -247,7 +243,7 @@ class TestAuoMLHook(unittest.TestCase):
         self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
 
         mock_list_datasets.assert_called_once_with(
-            parent=LOCATION_PATH, retry=None, timeout=None, metadata=None
+            request=dict(parent=LOCATION_PATH), retry=None, timeout=None, metadata=()
         )
 
     @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset")
@@ -255,5 +251,5 @@ class TestAuoMLHook(unittest.TestCase):
         self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
 
         mock_delete_dataset.assert_called_once_with(
-            name=DATASET_PATH, retry=None, timeout=None, metadata=None
+            request=dict(name=DATASET_PATH), retry=None, timeout=None, metadata=()
         )
diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py
index 903600b..4c80703 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -20,8 +20,9 @@ import copy
 import unittest
 from unittest import mock
 
-from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
+from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse
 
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
 from airflow.providers.google.cloud.operators.automl import (
     AutoMLBatchPredictOperator,
     AutoMLCreateDatasetOperator,
@@ -43,7 +44,7 @@ TASK_ID = "test-automl-hook"
 GCP_PROJECT_ID = "test-project"
 GCP_LOCATION = "test-location"
 MODEL_NAME = "test_model"
-MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
+MODEL_ID = "TBL9195602771183665152"
 DATASET_ID = "TBL123456789"
 MODEL = {
     "display_name": MODEL_NAME,
@@ -51,8 +52,9 @@ MODEL = {
     "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
 }
 
-LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
-MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID)
+LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
+MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
+DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
 
 INPUT_CONFIG = {"input": "value"}
 OUTPUT_CONFIG = {"output": "value"}
@@ -60,12 +62,15 @@ PAYLOAD = {"test": "payload"}
 DATASET = {"dataset_id": "data"}
 MASK = {"field": "mask"}
 
+extract_object_id = CloudAutoMLHook.extract_object_id
+
 
 class TestAutoMLTrainModelOperator(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator.xcom_push")
     @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
     def test_execute(self, mock_hook, mock_xcom):
-        mock_hook.return_value.extract_object_id.return_value = MODEL_ID
+        mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH)
+        mock_hook.return_value.extract_object_id = extract_object_id
         op = AutoMLTrainModelOperator(
             model=MODEL,
             location=GCP_LOCATION,
@@ -87,6 +92,9 @@ class TestAutoMLTrainModelOperator(unittest.TestCase):
 class TestAutoMLBatchPredictOperator(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
     def test_execute(self, mock_hook):
+        mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult()
+        mock_hook.return_value.extract_object_id = extract_object_id
+
         op = AutoMLBatchPredictOperator(
             model_id=MODEL_ID,
             location=GCP_LOCATION,
@@ -113,6 +121,8 @@ class TestAutoMLBatchPredictOperator(unittest.TestCase):
 class TestAutoMLPredictOperator(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
     def test_execute(self, mock_hook):
+        mock_hook.return_value.predict.return_value = PredictResponse()
+
         op = AutoMLPredictOperator(
             model_id=MODEL_ID,
             location=GCP_LOCATION,
@@ -137,7 +147,9 @@ class TestAutoMLCreateImportOperator(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator.xcom_push")
     @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
     def test_execute(self, mock_hook, mock_xcom):
-        mock_hook.return_value.extract_object_id.return_value = DATASET_ID
+        mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH)
+        mock_hook.return_value.extract_object_id = extract_object_id
+
         op = AutoMLCreateDatasetOperator(
             dataset=DATASET,
             location=GCP_LOCATION,
@@ -191,6 +203,8 @@ class TestAutoMLListColumnsSpecsOperator(unittest.TestCase):
 class TestAutoMLUpdateDatasetOperator(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
     def test_execute(self, mock_hook):
+        mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH)
+
         dataset = copy.deepcopy(DATASET)
         dataset["name"] = DATASET_ID
 
@@ -213,6 +227,9 @@ class TestAutoMLUpdateDatasetOperator(unittest.TestCase):
 class TestAutoMLGetModelOperator(unittest.TestCase):
     @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
     def test_execute(self, mock_hook):
+        mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH)
+        mock_hook.return_value.extract_object_id = extract_object_id
+
         op = AutoMLGetModelOperator(
             model_id=MODEL_ID,
             location=GCP_LOCATION,


[airflow] 12/28: Add timeout option to gcs hook methods. (#13156)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 1a2428817e4e7ddd50737f3a3723d2505e970c1c
Author: Joshua Carp <jm...@gmail.com>
AuthorDate: Thu Dec 24 08:12:06 2020 -0500

    Add timeout option to gcs hook methods. (#13156)
    
    (cherry picked from commit 323084e97ddacbc5512709bf0cad8f53082d16b0)
---
 airflow/providers/google/cloud/hooks/gcs.py    | 30 ++++++++++++++++++++------
 setup.py                                       |  2 +-
 tests/providers/google/cloud/hooks/test_gcs.py | 14 ++++++------
 3 files changed, 32 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py
index 0ca3961..72a23ea 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -40,6 +40,9 @@ from airflow.version import version
 RT = TypeVar('RT')  # pylint: disable=invalid-name
 T = TypeVar("T", bound=Callable)  # pylint: disable=invalid-name
 
+# Use default timeout from google-cloud-storage
+DEFAULT_TIMEOUT = 60
+
 
 def _fallback_object_url_to_object_name_and_bucket_name(
     object_url_keyword_arg_name='object_url',
@@ -257,7 +260,12 @@ class GCSHook(GoogleBaseHook):
         )
 
     def download(
-        self, object_name: str, bucket_name: Optional[str], filename: Optional[str] = None
+        self,
+        object_name: str,
+        bucket_name: Optional[str],
+        filename: Optional[str] = None,
+        chunk_size: Optional[int] = None,
+        timeout: Optional[int] = DEFAULT_TIMEOUT,
     ) -> Union[str, bytes]:
         """
         Downloads a file from Google Cloud Storage.
@@ -273,16 +281,20 @@ class GCSHook(GoogleBaseHook):
         :type object_name: str
         :param filename: If set, a local file path where the file should be written to.
         :type filename: str
+        :param chunk_size: Blob chunk size.
+        :type chunk_size: int
+        :param timeout: Request timeout in seconds.
+        :type timeout: int
         """
         # TODO: future improvement check file size before downloading,
         #  to check for local space availability
 
         client = self.get_conn()
         bucket = client.bucket(bucket_name)
-        blob = bucket.blob(blob_name=object_name)
+        blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)
 
         if filename:
-            blob.download_to_filename(filename)
+            blob.download_to_filename(filename, timeout=timeout)
             self.log.info('File downloaded to %s', filename)
             return filename
         else:
@@ -359,6 +371,8 @@ class GCSHook(GoogleBaseHook):
         mime_type: Optional[str] = None,
         gzip: bool = False,
         encoding: str = 'utf-8',
+        chunk_size: Optional[int] = None,
+        timeout: Optional[int] = DEFAULT_TIMEOUT,
     ) -> None:
         """
         Uploads a local file or file data as string or bytes to Google Cloud Storage.
@@ -377,10 +391,14 @@ class GCSHook(GoogleBaseHook):
         :type gzip: bool
         :param encoding: bytes encoding for file data if provided as string
         :type encoding: str
+        :param chunk_size: Blob chunk size.
+        :type chunk_size: int
+        :param timeout: Request timeout in seconds.
+        :type timeout: int
         """
         client = self.get_conn()
         bucket = client.bucket(bucket_name)
-        blob = bucket.blob(blob_name=object_name)
+        blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)
         if filename and data:
             raise ValueError(
                 "'filename' and 'data' parameter provided. Please "
@@ -398,7 +416,7 @@ class GCSHook(GoogleBaseHook):
                         shutil.copyfileobj(f_in, f_out)
                         filename = filename_gz
 
-            blob.upload_from_filename(filename=filename, content_type=mime_type)
+            blob.upload_from_filename(filename=filename, content_type=mime_type, timeout=timeout)
             if gzip:
                 os.remove(filename)
             self.log.info('File %s uploaded to %s in %s bucket', filename, object_name, bucket_name)
@@ -412,7 +430,7 @@ class GCSHook(GoogleBaseHook):
                 with gz.GzipFile(fileobj=out, mode="w") as f:
                     f.write(data)
                 data = out.getvalue()
-            blob.upload_from_string(data, content_type=mime_type)
+            blob.upload_from_string(data, content_type=mime_type, timeout=timeout)
             self.log.info('Data stream uploaded to %s in %s bucket', object_name, bucket_name)
         else:
             raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.")
diff --git a/setup.py b/setup.py
index ae18e57..3df9e47 100644
--- a/setup.py
+++ b/setup.py
@@ -301,7 +301,7 @@ google = [
     'google-cloud-secret-manager>=0.2.0,<2.0.0',
     'google-cloud-spanner>=1.10.0,<2.0.0',
     'google-cloud-speech>=0.36.3,<2.0.0',
-    'google-cloud-storage>=1.16,<2.0.0',
+    'google-cloud-storage>=1.30,<2.0.0',
     'google-cloud-tasks>=1.2.1,<2.0.0',
     'google-cloud-texttospeech>=0.4.0,<2.0.0',
     'google-cloud-translate>=1.5.0,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py
index dffe5ad..1ce44bb 100644
--- a/tests/providers/google/cloud/hooks/test_gcs.py
+++ b/tests/providers/google/cloud/hooks/test_gcs.py
@@ -672,7 +672,7 @@ class TestGCSHook(unittest.TestCase):
         )
 
         self.assertEqual(response, test_file)
-        download_filename_method.assert_called_once_with(test_file)
+        download_filename_method.assert_called_once_with(test_file, timeout=60)
 
     @mock.patch(GCS_STRING.format('NamedTemporaryFile'))
     @mock.patch(GCS_STRING.format('GCSHook.get_conn'))
@@ -697,7 +697,7 @@ class TestGCSHook(unittest.TestCase):
         with self.gcs_hook.provide_file(bucket_name=test_bucket, object_name=test_object) as response:
 
             self.assertEqual(test_file, response.name)
-        download_filename_method.assert_called_once_with(test_file)
+        download_filename_method.assert_called_once_with(test_file, timeout=60)
         mock_temp_file.assert_has_calls(
             [
                 mock.call(suffix='test_object'),
@@ -762,7 +762,7 @@ class TestGCSHookUpload(unittest.TestCase):
         self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name)
 
         upload_method.assert_called_once_with(
-            filename=self.testfile.name, content_type='application/octet-stream'
+            filename=self.testfile.name, content_type='application/octet-stream', timeout=60
         )
 
     @mock.patch(GCS_STRING.format('GCSHook.get_conn'))
@@ -782,7 +782,7 @@ class TestGCSHookUpload(unittest.TestCase):
 
         self.gcs_hook.upload(test_bucket, test_object, data=self.testdata_str)
 
-        upload_method.assert_called_once_with(self.testdata_str, content_type='text/plain')
+        upload_method.assert_called_once_with(self.testdata_str, content_type='text/plain', timeout=60)
 
     @mock.patch(GCS_STRING.format('GCSHook.get_conn'))
     def test_upload_data_bytes(self, mock_service):
@@ -793,7 +793,7 @@ class TestGCSHookUpload(unittest.TestCase):
 
         self.gcs_hook.upload(test_bucket, test_object, data=self.testdata_bytes)
 
-        upload_method.assert_called_once_with(self.testdata_bytes, content_type='text/plain')
+        upload_method.assert_called_once_with(self.testdata_bytes, content_type='text/plain', timeout=60)
 
     @mock.patch(GCS_STRING.format('BytesIO'))
     @mock.patch(GCS_STRING.format('gz.GzipFile'))
@@ -812,7 +812,7 @@ class TestGCSHookUpload(unittest.TestCase):
         byte_str = bytes(self.testdata_str, encoding)
         mock_gzip.assert_called_once_with(fileobj=mock_bytes_io.return_value, mode="w")
         gzip_ctx.write.assert_called_once_with(byte_str)
-        upload_method.assert_called_once_with(data, content_type='text/plain')
+        upload_method.assert_called_once_with(data, content_type='text/plain', timeout=60)
 
     @mock.patch(GCS_STRING.format('BytesIO'))
     @mock.patch(GCS_STRING.format('gz.GzipFile'))
@@ -829,7 +829,7 @@ class TestGCSHookUpload(unittest.TestCase):
 
         mock_gzip.assert_called_once_with(fileobj=mock_bytes_io.return_value, mode="w")
         gzip_ctx.write.assert_called_once_with(self.testdata_bytes)
-        upload_method.assert_called_once_with(data, content_type='text/plain')
+        upload_method.assert_called_once_with(data, content_type='text/plain', timeout=60)
 
     @mock.patch(GCS_STRING.format('GCSHook.get_conn'))
     def test_upload_exceptions(self, mock_service):


[airflow] 28/28: Fix failing docs build on Master (#14465)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 38b3548c25f82a952e3c3729ab86ddd840a2e6ab
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Thu Feb 25 18:56:55 2021 +0000

    Fix failing docs build on Master (#14465)
    
    https://github.com/apache/airflow/pull/14030 caused this issue
    (cherry picked from commit 4455f14732c207ec213703b8b8c68efeb8b6aebe)
---
 docs/apache-airflow-providers-tableau/index.rst | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/docs/apache-airflow-providers-tableau/index.rst b/docs/apache-airflow-providers-tableau/index.rst
index 47ace94..ce74925 100644
--- a/docs/apache-airflow-providers-tableau/index.rst
+++ b/docs/apache-airflow-providers-tableau/index.rst
@@ -24,12 +24,6 @@ Content
 
 .. toctree::
     :maxdepth: 1
-    :caption: Guides
-
-    Connection types <connections/tableau>
-
-.. toctree::
-    :maxdepth: 1
     :caption: References
 
     Python API <_api/airflow/providers/tableau/index>


[airflow] 23/28: Limits Sphinx to <3.5.0 (#14238)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 6118de45cbddd910e108267b24a5371f07c21b01
Author: Jarek Potiuk <ja...@potiuk.com>
AuthorDate: Mon Feb 15 14:05:49 2021 +0100

    Limits Sphinx to <3.5.0 (#14238)
    
    Sphinx 3.5.0 released on 14th of Feb introduced a problem in our
    doc builds.
    
    It is documented in https://github.com/sphinx-doc/sphinx/issues/8880
    
    Until this problem is solved we are limiting Sphinx.
    
    (cherry picked from commit da80b69812b12377efddf5ad9763ee09f89a9f31)
---
 setup.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/setup.py b/setup.py
index a752d82..92eb113 100644
--- a/setup.py
+++ b/setup.py
@@ -245,7 +245,8 @@ datadog = [
     'datadog>=0.14.0',
 ]
 doc = [
-    'sphinx>=2.1.2',
+    # Sphinx is limited to < 3.5.0 because of https://github.com/sphinx-doc/sphinx/issues/8880
+    'sphinx>=2.1.2, <3.5.0',
     f'sphinx-airflow-theme{get_sphinx_theme_version()}',
     'sphinx-argparse>=0.1.13',
     'sphinx-autoapi==1.0.0',


[airflow] 03/28: Fix grammar in production-deployment.rst (#14386)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 6e741499fa7a03d71ab92300869b7ff96453689b
Author: Jon Quinn <jo...@gmail.com>
AuthorDate: Tue Feb 23 13:31:38 2021 +0000

    Fix grammar in production-deployment.rst (#14386)
    
    (cherry picked from commit 4fb943c21425f055e555a95ef9e4f7ba4690ee8b)
---
 docs/apache-airflow/production-deployment.rst | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/docs/apache-airflow/production-deployment.rst b/docs/apache-airflow/production-deployment.rst
index 439afe1..042b655 100644
--- a/docs/apache-airflow/production-deployment.rst
+++ b/docs/apache-airflow/production-deployment.rst
@@ -56,9 +56,9 @@ Once that is done, you can run -
 Multi-Node Cluster
 ==================
 
-Airflow uses :class:`~airflow.executors.sequential_executor.SequentialExecutor` by default. However, by it
+Airflow uses :class:`~airflow.executors.sequential_executor.SequentialExecutor` by default. However, by its
 nature, the user is limited to executing at most one task at a time. ``Sequential Executor`` also pauses
-the scheduler when it runs a task, hence not recommended in a production setup. You should use the
+the scheduler when it runs a task, hence it is not recommended in a production setup. You should use the
 :class:`~airflow.executors.local_executor.LocalExecutor` for a single machine.
 For a multi-node setup, you should use the :doc:`Kubernetes executor <../executor/kubernetes>` or
 the :doc:`Celery executor <../executor/celery>`.


[airflow] 06/28: Add Google Cloud Workflows Operators (#13366)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 35679d27660e007a17f6c7faed4860b25df07d63
Author: Tomek Urbaszek <tu...@gmail.com>
AuthorDate: Thu Jan 28 20:35:09 2021 +0100

    Add Google Cloud Workflows Operators (#13366)
    
    Add Google Cloud Workflows Operators, system test, example and sensor
    
    Co-authored-by: Tobiasz Kędzierski <to...@polidea.com>
    (cherry picked from commit 6d6588fe2b8bb5fa33e930646d963df3e0530f23)
---
 .../google/cloud/example_dags/example_workflows.py | 197 ++++++
 airflow/providers/google/cloud/hooks/workflows.py  | 401 ++++++++++++
 .../providers/google/cloud/operators/workflows.py  | 714 +++++++++++++++++++++
 .../providers/google/cloud/sensors/workflows.py    | 123 ++++
 airflow/providers/google/provider.yaml             |  14 +
 .../operators/cloud/workflows.rst                  | 185 ++++++
 setup.py                                           |   2 +
 .../providers/google/cloud/hooks/test_workflows.py | 256 ++++++++
 .../google/cloud/operators/test_workflows.py       | 383 +++++++++++
 .../cloud/operators/test_workflows_system.py       |  29 +
 .../google/cloud/sensors/test_workflows.py         | 108 ++++
 .../google/cloud/utils/gcp_authenticator.py        |   1 +
 12 files changed, 2413 insertions(+)

diff --git a/airflow/providers/google/cloud/example_dags/example_workflows.py b/airflow/providers/google/cloud/example_dags/example_workflows.py
new file mode 100644
index 0000000..0fab435
--- /dev/null
+++ b/airflow/providers/google/cloud/example_dags/example_workflows.py
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+
+from airflow import DAG
+from airflow.providers.google.cloud.operators.workflows import (
+    WorkflowsCancelExecutionOperator,
+    WorkflowsCreateExecutionOperator,
+    WorkflowsCreateWorkflowOperator,
+    WorkflowsDeleteWorkflowOperator,
+    WorkflowsGetExecutionOperator,
+    WorkflowsGetWorkflowOperator,
+    WorkflowsListExecutionsOperator,
+    WorkflowsListWorkflowsOperator,
+    WorkflowsUpdateWorkflowOperator,
+)
+from airflow.providers.google.cloud.sensors.workflows import WorkflowExecutionSensor
+from airflow.utils.dates import days_ago
+
+LOCATION = os.environ.get("GCP_WORKFLOWS_LOCATION", "us-central1")
+PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id")
+
+WORKFLOW_ID = os.getenv("GCP_WORKFLOWS_WORKFLOW_ID", "airflow-test-workflow")
+
+# [START how_to_define_workflow]
+WORKFLOW_CONTENT = """
+- getCurrentTime:
+    call: http.get
+    args:
+        url: https://us-central1-workflowsample.cloudfunctions.net/datetime
+    result: currentTime
+- readWikipedia:
+    call: http.get
+    args:
+        url: https://en.wikipedia.org/w/api.php
+        query:
+            action: opensearch
+            search: ${currentTime.body.dayOfTheWeek}
+    result: wikiResult
+- returnResult:
+    return: ${wikiResult.body[1]}
+"""
+
+WORKFLOW = {
+    "description": "Test workflow",
+    "labels": {"airflow-version": "dev"},
+    "source_contents": WORKFLOW_CONTENT,
+}
+# [END how_to_define_workflow]
+
+EXECUTION = {"argument": ""}
+
+SLEEP_WORKFLOW_ID = os.getenv("GCP_WORKFLOWS_SLEEP_WORKFLOW_ID", "sleep_workflow")
+SLEEP_WORKFLOW_CONTENT = """
+- someSleep:
+    call: sys.sleep
+    args:
+        seconds: 120
+"""
+
+SLEEP_WORKFLOW = {
+    "description": "Test workflow",
+    "labels": {"airflow-version": "dev"},
+    "source_contents": SLEEP_WORKFLOW_CONTENT,
+}
+
+
+with DAG("example_cloud_workflows", start_date=days_ago(1), schedule_interval=None) as dag:
+    # [START how_to_create_workflow]
+    create_workflow = WorkflowsCreateWorkflowOperator(
+        task_id="create_workflow",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+        workflow=WORKFLOW,
+        workflow_id=WORKFLOW_ID,
+    )
+    # [END how_to_create_workflow]
+
+    # [START how_to_update_workflow]
+    update_workflows = WorkflowsUpdateWorkflowOperator(
+        task_id="update_workflows",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+        workflow_id=WORKFLOW_ID,
+        update_mask={"paths": ["name", "description"]},
+    )
+    # [END how_to_update_workflow]
+
+    # [START how_to_get_workflow]
+    get_workflow = WorkflowsGetWorkflowOperator(
+        task_id="get_workflow", location=LOCATION, project_id=PROJECT_ID, workflow_id=WORKFLOW_ID
+    )
+    # [END how_to_get_workflow]
+
+    # [START how_to_list_workflows]
+    list_workflows = WorkflowsListWorkflowsOperator(
+        task_id="list_workflows",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+    )
+    # [END how_to_list_workflows]
+
+    # [START how_to_delete_workflow]
+    delete_workflow = WorkflowsDeleteWorkflowOperator(
+        task_id="delete_workflow", location=LOCATION, project_id=PROJECT_ID, workflow_id=WORKFLOW_ID
+    )
+    # [END how_to_delete_workflow]
+
+    # [START how_to_create_execution]
+    create_execution = WorkflowsCreateExecutionOperator(
+        task_id="create_execution",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+        execution=EXECUTION,
+        workflow_id=WORKFLOW_ID,
+    )
+    # [END how_to_create_execution]
+
+    # [START how_to_wait_for_execution]
+    wait_for_execution = WorkflowExecutionSensor(
+        task_id="wait_for_execution",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+        workflow_id=WORKFLOW_ID,
+        execution_id='{{ task_instance.xcom_pull("create_execution", key="execution_id") }}',
+    )
+    # [END how_to_wait_for_execution]
+
+    # [START how_to_get_execution]
+    get_execution = WorkflowsGetExecutionOperator(
+        task_id="get_execution",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+        workflow_id=WORKFLOW_ID,
+        execution_id='{{ task_instance.xcom_pull("create_execution", key="execution_id") }}',
+    )
+    # [END how_to_get_execution]
+
+    # [START how_to_list_executions]
+    list_executions = WorkflowsListExecutionsOperator(
+        task_id="list_executions", location=LOCATION, project_id=PROJECT_ID, workflow_id=WORKFLOW_ID
+    )
+    # [END how_to_list_executions]
+
+    create_workflow_for_cancel = WorkflowsCreateWorkflowOperator(
+        task_id="create_workflow_for_cancel",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+        workflow=SLEEP_WORKFLOW,
+        workflow_id=SLEEP_WORKFLOW_ID,
+    )
+
+    create_execution_for_cancel = WorkflowsCreateExecutionOperator(
+        task_id="create_execution_for_cancel",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+        execution=EXECUTION,
+        workflow_id=SLEEP_WORKFLOW_ID,
+    )
+
+    # [START how_to_cancel_execution]
+    cancel_execution = WorkflowsCancelExecutionOperator(
+        task_id="cancel_execution",
+        location=LOCATION,
+        project_id=PROJECT_ID,
+        workflow_id=SLEEP_WORKFLOW_ID,
+        execution_id='{{ task_instance.xcom_pull("create_execution_for_cancel", key="execution_id") }}',
+    )
+    # [END how_to_cancel_execution]
+
+    create_workflow >> update_workflows >> [get_workflow, list_workflows]
+    update_workflows >> [create_execution, create_execution_for_cancel]
+
+    create_execution >> wait_for_execution >> [get_execution, list_executions]
+    create_workflow_for_cancel >> create_execution_for_cancel >> cancel_execution
+
+    [cancel_execution, list_executions] >> delete_workflow
+
+
+if __name__ == '__main__':
+    dag.clear(dag_run_state=None)
+    dag.run()
diff --git a/airflow/providers/google/cloud/hooks/workflows.py b/airflow/providers/google/cloud/hooks/workflows.py
new file mode 100644
index 0000000..6c78350
--- /dev/null
+++ b/airflow/providers/google/cloud/hooks/workflows.py
@@ -0,0 +1,401 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Dict, Optional, Sequence, Tuple, Union
+
+from google.api_core.operation import Operation
+from google.api_core.retry import Retry
+
+# pylint: disable=no-name-in-module
+from google.cloud.workflows.executions_v1beta import Execution, ExecutionsClient
+from google.cloud.workflows.executions_v1beta.services.executions.pagers import ListExecutionsPager
+from google.cloud.workflows_v1beta import Workflow, WorkflowsClient
+from google.cloud.workflows_v1beta.services.workflows.pagers import ListWorkflowsPager
+from google.protobuf.field_mask_pb2 import FieldMask
+
+from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+
+# pylint: enable=no-name-in-module
+
+
+class WorkflowsHook(GoogleBaseHook):
+    """
+    Hook for Google GCP APIs.
+
+    All the methods in the hook where project_id is used must be called with
+    keyword arguments rather than positional.
+    """
+
+    def get_workflows_client(self) -> WorkflowsClient:
+        """Returns WorkflowsClient."""
+        return WorkflowsClient(credentials=self._get_credentials(), client_info=self.client_info)
+
+    def get_executions_client(self) -> ExecutionsClient:
+        """Returns ExecutionsClient."""
+        return ExecutionsClient(credentials=self._get_credentials(), client_info=self.client_info)
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def create_workflow(
+        self,
+        workflow: Dict,
+        workflow_id: str,
+        location: str,
+        project_id: str,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> Operation:
+        """
+        Creates a new workflow. If a workflow with the specified name
+        already exists in the specified project and location, the long
+        running operation will return
+        [ALREADY_EXISTS][google.rpc.Code.ALREADY_EXISTS] error.
+
+        :param workflow: Required. Workflow to be created.
+        :type workflow: Dict
+        :param workflow_id: Required. The ID of the workflow to be created.
+        :type workflow_id: str
+        :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+        :type project_id: str
+        :param location: Required. The GCP region in which to handle the request.
+        :type location: str
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_workflows_client()
+        parent = f"projects/{project_id}/locations/{location}"
+        return client.create_workflow(
+            request={"parent": parent, "workflow": workflow, "workflow_id": workflow_id},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def get_workflow(
+        self,
+        workflow_id: str,
+        location: str,
+        project_id: str,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> Workflow:
+        """
+        Gets details of a single Workflow.
+
+        :param workflow_id: Required. The ID of the workflow to be created.
+        :type workflow_id: str
+        :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+        :type project_id: str
+        :param location: Required. The GCP region in which to handle the request.
+        :type location: str
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_workflows_client()
+        name = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}"
+        return client.get_workflow(request={"name": name}, retry=retry, timeout=timeout, metadata=metadata)
+
+    def update_workflow(
+        self,
+        workflow: Union[Dict, Workflow],
+        update_mask: Optional[FieldMask] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> Operation:
+        """
+        Updates an existing workflow.
+        Running this method has no impact on already running
+        executions of the workflow. A new revision of the
+        workflow may be created as a result of a successful
+        update operation. In that case, such revision will be
+        used in new workflow executions.
+
+        :param workflow: Required. Workflow to be created.
+        :type workflow: Dict
+        :param update_mask: List of fields to be updated. If not present,
+            the entire workflow will be updated.
+        :type update_mask: FieldMask
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_workflows_client()
+        return client.update_workflow(
+            request={"workflow": workflow, "update_mask": update_mask},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def delete_workflow(
+        self,
+        workflow_id: str,
+        location: str,
+        project_id: str,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> Operation:
+        """
+        Deletes a workflow with the specified name.
+        This method also cancels and deletes all running
+        executions of the workflow.
+
+        :param workflow_id: Required. The ID of the workflow to be created.
+        :type workflow_id: str
+        :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+        :type project_id: str
+        :param location: Required. The GCP region in which to handle the request.
+        :type location: str
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_workflows_client()
+        name = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}"
+        return client.delete_workflow(request={"name": name}, retry=retry, timeout=timeout, metadata=metadata)
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def list_workflows(
+        self,
+        location: str,
+        project_id: str,
+        filter_: Optional[str] = None,
+        order_by: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> ListWorkflowsPager:
+        """
+        Lists Workflows in a given project and location.
+        The default order is not specified.
+
+        :param filter_: Filter to restrict results to specific workflows.
+        :type filter_: str
+        :param order_by: Comma-separated list of fields that that
+            specify the order of the results. Default sorting order for a field is ascending.
+            To specify descending order for a field, append a "desc" suffix.
+            If not specified, the results will be returned in an unspecified order.
+        :type order_by: str
+        :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+        :type project_id: str
+        :param location: Required. The GCP region in which to handle the request.
+        :type location: str
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_workflows_client()
+        parent = f"projects/{project_id}/locations/{location}"
+
+        return client.list_workflows(
+            request={"parent": parent, "filter": filter_, "order_by": order_by},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def create_execution(
+        self,
+        workflow_id: str,
+        location: str,
+        project_id: str,
+        execution: Dict,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> Execution:
+        """
+        Creates a new execution using the latest revision of
+        the given workflow.
+
+        :param execution: Required. Input parameters of the execution represented as a dictionary.
+        :type execution: Dict
+        :param workflow_id: Required. The ID of the workflow.
+        :type workflow_id: str
+        :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+        :type project_id: str
+        :param location: Required. The GCP region in which to handle the request.
+        :type location: str
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_executions_client()
+        parent = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}"
+        return client.create_execution(
+            request={"parent": parent, "execution": execution},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def get_execution(
+        self,
+        workflow_id: str,
+        execution_id: str,
+        location: str,
+        project_id: str,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> Execution:
+        """
+        Returns an execution for the given ``workflow_id`` and ``execution_id``.
+
+        :param workflow_id: Required. The ID of the workflow.
+        :type workflow_id: str
+        :param execution_id: Required. The ID of the execution.
+        :type execution_id: str
+        :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+        :type project_id: str
+        :param location: Required. The GCP region in which to handle the request.
+        :type location: str
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_executions_client()
+        name = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}/executions/{execution_id}"
+        return client.get_execution(request={"name": name}, retry=retry, timeout=timeout, metadata=metadata)
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def cancel_execution(
+        self,
+        workflow_id: str,
+        execution_id: str,
+        location: str,
+        project_id: str,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> Execution:
+        """
+        Cancels an execution using the given ``workflow_id`` and ``execution_id``.
+
+        :param workflow_id: Required. The ID of the workflow.
+        :type workflow_id: str
+        :param execution_id: Required. The ID of the execution.
+        :type execution_id: str
+        :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+        :type project_id: str
+        :param location: Required. The GCP region in which to handle the request.
+        :type location: str
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_executions_client()
+        name = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}/executions/{execution_id}"
+        return client.cancel_execution(
+            request={"name": name}, retry=retry, timeout=timeout, metadata=metadata
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def list_executions(
+        self,
+        workflow_id: str,
+        location: str,
+        project_id: str,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+    ) -> ListExecutionsPager:
+        """
+        Returns a list of executions which belong to the
+        workflow with the given name. The method returns
+        executions of all workflow revisions. Returned
+        executions are ordered by their start time (newest
+        first).
+
+        :param workflow_id: Required. The ID of the workflow to be created.
+        :type workflow_id: str
+        :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+        :type project_id: str
+        :param location: Required. The GCP region in which to handle the request.
+        :type location: str
+        :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+            retried.
+        :type retry: google.api_core.retry.Retry
+        :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+            ``retry`` is specified, the timeout applies to each individual attempt.
+        :type timeout: float
+        :param metadata: Additional metadata that is provided to the method.
+        :type metadata: Sequence[Tuple[str, str]]
+        """
+        metadata = metadata or ()
+        client = self.get_executions_client()
+        parent = f"projects/{project_id}/locations/{location}/workflows/{workflow_id}"
+        return client.list_executions(
+            request={"parent": parent}, retry=retry, timeout=timeout, metadata=metadata
+        )
diff --git a/airflow/providers/google/cloud/operators/workflows.py b/airflow/providers/google/cloud/operators/workflows.py
new file mode 100644
index 0000000..c7fc96d
--- /dev/null
+++ b/airflow/providers/google/cloud/operators/workflows.py
@@ -0,0 +1,714 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import hashlib
+import json
+import re
+import uuid
+from datetime import datetime, timedelta
+from typing import Dict, Optional, Sequence, Tuple, Union
+
+import pytz
+from google.api_core.exceptions import AlreadyExists
+from google.api_core.retry import Retry
+
+# pylint: disable=no-name-in-module
+from google.cloud.workflows.executions_v1beta import Execution
+from google.cloud.workflows_v1beta import Workflow
+
+# pylint: enable=no-name-in-module
+from google.protobuf.field_mask_pb2 import FieldMask
+
+from airflow.models import BaseOperator
+from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook
+
+
+class WorkflowsCreateWorkflowOperator(BaseOperator):
+    """
+    Creates a new workflow. If a workflow with the specified name
+    already exists in the specified project and location, the long
+    running operation will return
+    [ALREADY_EXISTS][google.rpc.Code.ALREADY_EXISTS] error.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsCreateWorkflowOperator`
+
+    :param workflow: Required. Workflow to be created.
+    :type workflow: Dict
+    :param workflow_id: Required. The ID of the workflow to be created.
+    :type workflow_id: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "workflow", "workflow_id")
+    template_fields_renderers = {"workflow": "json"}
+
+    def __init__(
+        self,
+        *,
+        workflow: Dict,
+        workflow_id: str,
+        location: str,
+        project_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        force_rerun: bool = False,
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.workflow = workflow
+        self.workflow_id = workflow_id
+        self.location = location
+        self.project_id = project_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.force_rerun = force_rerun
+
+    def _workflow_id(self, context):
+        if self.workflow_id and not self.force_rerun:
+            # If users provide workflow id then assuring the idempotency
+            # is on their side
+            return self.workflow_id
+
+        if self.force_rerun:
+            hash_base = str(uuid.uuid4())
+        else:
+            hash_base = json.dumps(self.workflow, sort_keys=True)
+
+        # We are limited by allowed length of workflow_id so
+        # we use hash of whole information
+        exec_date = context['execution_date'].isoformat()
+        base = f"airflow_{self.dag_id}_{self.task_id}_{exec_date}_{hash_base}"
+        workflow_id = hashlib.md5(base.encode()).hexdigest()
+        return re.sub(r"[:\-+.]", "_", workflow_id)
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        workflow_id = self._workflow_id(context)
+
+        self.log.info("Creating workflow")
+        try:
+            operation = hook.create_workflow(
+                workflow=self.workflow,
+                workflow_id=workflow_id,
+                location=self.location,
+                project_id=self.project_id,
+                retry=self.retry,
+                timeout=self.timeout,
+                metadata=self.metadata,
+            )
+            workflow = operation.result()
+        except AlreadyExists:
+            workflow = hook.get_workflow(
+                workflow_id=workflow_id,
+                location=self.location,
+                project_id=self.project_id,
+                retry=self.retry,
+                timeout=self.timeout,
+                metadata=self.metadata,
+            )
+        return Workflow.to_dict(workflow)
+
+
+class WorkflowsUpdateWorkflowOperator(BaseOperator):
+    """
+    Updates an existing workflow.
+    Running this method has no impact on already running
+    executions of the workflow. A new revision of the
+    workflow may be created as a result of a successful
+    update operation. In that case, such revision will be
+    used in new workflow executions.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsUpdateWorkflowOperator`
+
+    :param workflow_id: Required. The ID of the workflow to be updated.
+    :type workflow_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param update_mask: List of fields to be updated. If not present,
+        the entire workflow will be updated.
+    :type update_mask: FieldMask
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("workflow_id", "update_mask")
+    template_fields_renderers = {"update_mask": "json"}
+
+    def __init__(
+        self,
+        *,
+        workflow_id: str,
+        location: str,
+        project_id: Optional[str] = None,
+        update_mask: Optional[FieldMask] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.workflow_id = workflow_id
+        self.location = location
+        self.project_id = project_id
+        self.update_mask = update_mask
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+
+        workflow = hook.get_workflow(
+            workflow_id=self.workflow_id,
+            project_id=self.project_id,
+            location=self.location,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        self.log.info("Updating workflow")
+        operation = hook.update_workflow(
+            workflow=workflow,
+            update_mask=self.update_mask,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        workflow = operation.result()
+        return Workflow.to_dict(workflow)
+
+
+class WorkflowsDeleteWorkflowOperator(BaseOperator):
+    """
+    Deletes a workflow with the specified name.
+    This method also cancels and deletes all running
+    executions of the workflow.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsDeleteWorkflowOperator`
+
+    :param workflow_id: Required. The ID of the workflow to be created.
+    :type workflow_id: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "workflow_id")
+
+    def __init__(
+        self,
+        *,
+        workflow_id: str,
+        location: str,
+        project_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.workflow_id = workflow_id
+        self.location = location
+        self.project_id = project_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        self.log.info("Deleting workflow %s", self.workflow_id)
+        operation = hook.delete_workflow(
+            workflow_id=self.workflow_id,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        operation.result()
+
+
+class WorkflowsListWorkflowsOperator(BaseOperator):
+    """
+    Lists Workflows in a given project and location.
+    The default order is not specified.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsListWorkflowsOperator`
+
+    :param filter_: Filter to restrict results to specific workflows.
+    :type filter_: str
+    :param order_by: Comma-separated list of fields that that
+        specify the order of the results. Default sorting order for a field is ascending.
+        To specify descending order for a field, append a "desc" suffix.
+        If not specified, the results will be returned in an unspecified order.
+    :type order_by: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "order_by", "filter_")
+
+    def __init__(
+        self,
+        *,
+        location: str,
+        project_id: Optional[str] = None,
+        filter_: Optional[str] = None,
+        order_by: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.filter_ = filter_
+        self.order_by = order_by
+        self.location = location
+        self.project_id = project_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        self.log.info("Retrieving workflows")
+        workflows_iter = hook.list_workflows(
+            filter_=self.filter_,
+            order_by=self.order_by,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        return [Workflow.to_dict(w) for w in workflows_iter]
+
+
+class WorkflowsGetWorkflowOperator(BaseOperator):
+    """
+    Gets details of a single Workflow.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsGetWorkflowOperator`
+
+    :param workflow_id: Required. The ID of the workflow to be created.
+    :type workflow_id: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "workflow_id")
+
+    def __init__(
+        self,
+        *,
+        workflow_id: str,
+        location: str,
+        project_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.workflow_id = workflow_id
+        self.location = location
+        self.project_id = project_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        self.log.info("Retrieving workflow")
+        workflow = hook.get_workflow(
+            workflow_id=self.workflow_id,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        return Workflow.to_dict(workflow)
+
+
+class WorkflowsCreateExecutionOperator(BaseOperator):
+    """
+    Creates a new execution using the latest revision of
+    the given workflow.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsCreateExecutionOperator`
+
+    :param execution: Required. Execution to be created.
+    :type execution: Dict
+    :param workflow_id: Required. The ID of the workflow.
+    :type workflow_id: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "workflow_id", "execution")
+    template_fields_renderers = {"execution": "json"}
+
+    def __init__(
+        self,
+        *,
+        workflow_id: str,
+        execution: Dict,
+        location: str,
+        project_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.workflow_id = workflow_id
+        self.execution = execution
+        self.location = location
+        self.project_id = project_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        self.log.info("Creating execution")
+        execution = hook.create_execution(
+            workflow_id=self.workflow_id,
+            execution=self.execution,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        execution_id = execution.name.split("/")[-1]
+        self.xcom_push(context, key="execution_id", value=execution_id)
+        return Execution.to_dict(execution)
+
+
+class WorkflowsCancelExecutionOperator(BaseOperator):
+    """
+    Cancels an execution using the given ``workflow_id`` and ``execution_id``.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsCancelExecutionOperator`
+
+    :param workflow_id: Required. The ID of the workflow.
+    :type workflow_id: str
+    :param execution_id: Required. The ID of the execution.
+    :type execution_id: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "workflow_id", "execution_id")
+
+    def __init__(
+        self,
+        *,
+        workflow_id: str,
+        execution_id: str,
+        location: str,
+        project_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.workflow_id = workflow_id
+        self.execution_id = execution_id
+        self.location = location
+        self.project_id = project_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        self.log.info("Canceling execution %s", self.execution_id)
+        execution = hook.cancel_execution(
+            workflow_id=self.workflow_id,
+            execution_id=self.execution_id,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        return Execution.to_dict(execution)
+
+
+class WorkflowsListExecutionsOperator(BaseOperator):
+    """
+    Returns a list of executions which belong to the
+    workflow with the given name. The method returns
+    executions of all workflow revisions. Returned
+    executions are ordered by their start time (newest
+    first).
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsListExecutionsOperator`
+
+    :param workflow_id: Required. The ID of the workflow to be created.
+    :type workflow_id: str
+    :param start_date_filter: If passed only executions older that this date will be returned.
+        By default operators return executions from last 60 minutes
+    :type start_date_filter: datetime
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "workflow_id")
+
+    def __init__(
+        self,
+        *,
+        workflow_id: str,
+        location: str,
+        start_date_filter: Optional[datetime] = None,
+        project_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.workflow_id = workflow_id
+        self.location = location
+        self.start_date_filter = start_date_filter or datetime.now(tz=pytz.UTC) - timedelta(minutes=60)
+        self.project_id = project_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        self.log.info("Retrieving executions for workflow %s", self.workflow_id)
+        execution_iter = hook.list_executions(
+            workflow_id=self.workflow_id,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+        return [Execution.to_dict(e) for e in execution_iter if e.start_time > self.start_date_filter]
+
+
+class WorkflowsGetExecutionOperator(BaseOperator):
+    """
+    Returns an execution for the given ``workflow_id`` and ``execution_id``.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:WorkflowsGetExecutionOperator`
+
+    :param workflow_id: Required. The ID of the workflow.
+    :type workflow_id: str
+    :param execution_id: Required. The ID of the execution.
+    :type execution_id: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The GCP region in which to handle the request.
+    :type location: str
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "workflow_id", "execution_id")
+
+    def __init__(
+        self,
+        *,
+        workflow_id: str,
+        execution_id: str,
+        location: str,
+        project_id: Optional[str] = None,
+        retry: Optional[Retry] = None,
+        timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.workflow_id = workflow_id
+        self.execution_id = execution_id
+        self.location = location
+        self.project_id = project_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        self.log.info("Retrieving execution %s for workflow %s", self.execution_id, self.workflow_id)
+        execution = hook.get_execution(
+            workflow_id=self.workflow_id,
+            execution_id=self.execution_id,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        return Execution.to_dict(execution)
diff --git a/airflow/providers/google/cloud/sensors/workflows.py b/airflow/providers/google/cloud/sensors/workflows.py
new file mode 100644
index 0000000..5950458
--- /dev/null
+++ b/airflow/providers/google/cloud/sensors/workflows.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Sequence, Set, Tuple, Union
+
+from google.api_core.retry import Retry
+from google.cloud.workflows.executions_v1beta import Execution
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook
+from airflow.sensors.base import BaseSensorOperator
+
+
+class WorkflowExecutionSensor(BaseSensorOperator):
+    """
+    Checks state of an execution for the given ``workflow_id`` and ``execution_id``.
+
+    :param workflow_id: Required. The ID of the workflow.
+    :type workflow_id: str
+    :param execution_id: Required. The ID of the execution.
+    :type execution_id: str
+    :param project_id: Required. The ID of the Google Cloud project the cluster belongs to.
+    :type project_id: str
+    :param location: Required. The Cloud Dataproc region in which to handle the request.
+    :type location: str
+    :param success_states: Execution states to be considered as successful, by default
+        it's only ``SUCCEEDED`` state
+    :type success_states: List[Execution.State]
+    :param failure_states: Execution states to be considered as failures, by default
+        they are ``FAILED`` and ``CANCELLED`` states.
+    :type failure_states: List[Execution.State]
+    :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
+        retried.
+    :type retry: google.api_core.retry.Retry
+    :param request_timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :type request_timeout: float
+    :param metadata: Additional metadata that is provided to the method.
+    :type metadata: Sequence[Tuple[str, str]]
+    """
+
+    template_fields = ("location", "workflow_id", "execution_id")
+
+    def __init__(
+        self,
+        *,
+        workflow_id: str,
+        execution_id: str,
+        location: str,
+        project_id: str,
+        success_states: Optional[Set[Execution.State]] = None,
+        failure_states: Optional[Set[Execution.State]] = None,
+        retry: Optional[Retry] = None,
+        request_timeout: Optional[float] = None,
+        metadata: Optional[Sequence[Tuple[str, str]]] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.success_states = success_states or {Execution.State.SUCCEEDED}
+        self.failure_states = failure_states or {Execution.State.FAILED, Execution.State.CANCELLED}
+        self.workflow_id = workflow_id
+        self.execution_id = execution_id
+        self.location = location
+        self.project_id = project_id
+        self.retry = retry
+        self.request_timeout = request_timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def poke(self, context):
+        hook = WorkflowsHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
+        self.log.info("Checking state of execution %s for workflow %s", self.execution_id, self.workflow_id)
+        execution: Execution = hook.get_execution(
+            workflow_id=self.workflow_id,
+            execution_id=self.execution_id,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.request_timeout,
+            metadata=self.metadata,
+        )
+
+        state = execution.state
+        if state in self.failure_states:
+            raise AirflowException(
+                f"Execution {self.execution_id} for workflow {self.execution_id} "
+                f"failed and is in `{state}` state",
+            )
+
+        if state in self.success_states:
+            self.log.info(
+                "Execution %s for workflow %s completed with state: %s",
+                self.execution_id,
+                self.workflow_id,
+                state,
+            )
+            return True
+
+        self.log.info(
+            "Execution %s for workflow %s does not completed yet, current state: %s",
+            self.execution_id,
+            self.workflow_id,
+            state,
+        )
+        return False
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 9961b13..690eb00 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -277,6 +277,11 @@ integrations:
       - /docs/apache-airflow-providers-google/operators/cloud/natural_language.rst
     logo: /integration-logos/gcp/Cloud-NLP.png
     tags: [gcp]
+  - integration-name: Google Cloud Workflows
+    external-doc-url: https://cloud.google.com/workflows/
+    how-to-guide:
+      - /docs/apache-airflow-providers-google/operators/cloud/workflows.rst
+    tags: [gcp]
 
 operators:
   - integration-name: Google Ads
@@ -377,6 +382,9 @@ operators:
   - integration-name: Google Cloud Vision
     python-modules:
       - airflow.providers.google.cloud.operators.vision
+  - integration-name: Google Cloud Workflows
+    python-modules:
+      - airflow.providers.google.cloud.operators.workflows
   - integration-name: Google Cloud Firestore
     python-modules:
       - airflow.providers.google.firebase.operators.firestore
@@ -421,6 +429,9 @@ sensors:
   - integration-name: Google Cloud Pub/Sub
     python-modules:
       - airflow.providers.google.cloud.sensors.pubsub
+  - integration-name: Google Cloud Workflows
+    python-modules:
+      - airflow.providers.google.cloud.sensors.workflows
   - integration-name: Google Campaign Manager
     python-modules:
       - airflow.providers.google.marketing_platform.sensors.campaign_manager
@@ -541,6 +552,9 @@ hooks:
   - integration-name: Google Cloud Vision
     python-modules:
       - airflow.providers.google.cloud.hooks.vision
+  - integration-name: Google Cloud Workflows
+    python-modules:
+      - airflow.providers.google.cloud.hooks.workflows
   - integration-name: Google
     python-modules:
       - airflow.providers.google.common.hooks.base_google
diff --git a/docs/apache-airflow-providers-google/operators/cloud/workflows.rst b/docs/apache-airflow-providers-google/operators/cloud/workflows.rst
new file mode 100644
index 0000000..551a7ca
--- /dev/null
+++ b/docs/apache-airflow-providers-google/operators/cloud/workflows.rst
@@ -0,0 +1,185 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+Google Cloud Workflows Operators
+================================
+
+You can use Workflows to create serverless workflows that link series of serverless tasks together
+in an order you define. Combine the power of Google Cloud's APIs, serverless products like Cloud
+Functions and Cloud Run, and calls to external APIs to create flexible serverless applications.
+
+For more information about the service visit
+`Workflows production documentation <Product documentation <https://cloud.google.com/workflows/docs/overview>`__.
+
+.. contents::
+  :depth: 1
+  :local:
+
+Prerequisite Tasks
+------------------
+
+.. include::/operators/_partials/prerequisite_tasks.rst
+
+
+.. _howto/operator:WorkflowsCreateWorkflowOperator:
+
+Create workflow
+===============
+
+To create a workflow use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsCreateWorkflowOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_create_workflow]
+      :end-before: [END how_to_create_workflow]
+
+The workflow should be define in similar why to this example:
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 0
+      :start-after: [START how_to_define_workflow]
+      :end-before: [END how_to_define_workflow]
+
+For more information about authoring workflows check official
+production documentation `<Product documentation <https://cloud.google.com/workflows/docs/overview>`__.
+
+
+.. _howto/operator:WorkflowsUpdateWorkflowOperator:
+
+Update workflow
+===============
+
+To update a workflow use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsUpdateWorkflowOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_update_workflow]
+      :end-before: [END how_to_update_workflow]
+
+.. _howto/operator:WorkflowsGetWorkflowOperator:
+
+Get workflow
+============
+
+To get a workflow use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsGetWorkflowOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_get_workflow]
+      :end-before: [END how_to_get_workflow]
+
+.. _howto/operator:WorkflowsListWorkflowsOperator:
+
+List workflows
+==============
+
+To list workflows use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsListWorkflowsOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_list_workflows]
+      :end-before: [END how_to_list_workflows]
+
+.. _howto/operator:WorkflowsDeleteWorkflowOperator:
+
+Delete workflow
+===============
+
+To delete a workflow use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsDeleteWorkflowOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_delete_workflow]
+      :end-before: [END how_to_delete_workflow]
+
+.. _howto/operator:WorkflowsCreateExecutionOperator:
+
+Create execution
+================
+
+To create an execution use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsCreateExecutionOperator`.
+This operator is not idempotent due to API limitation.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_create_execution]
+      :end-before: [END how_to_create_execution]
+
+The create operator does not wait for execution to complete. To wait for execution result use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowExecutionSensor`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_wait_for_execution]
+      :end-before: [END how_to_wait_for_execution]
+
+.. _howto/operator:WorkflowsGetExecutionOperator:
+
+Get execution
+================
+
+To get an execution use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsGetExecutionOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_get_execution]
+      :end-before: [END how_to_get_execution]
+
+.. _howto/operator:WorkflowsListExecutionsOperator:
+
+List executions
+===============
+
+To list executions use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsListExecutionsOperator`.
+By default this operator will return only executions for last 60 minutes.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_list_executions]
+      :end-before: [END how_to_list_executions]
+
+.. _howto/operator:WorkflowsCancelExecutionOperator:
+
+Cancel execution
+================
+
+To cancel an execution use
+:class:`~airflow.providers.google.cloud.operators.dataproc.WorkflowsCancelExecutionOperator`.
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_workflows.py
+      :language: python
+      :dedent: 4
+      :start-after: [START how_to_cancel_execution]
+      :end-before: [END how_to_cancel_execution]
diff --git a/setup.py b/setup.py
index 0689bd5..7071795 100644
--- a/setup.py
+++ b/setup.py
@@ -279,6 +279,7 @@ flask_oauth = [
 google = [
     'PyOpenSSL',
     'google-ads>=4.0.0,<8.0.0',
+    'google-api-core>=1.25.1,<2.0.0',
     'google-api-python-client>=1.6.0,<2.0.0',
     'google-auth>=1.0.0,<2.0.0',
     'google-auth-httplib2>=0.0.1',
@@ -306,6 +307,7 @@ google = [
     'google-cloud-translate>=1.5.0,<2.0.0',
     'google-cloud-videointelligence>=1.7.0,<2.0.0',
     'google-cloud-vision>=0.35.2,<2.0.0',
+    'google-cloud-workflows>=0.1.0,<2.0.0',
     'grpcio-gcp>=0.2.2',
     'json-merge-patch~=0.2',
     'pandas-gbq',
diff --git a/tests/providers/google/cloud/hooks/test_workflows.py b/tests/providers/google/cloud/hooks/test_workflows.py
new file mode 100644
index 0000000..4f3d4d0
--- /dev/null
+++ b/tests/providers/google/cloud/hooks/test_workflows.py
@@ -0,0 +1,256 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest import mock
+
+from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook
+
+BASE_PATH = "airflow.providers.google.cloud.hooks.workflows.{}"
+LOCATION = "europe-west1"
+WORKFLOW_ID = "workflow_id"
+EXECUTION_ID = "execution_id"
+WORKFLOW = {"aa": "bb"}
+EXECUTION = {"ccc": "ddd"}
+PROJECT_ID = "airflow-testing"
+METADATA = ()
+TIMEOUT = None
+RETRY = None
+FILTER_ = "aaaa"
+ORDER_BY = "bbb"
+UPDATE_MASK = "aaa,bbb"
+
+WORKFLOW_PARENT = f"projects/{PROJECT_ID}/locations/{LOCATION}"
+WORKFLOW_NAME = f"projects/{PROJECT_ID}/locations/{LOCATION}/workflows/{WORKFLOW_ID}"
+EXECUTION_PARENT = f"projects/{PROJECT_ID}/locations/{LOCATION}/workflows/{WORKFLOW_ID}"
+EXECUTION_NAME = (
+    f"projects/{PROJECT_ID}/locations/{LOCATION}/workflows/{WORKFLOW_ID}/executions/{EXECUTION_ID}"
+)
+
+
+def mock_init(*args, **kwargs):
+    pass
+
+
+class TestWorkflowsHook:
+    def setup_method(self, _):
+        with mock.patch(BASE_PATH.format("GoogleBaseHook.__init__"), new=mock_init):
+            self.hook = WorkflowsHook(gcp_conn_id="test")  # pylint: disable=attribute-defined-outside-init
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook._get_credentials"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook.client_info"), new_callable=mock.PropertyMock)
+    @mock.patch(BASE_PATH.format("WorkflowsClient"))
+    def test_get_workflows_client(self, mock_client, mock_client_info, mock_get_credentials):
+        self.hook.get_workflows_client()
+        mock_client.assert_called_once_with(
+            credentials=mock_get_credentials.return_value,
+            client_info=mock_client_info.return_value,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook._get_credentials"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook.client_info"), new_callable=mock.PropertyMock)
+    @mock.patch(BASE_PATH.format("ExecutionsClient"))
+    def test_get_executions_client(self, mock_client, mock_client_info, mock_get_credentials):
+        self.hook.get_executions_client()
+        mock_client.assert_called_once_with(
+            credentials=mock_get_credentials.return_value,
+            client_info=mock_client_info.return_value,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_workflows_client"))
+    def test_create_workflow(self, mock_client):
+        result = self.hook.create_workflow(
+            workflow=WORKFLOW,
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.create_workflow.return_value == result
+        mock_client.return_value.create_workflow.assert_called_once_with(
+            request=dict(workflow=WORKFLOW, workflow_id=WORKFLOW_ID, parent=WORKFLOW_PARENT),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_workflows_client"))
+    def test_get_workflow(self, mock_client):
+        result = self.hook.get_workflow(
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.get_workflow.return_value == result
+        mock_client.return_value.get_workflow.assert_called_once_with(
+            request=dict(name=WORKFLOW_NAME),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_workflows_client"))
+    def test_update_workflow(self, mock_client):
+        result = self.hook.update_workflow(
+            workflow=WORKFLOW,
+            update_mask=UPDATE_MASK,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.update_workflow.return_value == result
+        mock_client.return_value.update_workflow.assert_called_once_with(
+            request=dict(
+                workflow=WORKFLOW,
+                update_mask=UPDATE_MASK,
+            ),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_workflows_client"))
+    def test_delete_workflow(self, mock_client):
+        result = self.hook.delete_workflow(
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.delete_workflow.return_value == result
+        mock_client.return_value.delete_workflow.assert_called_once_with(
+            request=dict(name=WORKFLOW_NAME),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_workflows_client"))
+    def test_list_workflows(self, mock_client):
+        result = self.hook.list_workflows(
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            filter_=FILTER_,
+            order_by=ORDER_BY,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.list_workflows.return_value == result
+        mock_client.return_value.list_workflows.assert_called_once_with(
+            request=dict(
+                parent=WORKFLOW_PARENT,
+                filter=FILTER_,
+                order_by=ORDER_BY,
+            ),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_executions_client"))
+    def test_create_execution(self, mock_client):
+        result = self.hook.create_execution(
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            execution=EXECUTION,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.create_execution.return_value == result
+        mock_client.return_value.create_execution.assert_called_once_with(
+            request=dict(
+                parent=EXECUTION_PARENT,
+                execution=EXECUTION,
+            ),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_executions_client"))
+    def test_get_execution(self, mock_client):
+        result = self.hook.get_execution(
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.get_execution.return_value == result
+        mock_client.return_value.get_execution.assert_called_once_with(
+            request=dict(name=EXECUTION_NAME),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_executions_client"))
+    def test_cancel_execution(self, mock_client):
+        result = self.hook.cancel_execution(
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.cancel_execution.return_value == result
+        mock_client.return_value.cancel_execution.assert_called_once_with(
+            request=dict(name=EXECUTION_NAME),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook.get_executions_client"))
+    def test_list_execution(self, mock_client):
+        result = self.hook.list_executions(
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert mock_client.return_value.list_executions.return_value == result
+        mock_client.return_value.list_executions.assert_called_once_with(
+            request=dict(parent=EXECUTION_PARENT),
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
diff --git a/tests/providers/google/cloud/operators/test_workflows.py b/tests/providers/google/cloud/operators/test_workflows.py
new file mode 100644
index 0000000..5578548
--- /dev/null
+++ b/tests/providers/google/cloud/operators/test_workflows.py
@@ -0,0 +1,383 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import datetime
+from unittest import mock
+
+import pytz
+
+from airflow.providers.google.cloud.operators.workflows import (
+    WorkflowsCancelExecutionOperator,
+    WorkflowsCreateExecutionOperator,
+    WorkflowsCreateWorkflowOperator,
+    WorkflowsDeleteWorkflowOperator,
+    WorkflowsGetExecutionOperator,
+    WorkflowsGetWorkflowOperator,
+    WorkflowsListExecutionsOperator,
+    WorkflowsListWorkflowsOperator,
+    WorkflowsUpdateWorkflowOperator,
+)
+
+BASE_PATH = "airflow.providers.google.cloud.operators.workflows.{}"
+LOCATION = "europe-west1"
+WORKFLOW_ID = "workflow_id"
+EXECUTION_ID = "execution_id"
+WORKFLOW = {"aa": "bb"}
+EXECUTION = {"ccc": "ddd"}
+PROJECT_ID = "airflow-testing"
+METADATA = None
+TIMEOUT = None
+RETRY = None
+FILTER_ = "aaaa"
+ORDER_BY = "bbb"
+UPDATE_MASK = "aaa,bbb"
+GCP_CONN_ID = "test-conn"
+IMPERSONATION_CHAIN = None
+
+
+class TestWorkflowsCreateWorkflowOperator:
+    @mock.patch(BASE_PATH.format("Workflow"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_execute(self, mock_hook, mock_object):
+        op = WorkflowsCreateWorkflowOperator(
+            task_id="test_task",
+            workflow=WORKFLOW,
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.create_workflow.assert_called_once_with(
+            workflow=WORKFLOW,
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert result == mock_object.to_dict.return_value
+
+
+class TestWorkflowsUpdateWorkflowOperator:
+    @mock.patch(BASE_PATH.format("Workflow"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_execute(self, mock_hook, mock_object):
+        op = WorkflowsUpdateWorkflowOperator(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            update_mask=UPDATE_MASK,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.get_workflow.assert_called_once_with(
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        mock_hook.return_value.update_workflow.assert_called_once_with(
+            workflow=mock_hook.return_value.get_workflow.return_value,
+            update_mask=UPDATE_MASK,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert result == mock_object.to_dict.return_value
+
+
+class TestWorkflowsDeleteWorkflowOperator:
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_execute(
+        self,
+        mock_hook,
+    ):
+        op = WorkflowsDeleteWorkflowOperator(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.delete_workflow.assert_called_once_with(
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+
+class TestWorkflowsListWorkflowsOperator:
+    @mock.patch(BASE_PATH.format("Workflow"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_execute(self, mock_hook, mock_object):
+        workflow_mock = mock.MagicMock()
+        workflow_mock.start_time = datetime.datetime.now(tz=pytz.UTC) + datetime.timedelta(minutes=5)
+        mock_hook.return_value.list_workflows.return_value = [workflow_mock]
+
+        op = WorkflowsListWorkflowsOperator(
+            task_id="test_task",
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            filter_=FILTER_,
+            order_by=ORDER_BY,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.list_workflows.assert_called_once_with(
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            filter_=FILTER_,
+            order_by=ORDER_BY,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert result == [mock_object.to_dict.return_value]
+
+
+class TestWorkflowsGetWorkflowOperator:
+    @mock.patch(BASE_PATH.format("Workflow"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_execute(self, mock_hook, mock_object):
+        op = WorkflowsGetWorkflowOperator(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.get_workflow.assert_called_once_with(
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert result == mock_object.to_dict.return_value
+
+
+class TestWorkflowExecutionsCreateExecutionOperator:
+    @mock.patch(BASE_PATH.format("Execution"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    @mock.patch(BASE_PATH.format("WorkflowsCreateExecutionOperator.xcom_push"))
+    def test_execute(self, mock_xcom, mock_hook, mock_object):
+        mock_hook.return_value.create_execution.return_value.name = "name/execution_id"
+        op = WorkflowsCreateExecutionOperator(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            execution=EXECUTION,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.create_execution.assert_called_once_with(
+            workflow_id=WORKFLOW_ID,
+            execution=EXECUTION,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_xcom.assert_called_once_with({}, key="execution_id", value="execution_id")
+        assert result == mock_object.to_dict.return_value
+
+
+class TestWorkflowExecutionsCancelExecutionOperator:
+    @mock.patch(BASE_PATH.format("Execution"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_execute(self, mock_hook, mock_object):
+        op = WorkflowsCancelExecutionOperator(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.cancel_execution.assert_called_once_with(
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert result == mock_object.to_dict.return_value
+
+
+class TestWorkflowExecutionsListExecutionsOperator:
+    @mock.patch(BASE_PATH.format("Execution"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_execute(self, mock_hook, mock_object):
+        execution_mock = mock.MagicMock()
+        execution_mock.start_time = datetime.datetime.now(tz=pytz.UTC) + datetime.timedelta(minutes=5)
+        mock_hook.return_value.list_executions.return_value = [execution_mock]
+
+        op = WorkflowsListExecutionsOperator(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.list_executions.assert_called_once_with(
+            workflow_id=WORKFLOW_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert result == [mock_object.to_dict.return_value]
+
+
+class TestWorkflowExecutionsGetExecutionOperator:
+    @mock.patch(BASE_PATH.format("Execution"))
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_execute(self, mock_hook, mock_object):
+        op = WorkflowsGetExecutionOperator(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.execute({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.get_execution.assert_called_once_with(
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert result == mock_object.to_dict.return_value
diff --git a/tests/providers/google/cloud/operators/test_workflows_system.py b/tests/providers/google/cloud/operators/test_workflows_system.py
new file mode 100644
index 0000000..0a768ed
--- /dev/null
+++ b/tests/providers/google/cloud/operators/test_workflows_system.py
@@ -0,0 +1,29 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from tests.providers.google.cloud.utils.gcp_authenticator import GCP_WORKFLOWS_KEY
+from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
+
+
+@pytest.mark.system("google.cloud")
+@pytest.mark.credential_file(GCP_WORKFLOWS_KEY)
+class CloudVisionExampleDagsSystemTest(GoogleSystemTest):
+    @provide_gcp_context(GCP_WORKFLOWS_KEY)
+    def test_run_example_workflow_dag(self):
+        self.run_dag('example_cloud_workflows', CLOUD_DAG_FOLDER)
diff --git a/tests/providers/google/cloud/sensors/test_workflows.py b/tests/providers/google/cloud/sensors/test_workflows.py
new file mode 100644
index 0000000..56ad958
--- /dev/null
+++ b/tests/providers/google/cloud/sensors/test_workflows.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest import mock
+
+import pytest
+from google.cloud.workflows.executions_v1beta import Execution
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.sensors.workflows import WorkflowExecutionSensor
+
+BASE_PATH = "airflow.providers.google.cloud.sensors.workflows.{}"
+LOCATION = "europe-west1"
+WORKFLOW_ID = "workflow_id"
+EXECUTION_ID = "execution_id"
+PROJECT_ID = "airflow-testing"
+METADATA = None
+TIMEOUT = None
+RETRY = None
+GCP_CONN_ID = "test-conn"
+IMPERSONATION_CHAIN = None
+
+
+class TestWorkflowExecutionSensor:
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_poke_success(self, mock_hook):
+        mock_hook.return_value.get_execution.return_value = mock.MagicMock(state=Execution.State.SUCCEEDED)
+        op = WorkflowExecutionSensor(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            request_timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.poke({})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.get_execution.assert_called_once_with(
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        assert result is True
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_poke_wait(self, mock_hook):
+        mock_hook.return_value.get_execution.return_value = mock.MagicMock(state=Execution.State.ACTIVE)
+        op = WorkflowExecutionSensor(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            request_timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        result = op.poke({})
+
+        assert result is False
+
+    @mock.patch(BASE_PATH.format("WorkflowsHook"))
+    def test_poke_failure(self, mock_hook):
+        mock_hook.return_value.get_execution.return_value = mock.MagicMock(state=Execution.State.FAILED)
+        op = WorkflowExecutionSensor(
+            task_id="test_task",
+            workflow_id=WORKFLOW_ID,
+            execution_id=EXECUTION_ID,
+            location=LOCATION,
+            project_id=PROJECT_ID,
+            retry=RETRY,
+            request_timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        with pytest.raises(AirflowException):
+            op.poke({})
diff --git a/tests/providers/google/cloud/utils/gcp_authenticator.py b/tests/providers/google/cloud/utils/gcp_authenticator.py
index bf36ead..2fad48c 100644
--- a/tests/providers/google/cloud/utils/gcp_authenticator.py
+++ b/tests/providers/google/cloud/utils/gcp_authenticator.py
@@ -54,6 +54,7 @@ GCP_SECRET_MANAGER_KEY = 'gcp_secret_manager.json'
 GCP_SPANNER_KEY = 'gcp_spanner.json'
 GCP_STACKDDRIVER = 'gcp_stackdriver.json'
 GCP_TASKS_KEY = 'gcp_tasks.json'
+GCP_WORKFLOWS_KEY = "gcp_workflows.json"
 GMP_KEY = 'gmp.json'
 G_FIREBASE_KEY = 'g_firebase.json'
 GCP_AWS_KEY = 'gcp_aws.json'


[airflow] 21/28: Update to Pytest 6.0 (#14065)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit c34898b7886264439d667acfc9d003352e90e16f
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Thu Feb 4 12:57:51 2021 +0000

    Update to Pytest 6.0 (#14065)
    
    And pytest 6 removed a class that the rerunfailures plugin was using, so
    we have to upgrade that too.
    
    (cherry picked from commit 10c026cb7a7189d9573f30f2f2242f0f76842a72)
---
 setup.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/setup.py b/setup.py
index 7beb684..cd38ef2 100644
--- a/setup.py
+++ b/setup.py
@@ -506,10 +506,10 @@ devel = [
     'pre-commit',
     'pylint',
     'pysftp',
-    'pytest',
+    'pytest~=6.0',
     'pytest-cov',
     'pytest-instafail',
-    'pytest-rerunfailures',
+    'pytest-rerunfailures~=9.1',
     'pytest-timeouts',
     'pytest-xdist',
     'pywinrm',


[airflow] 19/28: Support google-cloud-monitoring>=2.0.0 (#13769)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 988a2a55c3be7c1121af322753ef12b341a73583
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Tue Feb 2 07:01:55 2021 +0100

    Support google-cloud-monitoring>=2.0.0 (#13769)
    
    (cherry picked from commit d2efb33239d36e58fb69066fd23779724cb11a90)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |   1 +
 .../cloud/example_dags/example_stackdriver.py      |  82 +++++--
 .../providers/google/cloud/hooks/stackdriver.py    | 133 +++++------
 .../google/cloud/operators/stackdriver.py          |  12 +-
 setup.py                                           |   2 +-
 .../google/cloud/hooks/test_stackdriver.py         | 242 +++++++++++----------
 .../google/cloud/operators/test_stackdriver.py     |  49 ++++-
 7 files changed, 302 insertions(+), 219 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index 16a6683..9cf9853 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -34,6 +34,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
 | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
 | [``google-cloud-dataproc``](https://pypi.org/project/google-cloud-dataproc/) | ``>=1.0.1,<2.0.0`` | ``>=2.2.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-dataproc/blob/master/UPGRADING.md) |
 | [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
+| [``google-cloud-monitoring``](https://pypi.org/project/google-cloud-monitoring/) | ``>=0.34.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-monitoring/blob/master/UPGRADING.md) |
 | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
 | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
 | [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) |
diff --git a/airflow/providers/google/cloud/example_dags/example_stackdriver.py b/airflow/providers/google/cloud/example_dags/example_stackdriver.py
index 68ac978..9c418b7 100644
--- a/airflow/providers/google/cloud/example_dags/example_stackdriver.py
+++ b/airflow/providers/google/cloud/example_dags/example_stackdriver.py
@@ -21,6 +21,7 @@ Example Airflow DAG for Google Cloud Stackdriver service.
 """
 
 import json
+import os
 
 from airflow import models
 from airflow.providers.google.cloud.operators.stackdriver import (
@@ -37,56 +38,80 @@ from airflow.providers.google.cloud.operators.stackdriver import (
 )
 from airflow.utils.dates import days_ago
 
+PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
+
 TEST_ALERT_POLICY_1 = {
     "combiner": "OR",
-    "name": "projects/sd-project/alertPolicies/12345",
-    "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"},
     "enabled": True,
-    "displayName": "test alert 1",
+    "display_name": "test alert 1",
     "conditions": [
         {
-            "conditionThreshold": {
+            "condition_threshold": {
+                "filter": (
+                    'metric.label.state="blocked" AND '
+                    'metric.type="agent.googleapis.com/processes/count_by_state" '
+                    'AND resource.type="gce_instance"'
+                ),
                 "comparison": "COMPARISON_GT",
-                "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}],
+                "threshold_value": 100,
+                "duration": {'seconds': 900},
+                "trigger": {"percent": 0},
+                "aggregations": [
+                    {
+                        "alignment_period": {'seconds': 60},
+                        "per_series_aligner": "ALIGN_MEAN",
+                        "cross_series_reducer": "REDUCE_MEAN",
+                        "group_by_fields": ["project", "resource.label.instance_id", "resource.label.zone"],
+                    }
+                ],
             },
-            "displayName": "Condition display",
-            "name": "projects/sd-project/alertPolicies/123/conditions/456",
+            "display_name": "test_alert_policy_1",
         }
     ],
 }
 
 TEST_ALERT_POLICY_2 = {
     "combiner": "OR",
-    "name": "projects/sd-project/alertPolicies/6789",
-    "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"},
     "enabled": False,
-    "displayName": "test alert 2",
+    "display_name": "test alert 2",
     "conditions": [
         {
-            "conditionThreshold": {
+            "condition_threshold": {
+                "filter": (
+                    'metric.label.state="blocked" AND '
+                    'metric.type="agent.googleapis.com/processes/count_by_state" AND '
+                    'resource.type="gce_instance"'
+                ),
                 "comparison": "COMPARISON_GT",
-                "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}],
+                "threshold_value": 100,
+                "duration": {'seconds': 900},
+                "trigger": {"percent": 0},
+                "aggregations": [
+                    {
+                        "alignment_period": {'seconds': 60},
+                        "per_series_aligner": "ALIGN_MEAN",
+                        "cross_series_reducer": "REDUCE_MEAN",
+                        "group_by_fields": ["project", "resource.label.instance_id", "resource.label.zone"],
+                    }
+                ],
             },
-            "displayName": "Condition display",
-            "name": "projects/sd-project/alertPolicies/456/conditions/789",
+            "display_name": "test_alert_policy_2",
         }
     ],
 }
 
 TEST_NOTIFICATION_CHANNEL_1 = {
-    "displayName": "channel1",
+    "display_name": "channel1",
     "enabled": True,
     "labels": {"auth_token": "top-secret", "channel_name": "#channel"},
-    "name": "projects/sd-project/notificationChannels/12345",
-    "type": "slack",
+    "type_": "slack",
 }
 
 TEST_NOTIFICATION_CHANNEL_2 = {
-    "displayName": "channel2",
+    "display_name": "channel2",
     "enabled": False,
     "labels": {"auth_token": "top-secret", "channel_name": "#channel"},
-    "name": "projects/sd-project/notificationChannels/6789",
-    "type": "slack",
+    "type_": "slack",
 }
 
 with models.DAG(
@@ -150,18 +175,29 @@ with models.DAG(
     # [START howto_operator_gcp_stackdriver_delete_notification_channel]
     delete_notification_channel = StackdriverDeleteNotificationChannelOperator(
         task_id='delete-notification-channel',
-        name='test-channel',
+        name="{{ task_instance.xcom_pull('list-notification-channel')[0]['name'] }}",
     )
     # [END howto_operator_gcp_stackdriver_delete_notification_channel]
 
+    delete_notification_channel_2 = StackdriverDeleteNotificationChannelOperator(
+        task_id='delete-notification-channel-2',
+        name="{{ task_instance.xcom_pull('list-notification-channel')[1]['name'] }}",
+    )
+
     # [START howto_operator_gcp_stackdriver_delete_alert_policy]
     delete_alert_policy = StackdriverDeleteAlertOperator(
         task_id='delete-alert-policy',
-        name='test-alert',
+        name="{{ task_instance.xcom_pull('list-alert-policies')[0]['name'] }}",
     )
     # [END howto_operator_gcp_stackdriver_delete_alert_policy]
 
+    delete_alert_policy_2 = StackdriverDeleteAlertOperator(
+        task_id='delete-alert-policy-2',
+        name="{{ task_instance.xcom_pull('list-alert-policies')[1]['name'] }}",
+    )
+
     create_notification_channel >> enable_notification_channel >> disable_notification_channel
     disable_notification_channel >> list_notification_channel >> create_alert_policy
     create_alert_policy >> enable_alert_policy >> disable_alert_policy >> list_alert_policies
-    list_alert_policies >> delete_notification_channel >> delete_alert_policy
+    list_alert_policies >> delete_notification_channel >> delete_notification_channel_2
+    delete_notification_channel_2 >> delete_alert_policy >> delete_alert_policy_2
diff --git a/airflow/providers/google/cloud/hooks/stackdriver.py b/airflow/providers/google/cloud/hooks/stackdriver.py
index 9da1afa..04dc329 100644
--- a/airflow/providers/google/cloud/hooks/stackdriver.py
+++ b/airflow/providers/google/cloud/hooks/stackdriver.py
@@ -24,7 +24,8 @@ from typing import Any, Optional, Sequence, Union
 from google.api_core.exceptions import InvalidArgument
 from google.api_core.gapic_v1.method import DEFAULT
 from google.cloud import monitoring_v3
-from google.protobuf.json_format import MessageToDict, MessageToJson, Parse
+from google.cloud.monitoring_v3 import AlertPolicy, NotificationChannel
+from google.protobuf.field_mask_pb2 import FieldMask
 from googleapiclient.errors import HttpError
 
 from airflow.exceptions import AirflowException
@@ -110,18 +111,20 @@ class StackdriverHook(GoogleBaseHook):
         """
         client = self._get_policy_client()
         policies_ = client.list_alert_policies(
-            name=f'projects/{project_id}',
-            filter_=filter_,
-            order_by=order_by,
-            page_size=page_size,
+            request={
+                'name': f'projects/{project_id}',
+                'filter': filter_,
+                'order_by': order_by,
+                'page_size': page_size,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         if format_ == "dict":
-            return [MessageToDict(policy) for policy in policies_]
+            return [AlertPolicy.to_dict(policy) for policy in policies_]
         elif format_ == "json":
-            return [MessageToJson(policy) for policy in policies_]
+            return [AlertPolicy.to_jsoon(policy) for policy in policies_]
         else:
             return policies_
 
@@ -138,12 +141,14 @@ class StackdriverHook(GoogleBaseHook):
         client = self._get_policy_client()
         policies_ = self.list_alert_policies(project_id=project_id, filter_=filter_)
         for policy in policies_:
-            if policy.enabled.value != bool(new_state):
-                policy.enabled.value = bool(new_state)
-                mask = monitoring_v3.types.field_mask_pb2.FieldMask()
-                mask.paths.append('enabled')  # pylint: disable=no-member
+            if policy.enabled != bool(new_state):
+                policy.enabled = bool(new_state)
+                mask = FieldMask(paths=['enabled'])
                 client.update_alert_policy(
-                    alert_policy=policy, update_mask=mask, retry=retry, timeout=timeout, metadata=metadata
+                    request={'alert_policy': policy, 'update_mask': mask},
+                    retry=retry,
+                    timeout=timeout,
+                    metadata=metadata or (),
                 )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -265,40 +270,39 @@ class StackdriverHook(GoogleBaseHook):
         ]
         policies_ = []
         channels = []
-
-        for channel in record["channels"]:
-            channel_json = json.dumps(channel)
-            channels.append(Parse(channel_json, monitoring_v3.types.notification_pb2.NotificationChannel()))
-        for policy in record["policies"]:
-            policy_json = json.dumps(policy)
-            policies_.append(Parse(policy_json, monitoring_v3.types.alert_pb2.AlertPolicy()))
+        for channel in record.get("channels", []):
+            channels.append(NotificationChannel(**channel))
+        for policy in record.get("policies", []):
+            policies_.append(AlertPolicy(**policy))
 
         channel_name_map = {}
 
         for channel in channels:
             channel.verification_status = (
-                monitoring_v3.enums.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
+                monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
             )
 
             if channel.name in existing_channels:
                 channel_client.update_notification_channel(
-                    notification_channel=channel, retry=retry, timeout=timeout, metadata=metadata
+                    request={'notification_channel': channel},
+                    retry=retry,
+                    timeout=timeout,
+                    metadata=metadata or (),
                 )
             else:
                 old_name = channel.name
-                channel.ClearField('name')
+                channel.name = None
                 new_channel = channel_client.create_notification_channel(
-                    name=f'projects/{project_id}',
-                    notification_channel=channel,
+                    request={'name': f'projects/{project_id}', 'notification_channel': channel},
                     retry=retry,
                     timeout=timeout,
-                    metadata=metadata,
+                    metadata=metadata or (),
                 )
                 channel_name_map[old_name] = new_channel.name
 
         for policy in policies_:
-            policy.ClearField('creation_record')
-            policy.ClearField('mutation_record')
+            policy.creation_record = None
+            policy.mutation_record = None
 
             for i, channel in enumerate(policy.notification_channels):
                 new_channel = channel_name_map.get(channel)
@@ -308,20 +312,22 @@ class StackdriverHook(GoogleBaseHook):
             if policy.name in existing_policies:
                 try:
                     policy_client.update_alert_policy(
-                        alert_policy=policy, retry=retry, timeout=timeout, metadata=metadata
+                        request={'alert_policy': policy},
+                        retry=retry,
+                        timeout=timeout,
+                        metadata=metadata or (),
                     )
                 except InvalidArgument:
                     pass
             else:
-                policy.ClearField('name')
+                policy.name = None
                 for condition in policy.conditions:
-                    condition.ClearField('name')
+                    condition.name = None
                 policy_client.create_alert_policy(
-                    name=f'projects/{project_id}',
-                    alert_policy=policy,
+                    request={'name': f'projects/{project_id}', 'alert_policy': policy},
                     retry=retry,
                     timeout=timeout,
-                    metadata=None,
+                    metadata=metadata or (),
                 )
 
     def delete_alert_policy(
@@ -349,7 +355,9 @@ class StackdriverHook(GoogleBaseHook):
         """
         policy_client = self._get_policy_client()
         try:
-            policy_client.delete_alert_policy(name=name, retry=retry, timeout=timeout, metadata=metadata)
+            policy_client.delete_alert_policy(
+                request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+            )
         except HttpError as err:
             raise AirflowException(f'Delete alerting policy failed. Error was {err.content}')
 
@@ -405,18 +413,20 @@ class StackdriverHook(GoogleBaseHook):
         """
         client = self._get_channel_client()
         channels = client.list_notification_channels(
-            name=f'projects/{project_id}',
-            filter_=filter_,
-            order_by=order_by,
-            page_size=page_size,
+            request={
+                'name': f'projects/{project_id}',
+                'filter': filter_,
+                'order_by': order_by,
+                'page_size': page_size,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         if format_ == "dict":
-            return [MessageToDict(channel) for channel in channels]
+            return [NotificationChannel.to_dict(channel) for channel in channels]
         elif format_ == "json":
-            return [MessageToJson(channel) for channel in channels]
+            return [NotificationChannel.to_json(channel) for channel in channels]
         else:
             return channels
 
@@ -431,18 +441,18 @@ class StackdriverHook(GoogleBaseHook):
         metadata: Optional[str] = None,
     ) -> None:
         client = self._get_channel_client()
-        channels = client.list_notification_channels(name=f'projects/{project_id}', filter_=filter_)
+        channels = client.list_notification_channels(
+            request={'name': f'projects/{project_id}', 'filter': filter_}
+        )
         for channel in channels:
-            if channel.enabled.value != bool(new_state):
-                channel.enabled.value = bool(new_state)
-                mask = monitoring_v3.types.field_mask_pb2.FieldMask()
-                mask.paths.append('enabled')  # pylint: disable=no-member
+            if channel.enabled != bool(new_state):
+                channel.enabled = bool(new_state)
+                mask = FieldMask(paths=['enabled'])
                 client.update_notification_channel(
-                    notification_channel=channel,
-                    update_mask=mask,
+                    request={'notification_channel': channel, 'update_mask': mask},
                     retry=retry,
                     timeout=timeout,
-                    metadata=metadata,
+                    metadata=metadata or (),
                 )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -518,7 +528,7 @@ class StackdriverHook(GoogleBaseHook):
             new_state=False,
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -562,29 +572,28 @@ class StackdriverHook(GoogleBaseHook):
         channel_name_map = {}
 
         for channel in record["channels"]:
-            channel_json = json.dumps(channel)
-            channels_list.append(
-                Parse(channel_json, monitoring_v3.types.notification_pb2.NotificationChannel())
-            )
+            channels_list.append(NotificationChannel(**channel))
 
         for channel in channels_list:
             channel.verification_status = (
-                monitoring_v3.enums.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
+                monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
             )
 
             if channel.name in existing_channels:
                 channel_client.update_notification_channel(
-                    notification_channel=channel, retry=retry, timeout=timeout, metadata=metadata
+                    request={'notification_channel': channel},
+                    retry=retry,
+                    timeout=timeout,
+                    metadata=metadata or (),
                 )
             else:
                 old_name = channel.name
-                channel.ClearField('name')
+                channel.name = None
                 new_channel = channel_client.create_notification_channel(
-                    name=f'projects/{project_id}',
-                    notification_channel=channel,
+                    request={'name': f'projects/{project_id}', 'notification_channel': channel},
                     retry=retry,
                     timeout=timeout,
-                    metadata=metadata,
+                    metadata=metadata or (),
                 )
                 channel_name_map[old_name] = new_channel.name
 
@@ -616,7 +625,7 @@ class StackdriverHook(GoogleBaseHook):
         channel_client = self._get_channel_client()
         try:
             channel_client.delete_notification_channel(
-                name=name, retry=retry, timeout=timeout, metadata=metadata
+                request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
             )
         except HttpError as err:
             raise AirflowException(f'Delete notification channel failed. Error was {err.content}')
diff --git a/airflow/providers/google/cloud/operators/stackdriver.py b/airflow/providers/google/cloud/operators/stackdriver.py
index dc86466..7289b12 100644
--- a/airflow/providers/google/cloud/operators/stackdriver.py
+++ b/airflow/providers/google/cloud/operators/stackdriver.py
@@ -19,6 +19,7 @@
 from typing import Optional, Sequence, Union
 
 from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud.monitoring_v3 import AlertPolicy, NotificationChannel
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.stackdriver import StackdriverHook
@@ -125,7 +126,7 @@ class StackdriverListAlertPoliciesOperator(BaseOperator):
 
     def execute(self, context):
         self.log.info(
-            'List Alert Policies: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %d',
+            'List Alert Policies: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %s',
             self.project_id,
             self.format_,
             self.filter_,
@@ -139,7 +140,7 @@ class StackdriverListAlertPoliciesOperator(BaseOperator):
                 impersonation_chain=self.impersonation_chain,
             )
 
-        return self.hook.list_alert_policies(
+        result = self.hook.list_alert_policies(
             project_id=self.project_id,
             format_=self.format_,
             filter_=self.filter_,
@@ -149,6 +150,7 @@ class StackdriverListAlertPoliciesOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
+        return [AlertPolicy.to_dict(policy) for policy in result]
 
 
 class StackdriverEnableAlertPoliciesOperator(BaseOperator):
@@ -614,7 +616,7 @@ class StackdriverListNotificationChannelsOperator(BaseOperator):
 
     def execute(self, context):
         self.log.info(
-            'List Notification Channels: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %d',
+            'List Notification Channels: Project id: %s Format: %s Filter: %s Order By: %s Page Size: %s',
             self.project_id,
             self.format_,
             self.filter_,
@@ -627,7 +629,7 @@ class StackdriverListNotificationChannelsOperator(BaseOperator):
                 delegate_to=self.delegate_to,
                 impersonation_chain=self.impersonation_chain,
             )
-        return self.hook.list_notification_channels(
+        channels = self.hook.list_notification_channels(
             format_=self.format_,
             project_id=self.project_id,
             filter_=self.filter_,
@@ -637,6 +639,8 @@ class StackdriverListNotificationChannelsOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
+        result = [NotificationChannel.to_dict(channel) for channel in channels]
+        return result
 
 
 class StackdriverEnableNotificationChannelsOperator(BaseOperator):
diff --git a/setup.py b/setup.py
index 0f40d88..fa1e73a 100644
--- a/setup.py
+++ b/setup.py
@@ -294,7 +294,7 @@ google = [
     'google-cloud-language>=1.1.1,<2.0.0',
     'google-cloud-logging>=1.14.0,<2.0.0',
     'google-cloud-memcache>=0.2.0',
-    'google-cloud-monitoring>=0.34.0,<2.0.0',
+    'google-cloud-monitoring>=2.0.0,<3.0.0',
     'google-cloud-os-login>=2.0.0,<3.0.0',
     'google-cloud-pubsub>=2.0.0,<3.0.0',
     'google-cloud-redis>=2.0.0,<3.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_stackdriver.py b/tests/providers/google/cloud/hooks/test_stackdriver.py
index 6892d05..10a3097 100644
--- a/tests/providers/google/cloud/hooks/test_stackdriver.py
+++ b/tests/providers/google/cloud/hooks/test_stackdriver.py
@@ -21,8 +21,8 @@ import unittest
 from unittest import mock
 
 from google.api_core.gapic_v1.method import DEFAULT
-from google.cloud import monitoring_v3
-from google.protobuf.json_format import ParseDict
+from google.cloud.monitoring_v3 import AlertPolicy, NotificationChannel
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow.providers.google.cloud.hooks import stackdriver
 
@@ -32,16 +32,15 @@ TEST_FILTER = "filter"
 TEST_ALERT_POLICY_1 = {
     "combiner": "OR",
     "name": "projects/sd-project/alertPolicies/12345",
-    "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"},
     "enabled": True,
-    "displayName": "test display",
+    "display_name": "test display",
     "conditions": [
         {
-            "conditionThreshold": {
+            "condition_threshold": {
                 "comparison": "COMPARISON_GT",
-                "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}],
+                "aggregations": [{"alignment_period": {'seconds': 60}, "per_series_aligner": "ALIGN_RATE"}],
             },
-            "displayName": "Condition display",
+            "display_name": "Condition display",
             "name": "projects/sd-project/alertPolicies/123/conditions/456",
         }
     ],
@@ -50,35 +49,34 @@ TEST_ALERT_POLICY_1 = {
 TEST_ALERT_POLICY_2 = {
     "combiner": "OR",
     "name": "projects/sd-project/alertPolicies/6789",
-    "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"},
     "enabled": False,
-    "displayName": "test display",
+    "display_name": "test display",
     "conditions": [
         {
-            "conditionThreshold": {
+            "condition_threshold": {
                 "comparison": "COMPARISON_GT",
-                "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}],
+                "aggregations": [{"alignment_period": {'seconds': 60}, "per_series_aligner": "ALIGN_RATE"}],
             },
-            "displayName": "Condition display",
+            "display_name": "Condition display",
             "name": "projects/sd-project/alertPolicies/456/conditions/789",
         }
     ],
 }
 
 TEST_NOTIFICATION_CHANNEL_1 = {
-    "displayName": "sd",
+    "display_name": "sd",
     "enabled": True,
     "labels": {"auth_token": "top-secret", "channel_name": "#channel"},
     "name": "projects/sd-project/notificationChannels/12345",
-    "type": "slack",
+    "type_": "slack",
 }
 
 TEST_NOTIFICATION_CHANNEL_2 = {
-    "displayName": "sd",
+    "display_name": "sd",
     "enabled": False,
     "labels": {"auth_token": "top-secret", "channel_name": "#channel"},
     "name": "projects/sd-project/notificationChannels/6789",
-    "type": "slack",
+    "type_": "slack",
 }
 
 
@@ -96,13 +94,10 @@ class TestStackdriverHookMethods(unittest.TestCase):
             project_id=PROJECT_ID,
         )
         method.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            filter_=TEST_FILTER,
+            request=dict(name=f'projects/{PROJECT_ID}', filter=TEST_FILTER, order_by=None, page_size=None),
             retry=DEFAULT,
             timeout=DEFAULT,
-            order_by=None,
-            page_size=None,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -113,8 +108,8 @@ class TestStackdriverHookMethods(unittest.TestCase):
     def test_stackdriver_enable_alert_policy(self, mock_policy_client, mock_get_creds_and_project_id):
         hook = stackdriver.StackdriverHook()
 
-        alert_policy_enabled = ParseDict(TEST_ALERT_POLICY_1, monitoring_v3.types.alert_pb2.AlertPolicy())
-        alert_policy_disabled = ParseDict(TEST_ALERT_POLICY_2, monitoring_v3.types.alert_pb2.AlertPolicy())
+        alert_policy_enabled = AlertPolicy(**TEST_ALERT_POLICY_1)
+        alert_policy_disabled = AlertPolicy(**TEST_ALERT_POLICY_2)
 
         alert_policies = [alert_policy_enabled, alert_policy_disabled]
 
@@ -124,23 +119,18 @@ class TestStackdriverHookMethods(unittest.TestCase):
             project_id=PROJECT_ID,
         )
         mock_policy_client.return_value.list_alert_policies.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            filter_=TEST_FILTER,
+            request=dict(name=f'projects/{PROJECT_ID}', filter=TEST_FILTER, order_by=None, page_size=None),
             retry=DEFAULT,
             timeout=DEFAULT,
-            order_by=None,
-            page_size=None,
-            metadata=None,
+            metadata=(),
         )
-        mask = monitoring_v3.types.field_mask_pb2.FieldMask()
-        alert_policy_disabled.enabled.value = True  # pylint: disable=no-member
-        mask.paths.append('enabled')  # pylint: disable=no-member
+        mask = FieldMask(paths=["enabled"])
+        alert_policy_disabled.enabled = True  # pylint: disable=no-member
         mock_policy_client.return_value.update_alert_policy.assert_called_once_with(
-            alert_policy=alert_policy_disabled,
-            update_mask=mask,
+            request=dict(alert_policy=alert_policy_disabled, update_mask=mask),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -150,8 +140,8 @@ class TestStackdriverHookMethods(unittest.TestCase):
     @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_policy_client')
     def test_stackdriver_disable_alert_policy(self, mock_policy_client, mock_get_creds_and_project_id):
         hook = stackdriver.StackdriverHook()
-        alert_policy_enabled = ParseDict(TEST_ALERT_POLICY_1, monitoring_v3.types.alert_pb2.AlertPolicy())
-        alert_policy_disabled = ParseDict(TEST_ALERT_POLICY_2, monitoring_v3.types.alert_pb2.AlertPolicy())
+        alert_policy_enabled = AlertPolicy(**TEST_ALERT_POLICY_1)
+        alert_policy_disabled = AlertPolicy(**TEST_ALERT_POLICY_2)
 
         mock_policy_client.return_value.list_alert_policies.return_value = [
             alert_policy_enabled,
@@ -162,23 +152,18 @@ class TestStackdriverHookMethods(unittest.TestCase):
             project_id=PROJECT_ID,
         )
         mock_policy_client.return_value.list_alert_policies.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            filter_=TEST_FILTER,
+            request=dict(name=f'projects/{PROJECT_ID}', filter=TEST_FILTER, order_by=None, page_size=None),
             retry=DEFAULT,
             timeout=DEFAULT,
-            order_by=None,
-            page_size=None,
-            metadata=None,
+            metadata=(),
         )
-        mask = monitoring_v3.types.field_mask_pb2.FieldMask()
-        alert_policy_enabled.enabled.value = False  # pylint: disable=no-member
-        mask.paths.append('enabled')  # pylint: disable=no-member
+        mask = FieldMask(paths=["enabled"])
+        alert_policy_enabled.enabled = False  # pylint: disable=no-member
         mock_policy_client.return_value.update_alert_policy.assert_called_once_with(
-            alert_policy=alert_policy_enabled,
-            update_mask=mask,
+            request=dict(alert_policy=alert_policy_enabled, update_mask=mask),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -191,8 +176,8 @@ class TestStackdriverHookMethods(unittest.TestCase):
         self, mock_channel_client, mock_policy_client, mock_get_creds_and_project_id
     ):
         hook = stackdriver.StackdriverHook()
-        existing_alert_policy = ParseDict(TEST_ALERT_POLICY_1, monitoring_v3.types.alert_pb2.AlertPolicy())
-        alert_policy_to_create = ParseDict(TEST_ALERT_POLICY_2, monitoring_v3.types.alert_pb2.AlertPolicy())
+        existing_alert_policy = AlertPolicy(**TEST_ALERT_POLICY_1)
+        alert_policy_to_create = AlertPolicy(**TEST_ALERT_POLICY_2)
 
         mock_policy_client.return_value.list_alert_policies.return_value = [existing_alert_policy]
         mock_channel_client.return_value.list_notification_channels.return_value = []
@@ -202,38 +187,77 @@ class TestStackdriverHookMethods(unittest.TestCase):
             project_id=PROJECT_ID,
         )
         mock_channel_client.return_value.list_notification_channels.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            filter_=None,
+            request=dict(
+                name=f'projects/{PROJECT_ID}',
+                filter=None,
+                order_by=None,
+                page_size=None,
+            ),
             retry=DEFAULT,
             timeout=DEFAULT,
-            order_by=None,
-            page_size=None,
-            metadata=None,
+            metadata=(),
         )
         mock_policy_client.return_value.list_alert_policies.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            filter_=None,
+            request=dict(name=f'projects/{PROJECT_ID}', filter=None, order_by=None, page_size=None),
             retry=DEFAULT,
             timeout=DEFAULT,
-            order_by=None,
-            page_size=None,
-            metadata=None,
+            metadata=(),
         )
-        alert_policy_to_create.ClearField('name')
-        alert_policy_to_create.ClearField('creation_record')
-        alert_policy_to_create.ClearField('mutation_record')
-        alert_policy_to_create.conditions[0].ClearField('name')  # pylint: disable=no-member
+        alert_policy_to_create.name = None
+        alert_policy_to_create.creation_record = None
+        alert_policy_to_create.mutation_record = None
+        alert_policy_to_create.conditions[0].name = None
         mock_policy_client.return_value.create_alert_policy.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            alert_policy=alert_policy_to_create,
+            request=dict(
+                name=f'projects/{PROJECT_ID}',
+                alert_policy=alert_policy_to_create,
+            ),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
-        existing_alert_policy.ClearField('creation_record')
-        existing_alert_policy.ClearField('mutation_record')
+        existing_alert_policy.creation_record = None
+        existing_alert_policy.mutation_record = None
         mock_policy_client.return_value.update_alert_policy.assert_called_once_with(
-            alert_policy=existing_alert_policy, retry=DEFAULT, timeout=DEFAULT, metadata=None
+            request=dict(alert_policy=existing_alert_policy), retry=DEFAULT, timeout=DEFAULT, metadata=()
+        )
+
+    @mock.patch(
+        'airflow.providers.google.common.hooks.base_google.GoogleBaseHook._get_credentials_and_project_id',
+        return_value=(CREDENTIALS, PROJECT_ID),
+    )
+    @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_policy_client')
+    @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_channel_client')
+    def test_stackdriver_upsert_alert_policy_without_channel(
+        self, mock_channel_client, mock_policy_client, mock_get_creds_and_project_id
+    ):
+        hook = stackdriver.StackdriverHook()
+        existing_alert_policy = AlertPolicy(**TEST_ALERT_POLICY_1)
+
+        mock_policy_client.return_value.list_alert_policies.return_value = [existing_alert_policy]
+        mock_channel_client.return_value.list_notification_channels.return_value = []
+
+        hook.upsert_alert(
+            alerts=json.dumps({"policies": [TEST_ALERT_POLICY_1, TEST_ALERT_POLICY_2]}),
+            project_id=PROJECT_ID,
+        )
+        mock_channel_client.return_value.list_notification_channels.assert_called_once_with(
+            request=dict(name=f'projects/{PROJECT_ID}', filter=None, order_by=None, page_size=None),
+            metadata=(),
+            retry=DEFAULT,
+            timeout=DEFAULT,
+        )
+        mock_policy_client.return_value.list_alert_policies.assert_called_once_with(
+            request=dict(name=f'projects/{PROJECT_ID}', filter=None, order_by=None, page_size=None),
+            retry=DEFAULT,
+            timeout=DEFAULT,
+            metadata=(),
+        )
+
+        existing_alert_policy.creation_record = None
+        existing_alert_policy.mutation_record = None
+        mock_policy_client.return_value.update_alert_policy.assert_called_once_with(
+            request=dict(alert_policy=existing_alert_policy), retry=DEFAULT, timeout=DEFAULT, metadata=()
         )
 
     @mock.patch(
@@ -247,10 +271,10 @@ class TestStackdriverHookMethods(unittest.TestCase):
             name='test-alert',
         )
         mock_policy_client.return_value.delete_alert_policy.assert_called_once_with(
-            name='test-alert',
+            request=dict(name='test-alert'),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -265,13 +289,10 @@ class TestStackdriverHookMethods(unittest.TestCase):
             project_id=PROJECT_ID,
         )
         mock_channel_client.return_value.list_notification_channels.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            filter_=TEST_FILTER,
-            order_by=None,
-            page_size=None,
+            request=dict(name=f'projects/{PROJECT_ID}', filter=TEST_FILTER, order_by=None, page_size=None),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -283,12 +304,9 @@ class TestStackdriverHookMethods(unittest.TestCase):
         self, mock_channel_client, mock_get_creds_and_project_id
     ):
         hook = stackdriver.StackdriverHook()
-        notification_channel_enabled = ParseDict(
-            TEST_NOTIFICATION_CHANNEL_1, monitoring_v3.types.notification_pb2.NotificationChannel()
-        )
-        notification_channel_disabled = ParseDict(
-            TEST_NOTIFICATION_CHANNEL_2, monitoring_v3.types.notification_pb2.NotificationChannel()
-        )
+        notification_channel_enabled = NotificationChannel(**TEST_NOTIFICATION_CHANNEL_1)
+        notification_channel_disabled = NotificationChannel(**TEST_NOTIFICATION_CHANNEL_2)
+
         mock_channel_client.return_value.list_notification_channels.return_value = [
             notification_channel_enabled,
             notification_channel_disabled,
@@ -299,15 +317,13 @@ class TestStackdriverHookMethods(unittest.TestCase):
             project_id=PROJECT_ID,
         )
 
-        notification_channel_disabled.enabled.value = True  # pylint: disable=no-member
-        mask = monitoring_v3.types.field_mask_pb2.FieldMask()
-        mask.paths.append('enabled')  # pylint: disable=no-member
+        notification_channel_disabled.enabled = True  # pylint: disable=no-member
+        mask = FieldMask(paths=['enabled'])
         mock_channel_client.return_value.update_notification_channel.assert_called_once_with(
-            notification_channel=notification_channel_disabled,
-            update_mask=mask,
+            request=dict(notification_channel=notification_channel_disabled, update_mask=mask),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -319,12 +335,8 @@ class TestStackdriverHookMethods(unittest.TestCase):
         self, mock_channel_client, mock_get_creds_and_project_id
     ):
         hook = stackdriver.StackdriverHook()
-        notification_channel_enabled = ParseDict(
-            TEST_NOTIFICATION_CHANNEL_1, monitoring_v3.types.notification_pb2.NotificationChannel()
-        )
-        notification_channel_disabled = ParseDict(
-            TEST_NOTIFICATION_CHANNEL_2, monitoring_v3.types.notification_pb2.NotificationChannel()
-        )
+        notification_channel_enabled = NotificationChannel(**TEST_NOTIFICATION_CHANNEL_1)
+        notification_channel_disabled = NotificationChannel(**TEST_NOTIFICATION_CHANNEL_2)
         mock_channel_client.return_value.list_notification_channels.return_value = [
             notification_channel_enabled,
             notification_channel_disabled,
@@ -335,15 +347,13 @@ class TestStackdriverHookMethods(unittest.TestCase):
             project_id=PROJECT_ID,
         )
 
-        notification_channel_enabled.enabled.value = False  # pylint: disable=no-member
-        mask = monitoring_v3.types.field_mask_pb2.FieldMask()
-        mask.paths.append('enabled')  # pylint: disable=no-member
+        notification_channel_enabled.enabled = False  # pylint: disable=no-member
+        mask = FieldMask(paths=['enabled'])
         mock_channel_client.return_value.update_notification_channel.assert_called_once_with(
-            notification_channel=notification_channel_enabled,
-            update_mask=mask,
+            request=dict(notification_channel=notification_channel_enabled, update_mask=mask),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -353,12 +363,9 @@ class TestStackdriverHookMethods(unittest.TestCase):
     @mock.patch('airflow.providers.google.cloud.hooks.stackdriver.StackdriverHook._get_channel_client')
     def test_stackdriver_upsert_channel(self, mock_channel_client, mock_get_creds_and_project_id):
         hook = stackdriver.StackdriverHook()
-        existing_notification_channel = ParseDict(
-            TEST_NOTIFICATION_CHANNEL_1, monitoring_v3.types.notification_pb2.NotificationChannel()
-        )
-        notification_channel_to_be_created = ParseDict(
-            TEST_NOTIFICATION_CHANNEL_2, monitoring_v3.types.notification_pb2.NotificationChannel()
-        )
+        existing_notification_channel = NotificationChannel(**TEST_NOTIFICATION_CHANNEL_1)
+        notification_channel_to_be_created = NotificationChannel(**TEST_NOTIFICATION_CHANNEL_2)
+
         mock_channel_client.return_value.list_notification_channels.return_value = [
             existing_notification_channel
         ]
@@ -367,24 +374,25 @@ class TestStackdriverHookMethods(unittest.TestCase):
             project_id=PROJECT_ID,
         )
         mock_channel_client.return_value.list_notification_channels.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            filter_=None,
-            order_by=None,
-            page_size=None,
+            request=dict(name=f'projects/{PROJECT_ID}', filter=None, order_by=None, page_size=None),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
         mock_channel_client.return_value.update_notification_channel.assert_called_once_with(
-            notification_channel=existing_notification_channel, retry=DEFAULT, timeout=DEFAULT, metadata=None
+            request=dict(notification_channel=existing_notification_channel),
+            retry=DEFAULT,
+            timeout=DEFAULT,
+            metadata=(),
         )
-        notification_channel_to_be_created.ClearField('name')
+        notification_channel_to_be_created.name = None
         mock_channel_client.return_value.create_notification_channel.assert_called_once_with(
-            name=f'projects/{PROJECT_ID}',
-            notification_channel=notification_channel_to_be_created,
+            request=dict(
+                name=f'projects/{PROJECT_ID}', notification_channel=notification_channel_to_be_created
+            ),
             retry=DEFAULT,
             timeout=DEFAULT,
-            metadata=None,
+            metadata=(),
         )
 
     @mock.patch(
@@ -400,5 +408,5 @@ class TestStackdriverHookMethods(unittest.TestCase):
             name='test-channel',
         )
         mock_channel_client.return_value.delete_notification_channel.assert_called_once_with(
-            name='test-channel', retry=DEFAULT, timeout=DEFAULT, metadata=None
+            request=dict(name='test-channel'), retry=DEFAULT, timeout=DEFAULT, metadata=()
         )
diff --git a/tests/providers/google/cloud/operators/test_stackdriver.py b/tests/providers/google/cloud/operators/test_stackdriver.py
index 28901b4..50dd997 100644
--- a/tests/providers/google/cloud/operators/test_stackdriver.py
+++ b/tests/providers/google/cloud/operators/test_stackdriver.py
@@ -21,6 +21,7 @@ import unittest
 from unittest import mock
 
 from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud.monitoring_v3 import AlertPolicy, NotificationChannel
 
 from airflow.providers.google.cloud.operators.stackdriver import (
     StackdriverDeleteAlertOperator,
@@ -40,16 +41,15 @@ TEST_FILTER = 'filter'
 TEST_ALERT_POLICY_1 = {
     "combiner": "OR",
     "name": "projects/sd-project/alertPolicies/12345",
-    "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"},
     "enabled": True,
-    "displayName": "test display",
+    "display_name": "test display",
     "conditions": [
         {
-            "conditionThreshold": {
+            "condition_threshold": {
                 "comparison": "COMPARISON_GT",
-                "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}],
+                "aggregations": [{"alignment_eriod": {'seconds': 60}, "per_series_aligner": "ALIGN_RATE"}],
             },
-            "displayName": "Condition display",
+            "display_name": "Condition display",
             "name": "projects/sd-project/alertPolicies/123/conditions/456",
         }
     ],
@@ -58,16 +58,15 @@ TEST_ALERT_POLICY_1 = {
 TEST_ALERT_POLICY_2 = {
     "combiner": "OR",
     "name": "projects/sd-project/alertPolicies/6789",
-    "creationRecord": {"mutatedBy": "user123", "mutateTime": "2020-01-01T00:00:00.000000Z"},
     "enabled": False,
-    "displayName": "test display",
+    "display_name": "test display",
     "conditions": [
         {
-            "conditionThreshold": {
+            "condition_threshold": {
                 "comparison": "COMPARISON_GT",
-                "aggregations": [{"alignmentPeriod": "60s", "perSeriesAligner": "ALIGN_RATE"}],
+                "aggregations": [{"alignment_period": {'seconds': 60}, "per_series_aligner": "ALIGN_RATE"}],
             },
-            "displayName": "Condition display",
+            "display_name": "Condition display",
             "name": "projects/sd-project/alertPolicies/456/conditions/789",
         }
     ],
@@ -94,7 +93,8 @@ class TestStackdriverListAlertPoliciesOperator(unittest.TestCase):
     @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook')
     def test_execute(self, mock_hook):
         operator = StackdriverListAlertPoliciesOperator(task_id=TEST_TASK_ID, filter_=TEST_FILTER)
-        operator.execute(None)
+        mock_hook.return_value.list_alert_policies.return_value = [AlertPolicy(name="test-name")]
+        result = operator.execute(None)
         mock_hook.return_value.list_alert_policies.assert_called_once_with(
             project_id=None,
             filter_=TEST_FILTER,
@@ -105,6 +105,16 @@ class TestStackdriverListAlertPoliciesOperator(unittest.TestCase):
             timeout=DEFAULT,
             metadata=None,
         )
+        assert [
+            {
+                'combiner': 0,
+                'conditions': [],
+                'display_name': '',
+                'name': 'test-name',
+                'notification_channels': [],
+                'user_labels': {},
+            }
+        ] == result
 
 
 class TestStackdriverEnableAlertPoliciesOperator(unittest.TestCase):
@@ -160,7 +170,11 @@ class TestStackdriverListNotificationChannelsOperator(unittest.TestCase):
     @mock.patch('airflow.providers.google.cloud.operators.stackdriver.StackdriverHook')
     def test_execute(self, mock_hook):
         operator = StackdriverListNotificationChannelsOperator(task_id=TEST_TASK_ID, filter_=TEST_FILTER)
-        operator.execute(None)
+        mock_hook.return_value.list_notification_channels.return_value = [
+            NotificationChannel(name="test-123")
+        ]
+
+        result = operator.execute(None)
         mock_hook.return_value.list_notification_channels.assert_called_once_with(
             project_id=None,
             filter_=TEST_FILTER,
@@ -171,6 +185,17 @@ class TestStackdriverListNotificationChannelsOperator(unittest.TestCase):
             timeout=DEFAULT,
             metadata=None,
         )
+        assert [
+            {
+                'description': '',
+                'display_name': '',
+                'labels': {},
+                'name': 'test-123',
+                'type_': '',
+                'user_labels': {},
+                'verification_status': 0,
+            }
+        ] == result
 
 
 class TestStackdriverEnableNotificationChannelsOperator(unittest.TestCase):


[airflow] 09/28: Update compatibility with google-cloud-kms>=2.0 (#13124)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit fae32c36bdb0f56587743d5ecfce02af92990042
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Tue Dec 22 12:59:27 2020 +0100

    Update compatibility with google-cloud-kms>=2.0 (#13124)
    
    (cherry picked from commit b26b0df5b03c4cd826fd7b2dff5771d64e18e6b7)
---
 airflow/providers/google/cloud/hooks/kms.py    | 20 +++++++------
 setup.py                                       |  2 +-
 tests/providers/google/cloud/hooks/test_kms.py | 40 +++++++++++++++-----------
 3 files changed, 37 insertions(+), 25 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/kms.py b/airflow/providers/google/cloud/hooks/kms.py
index e63c2f1..3fd1433 100644
--- a/airflow/providers/google/cloud/hooks/kms.py
+++ b/airflow/providers/google/cloud/hooks/kms.py
@@ -118,12 +118,14 @@ class CloudKMSHook(GoogleBaseHook):
         :rtype: str
         """
         response = self.get_conn().encrypt(
-            name=key_name,
-            plaintext=plaintext,
-            additional_authenticated_data=authenticated_data,
+            request={
+                'name': key_name,
+                'plaintext': plaintext,
+                'additional_authenticated_data': authenticated_data,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
         ciphertext = _b64encode(response.ciphertext)
@@ -161,12 +163,14 @@ class CloudKMSHook(GoogleBaseHook):
         :rtype: bytes
         """
         response = self.get_conn().decrypt(
-            name=key_name,
-            ciphertext=_b64decode(ciphertext),
-            additional_authenticated_data=authenticated_data,
+            request={
+                'name': key_name,
+                'ciphertext': _b64decode(ciphertext),
+                'additional_authenticated_data': authenticated_data,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
         return response.plaintext
diff --git a/setup.py b/setup.py
index 63dd6d7..1ec4f5d 100644
--- a/setup.py
+++ b/setup.py
@@ -290,7 +290,7 @@ google = [
     'google-cloud-datacatalog>=1.0.0,<2.0.0',
     'google-cloud-dataproc>=1.0.1,<2.0.0',
     'google-cloud-dlp>=0.11.0,<2.0.0',
-    'google-cloud-kms>=1.2.1,<2.0.0',
+    'google-cloud-kms>=2.0.0,<3.0.0',
     'google-cloud-language>=1.1.1,<2.0.0',
     'google-cloud-logging>=1.14.0,<2.0.0',
     'google-cloud-memcache>=0.2.0',
diff --git a/tests/providers/google/cloud/hooks/test_kms.py b/tests/providers/google/cloud/hooks/test_kms.py
index 6b87e3c..4de1dfb 100644
--- a/tests/providers/google/cloud/hooks/test_kms.py
+++ b/tests/providers/google/cloud/hooks/test_kms.py
@@ -82,12 +82,14 @@ class TestCloudKMSHook(unittest.TestCase):
         result = self.kms_hook.encrypt(TEST_KEY_ID, PLAINTEXT)
         mock_get_conn.assert_called_once_with()
         mock_get_conn.return_value.encrypt.assert_called_once_with(
-            name=TEST_KEY_ID,
-            plaintext=PLAINTEXT,
-            additional_authenticated_data=None,
+            request=dict(
+                name=TEST_KEY_ID,
+                plaintext=PLAINTEXT,
+                additional_authenticated_data=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert PLAINTEXT_b64 == result
 
@@ -97,12 +99,14 @@ class TestCloudKMSHook(unittest.TestCase):
         result = self.kms_hook.encrypt(TEST_KEY_ID, PLAINTEXT, AUTH_DATA)
         mock_get_conn.assert_called_once_with()
         mock_get_conn.return_value.encrypt.assert_called_once_with(
-            name=TEST_KEY_ID,
-            plaintext=PLAINTEXT,
-            additional_authenticated_data=AUTH_DATA,
+            request=dict(
+                name=TEST_KEY_ID,
+                plaintext=PLAINTEXT,
+                additional_authenticated_data=AUTH_DATA,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert PLAINTEXT_b64 == result
 
@@ -112,12 +116,14 @@ class TestCloudKMSHook(unittest.TestCase):
         result = self.kms_hook.decrypt(TEST_KEY_ID, CIPHERTEXT_b64)
         mock_get_conn.assert_called_once_with()
         mock_get_conn.return_value.decrypt.assert_called_once_with(
-            name=TEST_KEY_ID,
-            ciphertext=CIPHERTEXT,
-            additional_authenticated_data=None,
+            request=dict(
+                name=TEST_KEY_ID,
+                ciphertext=CIPHERTEXT,
+                additional_authenticated_data=None,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert PLAINTEXT == result
 
@@ -127,11 +133,13 @@ class TestCloudKMSHook(unittest.TestCase):
         result = self.kms_hook.decrypt(TEST_KEY_ID, CIPHERTEXT_b64, AUTH_DATA)
         mock_get_conn.assert_called_once_with()
         mock_get_conn.return_value.decrypt.assert_called_once_with(
-            name=TEST_KEY_ID,
-            ciphertext=CIPHERTEXT,
-            additional_authenticated_data=AUTH_DATA,
+            request=dict(
+                name=TEST_KEY_ID,
+                ciphertext=CIPHERTEXT,
+                additional_authenticated_data=AUTH_DATA,
+            ),
             retry=None,
             timeout=None,
-            metadata=None,
+            metadata=(),
         )
         assert PLAINTEXT == result


[airflow] 04/28: Add Apache Beam operators (#12814)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit c6ccaa5d4b4ce55e4469c2aa20ef434979747504
Author: Tobiasz Kędzierski <to...@polidea.com>
AuthorDate: Wed Feb 3 21:34:01 2021 +0100

    Add Apache Beam operators (#12814)
    
    (cherry picked from commit 1872d8719d24f94aeb1dcba9694837070b9884ca)
---
 CONTRIBUTING.rst                                   |  17 +-
 INSTALL                                            |  14 +-
 .../apache/beam/BACKPORT_PROVIDER_README.md        |  99 +++
 airflow/providers/apache/beam/CHANGELOG.rst        |  25 +
 airflow/providers/apache/beam/README.md            |  97 +++
 airflow/providers/apache/beam/__init__.py          |  17 +
 .../providers/apache/beam/example_dags/__init__.py |  17 +
 .../apache/beam/example_dags/example_beam.py       | 315 +++++++++
 airflow/providers/apache/beam/hooks/__init__.py    |  17 +
 airflow/providers/apache/beam/hooks/beam.py        | 289 ++++++++
 .../providers/apache/beam/operators/__init__.py    |  17 +
 airflow/providers/apache/beam/operators/beam.py    | 446 ++++++++++++
 airflow/providers/apache/beam/provider.yaml        |  45 ++
 airflow/providers/dependencies.json                |   4 +
 airflow/providers/google/cloud/hooks/dataflow.py   | 330 ++++-----
 .../providers/google/cloud/operators/dataflow.py   | 331 +++++++--
 .../copy_provider_package_sources.py               |  62 ++
 dev/provider_packages/prepare_provider_packages.py |   4 +-
 .../apache-airflow-providers-apache-beam/index.rst |  36 +
 .../operators.rst                                  | 116 ++++
 docs/apache-airflow/extra-packages-ref.rst         |   2 +
 docs/spelling_wordlist.txt                         |   2 +
 .../run_install_and_test_provider_packages.sh      |   2 +-
 setup.py                                           |   1 +
 tests/core/test_providers_manager.py               |   1 +
 tests/providers/apache/beam/__init__.py            |  16 +
 tests/providers/apache/beam/hooks/__init__.py      |  16 +
 tests/providers/apache/beam/hooks/test_beam.py     | 271 ++++++++
 tests/providers/apache/beam/operators/__init__.py  |  16 +
 tests/providers/apache/beam/operators/test_beam.py | 274 ++++++++
 .../apache/beam/operators/test_beam_system.py      |  47 ++
 .../providers/google/cloud/hooks/test_dataflow.py  | 760 ++++++++++++---------
 .../google/cloud/operators/test_dataflow.py        | 223 ++++--
 .../google/cloud/operators/test_mlengine_utils.py  |  30 +-
 34 files changed, 3263 insertions(+), 696 deletions(-)

diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 6d0e224..0a6f381 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -572,13 +572,13 @@ This is the full list of those extras:
 
   .. START EXTRAS HERE
 
-all, all_dbs, amazon, apache.atlas, apache.cassandra, apache.druid, apache.hdfs, apache.hive,
-apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop, apache.webhdfs,
-async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes, crypto, dask,
-databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker,
-druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google, google_auth,
-grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
-microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
+all, all_dbs, amazon, apache.atlas, apache.beam, apache.cassandra, apache.druid, apache.hdfs,
+apache.hive, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop,
+apache.webhdfs, async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes,
+crypto, dask, databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc,
+docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google,
+google_auth, grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes,
+ldap, microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
 opsgenie, oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole,
 rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack,
 snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica, virtualenv, webhdfs, winrm,
@@ -641,12 +641,13 @@ Here is the list of packages and their extras:
 Package                    Extras
 ========================== ===========================
 amazon                     apache.hive,google,imap,mongo,mysql,postgres,ssh
+apache.beam                google
 apache.druid               apache.hive
 apache.hive                amazon,microsoft.mssql,mysql,presto,samba,vertica
 apache.livy                http
 dingding                   http
 discord                    http
-google                     amazon,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,postgres,presto,salesforce,sftp,ssh
+google                     amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,postgres,presto,salesforce,sftp,ssh
 hashicorp                  google
 microsoft.azure            google,oracle
 microsoft.mssql            odbc
diff --git a/INSTALL b/INSTALL
index e1ef456..d175aa1 100644
--- a/INSTALL
+++ b/INSTALL
@@ -97,13 +97,13 @@ The list of available extras:
 
 # START EXTRAS HERE
 
-all, all_dbs, amazon, apache.atlas, apache.cassandra, apache.druid, apache.hdfs, apache.hive,
-apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop, apache.webhdfs,
-async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes, crypto, dask,
-databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker,
-druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google, google_auth,
-grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
-microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
+all, all_dbs, amazon, apache.atlas, apache.beam, apache.cassandra, apache.druid, apache.hdfs,
+apache.hive, apache.kylin, apache.livy, apache.pig, apache.pinot, apache.spark, apache.sqoop,
+apache.webhdfs, async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes,
+crypto, dask, databricks, datadog, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc,
+docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github_enterprise, google,
+google_auth, grpc, hashicorp, hdfs, hive, http, imap, jdbc, jenkins, jira, kerberos, kubernetes,
+ldap, microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas,
 opsgenie, oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole,
 rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, sftp, singularity, slack,
 snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica, virtualenv, webhdfs, winrm,
diff --git a/airflow/providers/apache/beam/BACKPORT_PROVIDER_README.md b/airflow/providers/apache/beam/BACKPORT_PROVIDER_README.md
new file mode 100644
index 0000000..d0908b6
--- /dev/null
+++ b/airflow/providers/apache/beam/BACKPORT_PROVIDER_README.md
@@ -0,0 +1,99 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
+
+
+# Package apache-airflow-backport-providers-apache-beam
+
+Release:
+
+**Table of contents**
+
+- [Backport package](#backport-package)
+- [Installation](#installation)
+- [PIP requirements](#pip-requirements)
+- [Cross provider package dependencies](#cross-provider-package-dependencies)
+- [Provider class summary](#provider-classes-summary)
+    - [Operators](#operators)
+        - [Moved operators](#moved-operators)
+    - [Transfer operators](#transfer-operators)
+        - [Moved transfer operators](#moved-transfer-operators)
+    - [Hooks](#hooks)
+        - [Moved hooks](#moved-hooks)
+- [Releases](#releases)
+    - [Release](#release)
+
+## Backport package
+
+This is a backport providers package for `apache.beam` provider. All classes for this provider package
+are in `airflow.providers.apache.beam` python package.
+
+**Only Python 3.6+ is supported for this backport package.**
+
+While Airflow 1.10.* continues to support Python 2.7+ - you need to upgrade python to 3.6+ if you
+want to use this backport package.
+
+
+## Installation
+
+You can install this package on top of an existing airflow 1.10.* installation via
+`pip install apache-airflow-backport-providers-apache-beam`
+
+## Cross provider package dependencies
+
+Those are dependencies that might be needed in order to use all the features of the package.
+You need to install the specified backport providers package in order to use them.
+
+You can install such cross-provider dependencies when installing from PyPI. For example:
+
+```bash
+pip install apache-airflow-beckport-providers-apache-beam[google]
+```
+
+| Dependent package                                                                                         | Extra       |
+|:----------------------------------------------------------------------------------------------------------|:------------|
+| [apache-airflow-providers-apache-google](https://pypi.org/project/apache-airflow-providers-apache-google) | google      |
+
+
+# Provider classes summary
+
+In Airflow 2.0, all operators, transfers, hooks, sensors, secrets for the `apache.beam` provider
+are in the `airflow.providers.apache.beam` package. You can read more about the naming conventions used
+in [Naming conventions for provider packages](https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#naming-conventions-for-provider-packages)
+
+
+## Operators
+
+### New operators
+
+| New Airflow 2.0 operators: `airflow.providers.apache.beam` package                                                                                                                 |
+|:-----------------------------------------------------------------------------------------------------------------------------------------------|
+| [operators.beam.BeamRunJavaPipelineOperator](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py)    |
+| [operators.beam.BeamRunPythonPipelineOperator](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py)  |
+
+
+## Hooks
+
+### New hooks
+
+| New Airflow 2.0 hooks: `airflow.providers.apache.beam` package                                                   |
+|:-----------------------------------------------------------------------------------------------------------------|
+| [hooks.beam.BeamHook](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/hooks/beam.py) |
+
+
+## Releases
diff --git a/airflow/providers/apache/beam/CHANGELOG.rst b/airflow/providers/apache/beam/CHANGELOG.rst
new file mode 100644
index 0000000..cef7dda
--- /dev/null
+++ b/airflow/providers/apache/beam/CHANGELOG.rst
@@ -0,0 +1,25 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+
+Changelog
+---------
+
+1.0.0
+.....
+
+Initial version of the provider.
diff --git a/airflow/providers/apache/beam/README.md b/airflow/providers/apache/beam/README.md
new file mode 100644
index 0000000..3aa0ead
--- /dev/null
+++ b/airflow/providers/apache/beam/README.md
@@ -0,0 +1,97 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
+
+
+# Package apache-airflow-providers-apache-beam
+
+Release: 0.0.1
+
+**Table of contents**
+
+- [Provider package](#provider-package)
+- [Installation](#installation)
+- [PIP requirements](#pip-requirements)
+- [Cross provider package dependencies](#cross-provider-package-dependencies)
+- [Provider class summary](#provider-classes-summary)
+    - [Operators](#operators)
+    - [Transfer operators](#transfer-operators)
+    - [Hooks](#hooks)
+- [Releases](#releases)
+
+## Provider package
+
+This is a provider package for `apache.beam` provider. All classes for this provider package
+are in `airflow.providers.apache.beam` python package.
+
+## Installation
+
+NOTE!
+
+On November 2020, new version of PIP (20.3) has been released with a new, 2020 resolver. This resolver
+does not yet work with Apache Airflow and might lead to errors in installation - depends on your choice
+of extras. In order to install Airflow you need to either downgrade pip to version 20.2.4
+`pip install --upgrade pip==20.2.4` or, in case you use Pip 20.3, you need to add option
+`--use-deprecated legacy-resolver` to your pip install command.
+
+You can install this package on top of an existing airflow 2.* installation via
+`pip install apache-airflow-providers-apache-beam`
+
+## Cross provider package dependencies
+
+Those are dependencies that might be needed in order to use all the features of the package.
+You need to install the specified backport providers package in order to use them.
+
+You can install such cross-provider dependencies when installing from PyPI. For example:
+
+```bash
+pip install apache-airflow-providers-apache-beam[google]
+```
+
+| Dependent package                                                                           | Extra       |
+|:--------------------------------------------------------------------------------------------|:------------|
+| [apache-airflow-providers-google](https://pypi.org/project/apache-airflow-providers-google) | google      |
+
+
+# Provider classes summary
+
+In Airflow 2.0, all operators, transfers, hooks, sensors, secrets for the `apache.beam` provider
+are in the `airflow.providers.apache.beam` package. You can read more about the naming conventions used
+in [Naming conventions for provider packages](https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#naming-conventions-for-provider-packages)
+
+
+## Operators
+
+### New operators
+
+| New Airflow 2.0 operators: `airflow.providers.apache.beam` package                                                                                                                 |
+|:-----------------------------------------------------------------------------------------------------------------------------------------------|
+| [operators.beam.BeamRunJavaPipelineOperator](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py)    |
+| [operators.beam.BeamRunPythonPipelineOperator](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/operators/beam.py)  |
+
+
+## Hooks
+
+### New hooks
+
+| New Airflow 2.0 hooks: `airflow.providers.apache.beam` package                                                   |
+|:-----------------------------------------------------------------------------------------------------------------|
+| [hooks.beam.BeamHook](https://github.com/apache/airflow/blob/master/airflow/providers/apache/beam/hooks/beam.py) |
+
+
+## Releases
diff --git a/airflow/providers/apache/beam/__init__.py b/airflow/providers/apache/beam/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/apache/beam/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/beam/example_dags/__init__.py b/airflow/providers/apache/beam/example_dags/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/apache/beam/example_dags/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/beam/example_dags/example_beam.py b/airflow/providers/apache/beam/example_dags/example_beam.py
new file mode 100644
index 0000000..d20c4ce
--- /dev/null
+++ b/airflow/providers/apache/beam/example_dags/example_beam.py
@@ -0,0 +1,315 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Apache Beam operators
+"""
+import os
+from urllib.parse import urlparse
+
+from airflow import models
+from airflow.providers.apache.beam.operators.beam import (
+    BeamRunJavaPipelineOperator,
+    BeamRunPythonPipelineOperator,
+)
+from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
+from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration
+from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor
+from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
+from airflow.utils.dates import days_ago
+
+GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project')
+GCS_INPUT = os.environ.get('APACHE_BEAM_PYTHON', 'gs://apache-beam-samples/shakespeare/kinglear.txt')
+GCS_TMP = os.environ.get('APACHE_BEAM_GCS_TMP', 'gs://test-dataflow-example/temp/')
+GCS_STAGING = os.environ.get('APACHE_BEAM_GCS_STAGING', 'gs://test-dataflow-example/staging/')
+GCS_OUTPUT = os.environ.get('APACHE_BEAM_GCS_OUTPUT', 'gs://test-dataflow-example/output')
+GCS_PYTHON = os.environ.get('APACHE_BEAM_PYTHON', 'gs://test-dataflow-example/wordcount_debugging.py')
+GCS_PYTHON_DATAFLOW_ASYNC = os.environ.get(
+    'APACHE_BEAM_PYTHON_DATAFLOW_ASYNC', 'gs://test-dataflow-example/wordcount_debugging.py'
+)
+
+GCS_JAR_DIRECT_RUNNER = os.environ.get(
+    'APACHE_BEAM_DIRECT_RUNNER_JAR',
+    'gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-DirectRunner.jar',
+)
+GCS_JAR_DATAFLOW_RUNNER = os.environ.get(
+    'APACHE_BEAM_DATAFLOW_RUNNER_JAR', 'gs://test-dataflow-example/word-count-beam-bundled-0.1.jar'
+)
+GCS_JAR_SPARK_RUNNER = os.environ.get(
+    'APACHE_BEAM_SPARK_RUNNER_JAR',
+    'gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-SparkRunner.jar',
+)
+GCS_JAR_FLINK_RUNNER = os.environ.get(
+    'APACHE_BEAM_FLINK_RUNNER_JAR',
+    'gs://test-dataflow-example/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-FlinkRunner.jar',
+)
+
+GCS_JAR_DIRECT_RUNNER_PARTS = urlparse(GCS_JAR_DIRECT_RUNNER)
+GCS_JAR_DIRECT_RUNNER_BUCKET_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.netloc
+GCS_JAR_DIRECT_RUNNER_OBJECT_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.path[1:]
+GCS_JAR_DATAFLOW_RUNNER_PARTS = urlparse(GCS_JAR_DATAFLOW_RUNNER)
+GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.netloc
+GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.path[1:]
+GCS_JAR_SPARK_RUNNER_PARTS = urlparse(GCS_JAR_SPARK_RUNNER)
+GCS_JAR_SPARK_RUNNER_BUCKET_NAME = GCS_JAR_SPARK_RUNNER_PARTS.netloc
+GCS_JAR_SPARK_RUNNER_OBJECT_NAME = GCS_JAR_SPARK_RUNNER_PARTS.path[1:]
+GCS_JAR_FLINK_RUNNER_PARTS = urlparse(GCS_JAR_FLINK_RUNNER)
+GCS_JAR_FLINK_RUNNER_BUCKET_NAME = GCS_JAR_FLINK_RUNNER_PARTS.netloc
+GCS_JAR_FLINK_RUNNER_OBJECT_NAME = GCS_JAR_FLINK_RUNNER_PARTS.path[1:]
+
+
+default_args = {
+    'default_pipeline_options': {
+        'output': '/tmp/example_beam',
+    },
+    "trigger_rule": "all_done",
+}
+
+
+with models.DAG(
+    "example_beam_native_java_direct_runner",
+    schedule_interval=None,  # Override to match your needs
+    start_date=days_ago(1),
+    tags=['example'],
+) as dag_native_java_direct_runner:
+
+    # [START howto_operator_start_java_direct_runner_pipeline]
+    jar_to_local_direct_runner = GCSToLocalFilesystemOperator(
+        task_id="jar_to_local_direct_runner",
+        bucket=GCS_JAR_DIRECT_RUNNER_BUCKET_NAME,
+        object_name=GCS_JAR_DIRECT_RUNNER_OBJECT_NAME,
+        filename="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar",
+    )
+
+    start_java_pipeline_direct_runner = BeamRunJavaPipelineOperator(
+        task_id="start_java_pipeline_direct_runner",
+        jar="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar",
+        pipeline_options={
+            'output': '/tmp/start_java_pipeline_direct_runner',
+            'inputFile': GCS_INPUT,
+        },
+        job_class='org.apache.beam.examples.WordCount',
+    )
+
+    jar_to_local_direct_runner >> start_java_pipeline_direct_runner
+    # [END howto_operator_start_java_direct_runner_pipeline]
+
+with models.DAG(
+    "example_beam_native_java_dataflow_runner",
+    schedule_interval=None,  # Override to match your needs
+    start_date=days_ago(1),
+    tags=['example'],
+) as dag_native_java_dataflow_runner:
+    # [START howto_operator_start_java_dataflow_runner_pipeline]
+    jar_to_local_dataflow_runner = GCSToLocalFilesystemOperator(
+        task_id="jar_to_local_dataflow_runner",
+        bucket=GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME,
+        object_name=GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME,
+        filename="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar",
+    )
+
+    start_java_pipeline_dataflow = BeamRunJavaPipelineOperator(
+        task_id="start_java_pipeline_dataflow",
+        runner="DataflowRunner",
+        jar="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar",
+        pipeline_options={
+            'tempLocation': GCS_TMP,
+            'stagingLocation': GCS_STAGING,
+            'output': GCS_OUTPUT,
+        },
+        job_class='org.apache.beam.examples.WordCount',
+        dataflow_config={"job_name": "{{task.task_id}}", "location": "us-central1"},
+    )
+
+    jar_to_local_dataflow_runner >> start_java_pipeline_dataflow
+    # [END howto_operator_start_java_dataflow_runner_pipeline]
+
+with models.DAG(
+    "example_beam_native_java_spark_runner",
+    schedule_interval=None,  # Override to match your needs
+    start_date=days_ago(1),
+    tags=['example'],
+) as dag_native_java_spark_runner:
+
+    jar_to_local_spark_runner = GCSToLocalFilesystemOperator(
+        task_id="jar_to_local_spark_runner",
+        bucket=GCS_JAR_SPARK_RUNNER_BUCKET_NAME,
+        object_name=GCS_JAR_SPARK_RUNNER_OBJECT_NAME,
+        filename="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar",
+    )
+
+    start_java_pipeline_spark_runner = BeamRunJavaPipelineOperator(
+        task_id="start_java_pipeline_spark_runner",
+        runner="SparkRunner",
+        jar="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar",
+        pipeline_options={
+            'output': '/tmp/start_java_pipeline_spark_runner',
+            'inputFile': GCS_INPUT,
+        },
+        job_class='org.apache.beam.examples.WordCount',
+    )
+
+    jar_to_local_spark_runner >> start_java_pipeline_spark_runner
+
+with models.DAG(
+    "example_beam_native_java_flink_runner",
+    schedule_interval=None,  # Override to match your needs
+    start_date=days_ago(1),
+    tags=['example'],
+) as dag_native_java_flink_runner:
+
+    jar_to_local_flink_runner = GCSToLocalFilesystemOperator(
+        task_id="jar_to_local_flink_runner",
+        bucket=GCS_JAR_FLINK_RUNNER_BUCKET_NAME,
+        object_name=GCS_JAR_FLINK_RUNNER_OBJECT_NAME,
+        filename="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar",
+    )
+
+    start_java_pipeline_flink_runner = BeamRunJavaPipelineOperator(
+        task_id="start_java_pipeline_flink_runner",
+        runner="FlinkRunner",
+        jar="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar",
+        pipeline_options={
+            'output': '/tmp/start_java_pipeline_flink_runner',
+            'inputFile': GCS_INPUT,
+        },
+        job_class='org.apache.beam.examples.WordCount',
+    )
+
+    jar_to_local_flink_runner >> start_java_pipeline_flink_runner
+
+
+with models.DAG(
+    "example_beam_native_python",
+    default_args=default_args,
+    start_date=days_ago(1),
+    schedule_interval=None,  # Override to match your needs
+    tags=['example'],
+) as dag_native_python:
+
+    # [START howto_operator_start_python_direct_runner_pipeline_local_file]
+    start_python_pipeline_local_direct_runner = BeamRunPythonPipelineOperator(
+        task_id="start_python_pipeline_local_direct_runner",
+        py_file='apache_beam.examples.wordcount',
+        py_options=['-m'],
+        py_requirements=['apache-beam[gcp]==2.26.0'],
+        py_interpreter='python3',
+        py_system_site_packages=False,
+    )
+    # [END howto_operator_start_python_direct_runner_pipeline_local_file]
+
+    # [START howto_operator_start_python_direct_runner_pipeline_gcs_file]
+    start_python_pipeline_direct_runner = BeamRunPythonPipelineOperator(
+        task_id="start_python_pipeline_direct_runner",
+        py_file=GCS_PYTHON,
+        py_options=[],
+        pipeline_options={"output": GCS_OUTPUT},
+        py_requirements=['apache-beam[gcp]==2.26.0'],
+        py_interpreter='python3',
+        py_system_site_packages=False,
+    )
+    # [END howto_operator_start_python_direct_runner_pipeline_gcs_file]
+
+    # [START howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
+    start_python_pipeline_dataflow_runner = BeamRunPythonPipelineOperator(
+        task_id="start_python_pipeline_dataflow_runner",
+        runner="DataflowRunner",
+        py_file=GCS_PYTHON,
+        pipeline_options={
+            'tempLocation': GCS_TMP,
+            'stagingLocation': GCS_STAGING,
+            'output': GCS_OUTPUT,
+        },
+        py_options=[],
+        py_requirements=['apache-beam[gcp]==2.26.0'],
+        py_interpreter='python3',
+        py_system_site_packages=False,
+        dataflow_config=DataflowConfiguration(
+            job_name='{{task.task_id}}', project_id=GCP_PROJECT_ID, location="us-central1"
+        ),
+    )
+    # [END howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
+
+    start_python_pipeline_local_spark_runner = BeamRunPythonPipelineOperator(
+        task_id="start_python_pipeline_local_spark_runner",
+        py_file='apache_beam.examples.wordcount',
+        runner="SparkRunner",
+        py_options=['-m'],
+        py_requirements=['apache-beam[gcp]==2.26.0'],
+        py_interpreter='python3',
+        py_system_site_packages=False,
+    )
+
+    start_python_pipeline_local_flink_runner = BeamRunPythonPipelineOperator(
+        task_id="start_python_pipeline_local_flink_runner",
+        py_file='apache_beam.examples.wordcount',
+        runner="FlinkRunner",
+        py_options=['-m'],
+        pipeline_options={
+            'output': '/tmp/start_python_pipeline_local_flink_runner',
+        },
+        py_requirements=['apache-beam[gcp]==2.26.0'],
+        py_interpreter='python3',
+        py_system_site_packages=False,
+    )
+
+    [
+        start_python_pipeline_local_direct_runner,
+        start_python_pipeline_direct_runner,
+    ] >> start_python_pipeline_local_flink_runner >> start_python_pipeline_local_spark_runner
+
+
+with models.DAG(
+    "example_beam_native_python_dataflow_async",
+    default_args=default_args,
+    start_date=days_ago(1),
+    schedule_interval=None,  # Override to match your needs
+    tags=['example'],
+) as dag_native_python_dataflow_async:
+    # [START howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
+    start_python_job_dataflow_runner_async = BeamRunPythonPipelineOperator(
+        task_id="start_python_job_dataflow_runner_async",
+        runner="DataflowRunner",
+        py_file=GCS_PYTHON_DATAFLOW_ASYNC,
+        pipeline_options={
+            'tempLocation': GCS_TMP,
+            'stagingLocation': GCS_STAGING,
+            'output': GCS_OUTPUT,
+        },
+        py_options=[],
+        py_requirements=['apache-beam[gcp]==2.26.0'],
+        py_interpreter='python3',
+        py_system_site_packages=False,
+        dataflow_config=DataflowConfiguration(
+            job_name='{{task.task_id}}',
+            project_id=GCP_PROJECT_ID,
+            location="us-central1",
+            wait_until_finished=False,
+        ),
+    )
+
+    wait_for_python_job_dataflow_runner_async_done = DataflowJobStatusSensor(
+        task_id="wait-for-python-job-async-done",
+        job_id="{{task_instance.xcom_pull('start_python_job_dataflow_runner_async')['dataflow_job_id']}}",
+        expected_statuses={DataflowJobStatus.JOB_STATE_DONE},
+        project_id=GCP_PROJECT_ID,
+        location='us-central1',
+    )
+
+    start_python_job_dataflow_runner_async >> wait_for_python_job_dataflow_runner_async_done
+    # [END howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
diff --git a/airflow/providers/apache/beam/hooks/__init__.py b/airflow/providers/apache/beam/hooks/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/apache/beam/hooks/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py
new file mode 100644
index 0000000..8e188b0
--- /dev/null
+++ b/airflow/providers/apache/beam/hooks/beam.py
@@ -0,0 +1,289 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains a Apache Beam Hook."""
+import json
+import select
+import shlex
+import subprocess
+import textwrap
+from tempfile import TemporaryDirectory
+from typing import Callable, List, Optional
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.python_virtualenv import prepare_virtualenv
+
+
+class BeamRunnerType:
+    """
+    Helper class for listing runner types.
+    For more information about runners see:
+    https://beam.apache.org/documentation/
+    """
+
+    DataflowRunner = "DataflowRunner"
+    DirectRunner = "DirectRunner"
+    SparkRunner = "SparkRunner"
+    FlinkRunner = "FlinkRunner"
+    SamzaRunner = "SamzaRunner"
+    NemoRunner = "NemoRunner"
+    JetRunner = "JetRunner"
+    Twister2Runner = "Twister2Runner"
+
+
+def beam_options_to_args(options: dict) -> List[str]:
+    """
+    Returns a formatted pipeline options from a dictionary of arguments
+
+    The logic of this method should be compatible with Apache Beam:
+    https://github.com/apache/beam/blob/b56740f0e8cd80c2873412847d0b336837429fb9/sdks/python/
+    apache_beam/options/pipeline_options.py#L230-L251
+
+    :param options: Dictionary with options
+    :type options: dict
+    :return: List of arguments
+    :rtype: List[str]
+    """
+    if not options:
+        return []
+
+    args: List[str] = []
+    for attr, value in options.items():
+        if value is None or (isinstance(value, bool) and value):
+            args.append(f"--{attr}")
+        elif isinstance(value, list):
+            args.extend([f"--{attr}={v}" for v in value])
+        else:
+            args.append(f"--{attr}={value}")
+    return args
+
+
+class BeamCommandRunner(LoggingMixin):
+    """
+    Class responsible for running pipeline command in subprocess
+
+    :param cmd: Parts of the command to be run in subprocess
+    :type cmd: List[str]
+    :param process_line_callback: Optional callback which can be used to process
+        stdout and stderr to detect job id
+    :type process_line_callback: Optional[Callable[[str], None]]
+    """
+
+    def __init__(
+        self,
+        cmd: List[str],
+        process_line_callback: Optional[Callable[[str], None]] = None,
+    ) -> None:
+        super().__init__()
+        self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd))
+        self.process_line_callback = process_line_callback
+        self.job_id: Optional[str] = None
+        self._proc = subprocess.Popen(
+            cmd,
+            shell=False,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            close_fds=True,
+        )
+
+    def _process_fd(self, fd):
+        """
+        Prints output to logs.
+
+        :param fd: File descriptor.
+        """
+        if fd not in (self._proc.stdout, self._proc.stderr):
+            raise Exception("No data in stderr or in stdout.")
+
+        fd_to_log = {self._proc.stderr: self.log.warning, self._proc.stdout: self.log.info}
+        func_log = fd_to_log[fd]
+
+        while True:
+            line = fd.readline().decode()
+            if not line:
+                return
+            if self.process_line_callback:
+                self.process_line_callback(line)
+            func_log(line.rstrip("\n"))
+
+    def wait_for_done(self) -> None:
+        """Waits for Apache Beam pipeline to complete."""
+        self.log.info("Start waiting for Apache Beam process to complete.")
+        reads = [self._proc.stderr, self._proc.stdout]
+        while True:
+            # Wait for at least one available fd.
+            readable_fds, _, _ = select.select(reads, [], [], 5)
+            if readable_fds is None:
+                self.log.info("Waiting for Apache Beam process to complete.")
+                continue
+
+            for readable_fd in readable_fds:
+                self._process_fd(readable_fd)
+
+            if self._proc.poll() is not None:
+                break
+
+        # Corner case: check if more output was created between the last read and the process termination
+        for readable_fd in reads:
+            self._process_fd(readable_fd)
+
+        self.log.info("Process exited with return code: %s", self._proc.returncode)
+
+        if self._proc.returncode != 0:
+            raise AirflowException(f"Apache Beam process failed with return code {self._proc.returncode}")
+
+
+class BeamHook(BaseHook):
+    """
+    Hook for Apache Beam.
+
+    All the methods in the hook where project_id is used must be called with
+    keyword arguments rather than positional.
+
+    :param runner: Runner type
+    :type runner: str
+    """
+
+    def __init__(
+        self,
+        runner: str,
+    ) -> None:
+        self.runner = runner
+        super().__init__()
+
+    def _start_pipeline(
+        self,
+        variables: dict,
+        command_prefix: List[str],
+        process_line_callback: Optional[Callable[[str], None]] = None,
+    ) -> None:
+        cmd = command_prefix + [
+            f"--runner={self.runner}",
+        ]
+        if variables:
+            cmd.extend(beam_options_to_args(variables))
+        cmd_runner = BeamCommandRunner(
+            cmd=cmd,
+            process_line_callback=process_line_callback,
+        )
+        cmd_runner.wait_for_done()
+
+    def start_python_pipeline(  # pylint: disable=too-many-arguments
+        self,
+        variables: dict,
+        py_file: str,
+        py_options: List[str],
+        py_interpreter: str = "python3",
+        py_requirements: Optional[List[str]] = None,
+        py_system_site_packages: bool = False,
+        process_line_callback: Optional[Callable[[str], None]] = None,
+    ):
+        """
+        Starts Apache Beam python pipeline.
+
+        :param variables: Variables passed to the pipeline.
+        :type variables: Dict
+        :param py_options: Additional options.
+        :type py_options: List[str]
+        :param py_interpreter: Python version of the Apache Beam pipeline.
+            If None, this defaults to the python3.
+            To track python versions supported by beam and related
+            issues check: https://issues.apache.org/jira/browse/BEAM-1251
+        :type py_interpreter: str
+        :param py_requirements: Additional python package(s) to install.
+            If a value is passed to this parameter, a new virtual environment has been created with
+            additional packages installed.
+
+            You could also install the apache-beam package if it is not installed on your system or you want
+            to use a different version.
+        :type py_requirements: List[str]
+        :param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
+            See virtualenv documentation for more information.
+
+            This option is only relevant if the ``py_requirements`` parameter is not None.
+        :type py_system_site_packages: bool
+        :param on_new_job_id_callback: Callback called when the job ID is known.
+        :type on_new_job_id_callback: callable
+        """
+        if "labels" in variables:
+            variables["labels"] = [f"{key}={value}" for key, value in variables["labels"].items()]
+
+        if py_requirements is not None:
+            if not py_requirements and not py_system_site_packages:
+                warning_invalid_environment = textwrap.dedent(
+                    """\
+                    Invalid method invocation. You have disabled inclusion of system packages and empty list
+                    required for installation, so it is not possible to create a valid virtual environment.
+                    In the virtual environment, apache-beam package must be installed for your job to be \
+                    executed. To fix this problem:
+                    * install apache-beam on the system, then set parameter py_system_site_packages to True,
+                    * add apache-beam to the list of required packages in parameter py_requirements.
+                    """
+                )
+                raise AirflowException(warning_invalid_environment)
+
+            with TemporaryDirectory(prefix="apache-beam-venv") as tmp_dir:
+                py_interpreter = prepare_virtualenv(
+                    venv_directory=tmp_dir,
+                    python_bin=py_interpreter,
+                    system_site_packages=py_system_site_packages,
+                    requirements=py_requirements,
+                )
+                command_prefix = [py_interpreter] + py_options + [py_file]
+
+                self._start_pipeline(
+                    variables=variables,
+                    command_prefix=command_prefix,
+                    process_line_callback=process_line_callback,
+                )
+        else:
+            command_prefix = [py_interpreter] + py_options + [py_file]
+
+            self._start_pipeline(
+                variables=variables,
+                command_prefix=command_prefix,
+                process_line_callback=process_line_callback,
+            )
+
+    def start_java_pipeline(
+        self,
+        variables: dict,
+        jar: str,
+        job_class: Optional[str] = None,
+        process_line_callback: Optional[Callable[[str], None]] = None,
+    ) -> None:
+        """
+        Starts Apache Beam Java pipeline.
+
+        :param variables: Variables passed to the job.
+        :type variables: dict
+        :param jar: Name of the jar for the pipeline
+        :type job_class: str
+        :param job_class: Name of the java class for the pipeline.
+        :type job_class: str
+        """
+        if "labels" in variables:
+            variables["labels"] = json.dumps(variables["labels"], separators=(",", ":"))
+
+        command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar]
+        self._start_pipeline(
+            variables=variables,
+            command_prefix=command_prefix,
+            process_line_callback=process_line_callback,
+        )
diff --git a/airflow/providers/apache/beam/operators/__init__.py b/airflow/providers/apache/beam/operators/__init__.py
new file mode 100644
index 0000000..217e5db
--- /dev/null
+++ b/airflow/providers/apache/beam/operators/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py
new file mode 100644
index 0000000..849298e
--- /dev/null
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -0,0 +1,446 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains Apache Beam operators."""
+from contextlib import ExitStack
+from typing import Callable, List, Optional, Union
+
+from airflow.models import BaseOperator
+from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
+from airflow.providers.google.cloud.hooks.dataflow import (
+    DataflowHook,
+    process_line_and_extract_dataflow_job_id_callback,
+)
+from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, DataflowConfiguration
+from airflow.utils.decorators import apply_defaults
+from airflow.utils.helpers import convert_camel_to_snake
+from airflow.version import version
+
+
+class BeamRunPythonPipelineOperator(BaseOperator):
+    """
+    Launching Apache Beam pipelines written in Python. Note that both
+    ``default_pipeline_options`` and ``pipeline_options`` will be merged to specify pipeline
+    execution parameter, and ``default_pipeline_options`` is expected to save
+    high-level options, for instances, project and zone information, which
+    apply to all beam operators in the DAG.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:BeamRunPythonPipelineOperator`
+
+    .. seealso::
+        For more detail on Apache Beam have a look at the reference:
+        https://beam.apache.org/documentation/
+
+    :param py_file: Reference to the python Apache Beam pipeline file.py, e.g.,
+        /some/local/file/path/to/your/python/pipeline/file. (templated)
+    :type py_file: str
+    :param runner: Runner on which pipeline will be run. By default "DirectRunner" is being used.
+        Other possible options: DataflowRunner, SparkRunner, FlinkRunner.
+        See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType`
+        See: https://beam.apache.org/documentation/runners/capability-matrix/
+
+        If you use Dataflow runner check dedicated operator:
+        :class:`~providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`
+    :type runner: str
+    :param py_options: Additional python options, e.g., ["-m", "-v"].
+    :type py_options: list[str]
+    :param default_pipeline_options: Map of default pipeline options.
+    :type default_pipeline_options: dict
+    :param pipeline_options: Map of pipeline options.The key must be a dictionary.
+        The value can contain different types:
+
+        * If the value is None, the single option - ``--key`` (without value) will be added.
+        * If the value is False, this option will be skipped
+        * If the value is True, the single option - ``--key`` (without value) will be added.
+        * If the value is list, the many options will be added for each key.
+          If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` options
+          will be left
+        * Other value types will be replaced with the Python textual representation.
+
+        When defining labels (``labels`` option), you can also provide a dictionary.
+    :type pipeline_options: dict
+    :param py_interpreter: Python version of the beam pipeline.
+        If None, this defaults to the python3.
+        To track python versions supported by beam and related
+        issues check: https://issues.apache.org/jira/browse/BEAM-1251
+    :type py_interpreter: str
+    :param py_requirements: Additional python package(s) to install.
+        If a value is passed to this parameter, a new virtual environment has been created with
+        additional packages installed.
+
+        You could also install the apache_beam package if it is not installed on your system or you want
+        to use a different version.
+    :type py_requirements: List[str]
+    :param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
+        See virtualenv documentation for more information.
+
+        This option is only relevant if the ``py_requirements`` parameter is not None.
+    :param gcp_conn_id: Optional.
+        The connection ID to use connecting to Google Cloud Storage if python file is on GCS.
+    :type gcp_conn_id: str
+    :param delegate_to:  Optional.
+        The account to impersonate using domain-wide delegation of authority,
+        if any. For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: str
+    :param dataflow_config: Dataflow configuration, used when runner type is set to DataflowRunner
+    :type dataflow_config: Union[dict, providers.google.cloud.operators.dataflow.DataflowConfiguration]
+    """
+
+    template_fields = ["py_file", "runner", "pipeline_options", "default_pipeline_options", "dataflow_config"]
+    template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'}
+
+    @apply_defaults
+    def __init__(
+        self,
+        *,
+        py_file: str,
+        runner: str = "DirectRunner",
+        default_pipeline_options: Optional[dict] = None,
+        pipeline_options: Optional[dict] = None,
+        py_interpreter: str = "python3",
+        py_options: Optional[List[str]] = None,
+        py_requirements: Optional[List[str]] = None,
+        py_system_site_packages: bool = False,
+        gcp_conn_id: str = "google_cloud_default",
+        delegate_to: Optional[str] = None,
+        dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+
+        self.py_file = py_file
+        self.runner = runner
+        self.py_options = py_options or []
+        self.default_pipeline_options = default_pipeline_options or {}
+        self.pipeline_options = pipeline_options or {}
+        self.pipeline_options.setdefault("labels", {}).update(
+            {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
+        )
+        self.py_interpreter = py_interpreter
+        self.py_requirements = py_requirements
+        self.py_system_site_packages = py_system_site_packages
+        self.gcp_conn_id = gcp_conn_id
+        self.delegate_to = delegate_to
+        self.dataflow_config = dataflow_config or {}
+        self.beam_hook: Optional[BeamHook] = None
+        self.dataflow_hook: Optional[DataflowHook] = None
+        self.dataflow_job_id: Optional[str] = None
+
+        if self.dataflow_config and self.runner.lower() != BeamRunnerType.DataflowRunner.lower():
+            self.log.warning(
+                "dataflow_config is defined but runner is different than DataflowRunner (%s)", self.runner
+            )
+
+    def execute(self, context):
+        """Execute the Apache Beam Pipeline."""
+        self.beam_hook = BeamHook(runner=self.runner)
+        pipeline_options = self.default_pipeline_options.copy()
+        process_line_callback: Optional[Callable] = None
+        is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
+
+        if isinstance(self.dataflow_config, dict):
+            self.dataflow_config = DataflowConfiguration(**self.dataflow_config)
+
+        if is_dataflow:
+            self.dataflow_hook = DataflowHook(
+                gcp_conn_id=self.dataflow_config.gcp_conn_id or self.gcp_conn_id,
+                delegate_to=self.dataflow_config.delegate_to or self.delegate_to,
+                poll_sleep=self.dataflow_config.poll_sleep,
+                impersonation_chain=self.dataflow_config.impersonation_chain,
+                drain_pipeline=self.dataflow_config.drain_pipeline,
+                cancel_timeout=self.dataflow_config.cancel_timeout,
+                wait_until_finished=self.dataflow_config.wait_until_finished,
+            )
+            self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id
+
+            dataflow_job_name = DataflowHook.build_dataflow_job_name(
+                self.dataflow_config.job_name, self.dataflow_config.append_job_name
+            )
+            pipeline_options["job_name"] = dataflow_job_name
+            pipeline_options["project"] = self.dataflow_config.project_id
+            pipeline_options["region"] = self.dataflow_config.location
+            pipeline_options.setdefault("labels", {}).update(
+                {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
+            )
+
+            def set_current_dataflow_job_id(job_id):
+                self.dataflow_job_id = job_id
+
+            process_line_callback = process_line_and_extract_dataflow_job_id_callback(
+                on_new_job_id_callback=set_current_dataflow_job_id
+            )
+
+        pipeline_options.update(self.pipeline_options)
+
+        # Convert argument names from lowerCamelCase to snake case.
+        formatted_pipeline_options = {
+            convert_camel_to_snake(key): pipeline_options[key] for key in pipeline_options
+        }
+
+        with ExitStack() as exit_stack:
+            if self.py_file.lower().startswith("gs://"):
+                gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
+                tmp_gcs_file = exit_stack.enter_context(  # pylint: disable=no-member
+                    gcs_hook.provide_file(object_url=self.py_file)
+                )
+                self.py_file = tmp_gcs_file.name
+
+            self.beam_hook.start_python_pipeline(
+                variables=formatted_pipeline_options,
+                py_file=self.py_file,
+                py_options=self.py_options,
+                py_interpreter=self.py_interpreter,
+                py_requirements=self.py_requirements,
+                py_system_site_packages=self.py_system_site_packages,
+                process_line_callback=process_line_callback,
+            )
+
+            if is_dataflow:
+                self.dataflow_hook.wait_for_done(  # pylint: disable=no-value-for-parameter
+                    job_name=dataflow_job_name,
+                    location=self.dataflow_config.location,
+                    job_id=self.dataflow_job_id,
+                    multiple_jobs=False,
+                )
+
+        return {"dataflow_job_id": self.dataflow_job_id}
+
+    def on_kill(self) -> None:
+        if self.dataflow_hook and self.dataflow_job_id:
+            self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id)
+            self.dataflow_hook.cancel_job(
+                job_id=self.dataflow_job_id,
+                project_id=self.dataflow_config.project_id,
+            )
+
+
+# pylint: disable=too-many-instance-attributes
+class BeamRunJavaPipelineOperator(BaseOperator):
+    """
+    Launching Apache Beam pipelines written in Java.
+
+    Note that both
+    ``default_pipeline_options`` and ``pipeline_options`` will be merged to specify pipeline
+    execution parameter, and ``default_pipeline_options`` is expected to save
+    high-level pipeline_options, for instances, project and zone information, which
+    apply to all Apache Beam operators in the DAG.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:BeamRunJavaPipelineOperator`
+
+    .. seealso::
+        For more detail on Apache Beam have a look at the reference:
+        https://beam.apache.org/documentation/
+
+    You need to pass the path to your jar file as a file reference with the ``jar``
+    parameter, the jar needs to be a self executing jar (see documentation here:
+    https://beam.apache.org/documentation/runners/dataflow/#self-executing-jar).
+    Use ``pipeline_options`` to pass on pipeline_options to your job.
+
+    :param jar: The reference to a self executing Apache Beam jar (templated).
+    :type jar: str
+    :param runner: Runner on which pipeline will be run. By default "DirectRunner" is being used.
+        See:
+        https://beam.apache.org/documentation/runners/capability-matrix/
+        If you use Dataflow runner check dedicated operator:
+        :class:`~providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`
+    :type runner: str
+    :param job_class: The name of the Apache Beam pipeline class to be executed, it
+        is often not the main class configured in the pipeline jar file.
+    :type job_class: str
+    :param default_pipeline_options: Map of default job pipeline_options.
+    :type default_pipeline_options: dict
+    :param pipeline_options: Map of job specific pipeline_options.The key must be a dictionary.
+        The value can contain different types:
+
+        * If the value is None, the single option - ``--key`` (without value) will be added.
+        * If the value is False, this option will be skipped
+        * If the value is True, the single option - ``--key`` (without value) will be added.
+        * If the value is list, the many pipeline_options will be added for each key.
+          If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` pipeline_options
+          will be left
+        * Other value types will be replaced with the Python textual representation.
+
+        When defining labels (``labels`` option), you can also provide a dictionary.
+    :type pipeline_options: dict
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud Storage if jar is on GCS
+    :type gcp_conn_id: str
+    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+        if any. For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: str
+    :param dataflow_config: Dataflow configuration, used when runner type is set to DataflowRunner
+    :type dataflow_config: Union[dict, providers.google.cloud.operators.dataflow.DataflowConfiguration]
+    """
+
+    template_fields = [
+        "jar",
+        "runner",
+        "job_class",
+        "pipeline_options",
+        "default_pipeline_options",
+        "dataflow_config",
+    ]
+    template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'}
+    ui_color = "#0273d4"
+
+    @apply_defaults
+    def __init__(
+        self,
+        *,
+        jar: str,
+        runner: str = "DirectRunner",
+        job_class: Optional[str] = None,
+        default_pipeline_options: Optional[dict] = None,
+        pipeline_options: Optional[dict] = None,
+        gcp_conn_id: str = "google_cloud_default",
+        delegate_to: Optional[str] = None,
+        dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+
+        self.jar = jar
+        self.runner = runner
+        self.default_pipeline_options = default_pipeline_options or {}
+        self.pipeline_options = pipeline_options or {}
+        self.job_class = job_class
+        self.dataflow_config = dataflow_config or {}
+        self.gcp_conn_id = gcp_conn_id
+        self.delegate_to = delegate_to
+        self.dataflow_job_id = None
+        self.dataflow_hook: Optional[DataflowHook] = None
+        self.beam_hook: Optional[BeamHook] = None
+        self._dataflow_job_name: Optional[str] = None
+
+        if self.dataflow_config and self.runner.lower() != BeamRunnerType.DataflowRunner.lower():
+            self.log.warning(
+                "dataflow_config is defined but runner is different than DataflowRunner (%s)", self.runner
+            )
+
+    def execute(self, context):
+        """Execute the Apache Beam Pipeline."""
+        self.beam_hook = BeamHook(runner=self.runner)
+        pipeline_options = self.default_pipeline_options.copy()
+        process_line_callback: Optional[Callable] = None
+        is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
+
+        if isinstance(self.dataflow_config, dict):
+            self.dataflow_config = DataflowConfiguration(**self.dataflow_config)
+
+        if is_dataflow:
+            self.dataflow_hook = DataflowHook(
+                gcp_conn_id=self.dataflow_config.gcp_conn_id or self.gcp_conn_id,
+                delegate_to=self.dataflow_config.delegate_to or self.delegate_to,
+                poll_sleep=self.dataflow_config.poll_sleep,
+                impersonation_chain=self.dataflow_config.impersonation_chain,
+                drain_pipeline=self.dataflow_config.drain_pipeline,
+                cancel_timeout=self.dataflow_config.cancel_timeout,
+                wait_until_finished=self.dataflow_config.wait_until_finished,
+            )
+            self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id
+
+            self._dataflow_job_name = DataflowHook.build_dataflow_job_name(
+                self.dataflow_config.job_name, self.dataflow_config.append_job_name
+            )
+            pipeline_options["jobName"] = self.dataflow_config.job_name
+            pipeline_options["project"] = self.dataflow_config.project_id
+            pipeline_options["region"] = self.dataflow_config.location
+            pipeline_options.setdefault("labels", {}).update(
+                {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
+            )
+
+            def set_current_dataflow_job_id(job_id):
+                self.dataflow_job_id = job_id
+
+            process_line_callback = process_line_and_extract_dataflow_job_id_callback(
+                on_new_job_id_callback=set_current_dataflow_job_id
+            )
+
+        pipeline_options.update(self.pipeline_options)
+
+        with ExitStack() as exit_stack:
+            if self.jar.lower().startswith("gs://"):
+                gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
+                tmp_gcs_file = exit_stack.enter_context(  # pylint: disable=no-member
+                    gcs_hook.provide_file(object_url=self.jar)
+                )
+                self.jar = tmp_gcs_file.name
+
+            if is_dataflow:
+                is_running = False
+                if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob:
+                    is_running = (
+                        # The reason for disable=no-value-for-parameter is that project_id parameter is
+                        # required but here is not passed, moreover it cannot be passed here.
+                        # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
+                        # fallback project_id value from variables and raise error if project_id is
+                        # defined both in variables and as parameter (here is already defined in variables)
+                        self.dataflow_hook.is_job_dataflow_running(  # pylint: disable=no-value-for-parameter
+                            name=self.dataflow_config.job_name,
+                            variables=pipeline_options,
+                        )
+                    )
+                    while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
+                        # The reason for disable=no-value-for-parameter is that project_id parameter is
+                        # required but here is not passed, moreover it cannot be passed here.
+                        # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
+                        # fallback project_id value from variables and raise error if project_id is
+                        # defined both in variables and as parameter (here is already defined in variables)
+                        # pylint: disable=no-value-for-parameter
+                        is_running = self.dataflow_hook.is_job_dataflow_running(
+                            name=self.dataflow_config.job_name,
+                            variables=pipeline_options,
+                        )
+                if not is_running:
+                    pipeline_options["jobName"] = self._dataflow_job_name
+                    self.beam_hook.start_java_pipeline(
+                        variables=pipeline_options,
+                        jar=self.jar,
+                        job_class=self.job_class,
+                        process_line_callback=process_line_callback,
+                    )
+                    self.dataflow_hook.wait_for_done(
+                        job_name=self._dataflow_job_name,
+                        location=self.dataflow_config.location,
+                        job_id=self.dataflow_job_id,
+                        multiple_jobs=self.dataflow_config.multiple_jobs,
+                        project_id=self.dataflow_config.project_id,
+                    )
+
+            else:
+                self.beam_hook.start_java_pipeline(
+                    variables=pipeline_options,
+                    jar=self.jar,
+                    job_class=self.job_class,
+                    process_line_callback=process_line_callback,
+                )
+
+        return {"dataflow_job_id": self.dataflow_job_id}
+
+    def on_kill(self) -> None:
+        if self.dataflow_hook and self.dataflow_job_id:
+            self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id)
+            self.dataflow_hook.cancel_job(
+                job_id=self.dataflow_job_id,
+                project_id=self.dataflow_config.project_id,
+            )
diff --git a/airflow/providers/apache/beam/provider.yaml b/airflow/providers/apache/beam/provider.yaml
new file mode 100644
index 0000000..4325265
--- /dev/null
+++ b/airflow/providers/apache/beam/provider.yaml
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-apache-beam
+name: Apache Beam
+description: |
+    `Apache Beam <https://beam.apache.org/>`__.
+
+versions:
+  - 0.0.1
+
+integrations:
+  - integration-name: Apache Beam
+    external-doc-url: https://beam.apache.org/
+    how-to-guide:
+      - /docs/apache-airflow-providers-apache-beam/operators.rst
+    tags: [apache]
+
+operators:
+  - integration-name: Apache Beam
+    python-modules:
+      - airflow.providers.apache.beam.operators.beam
+
+hooks:
+  - integration-name: Apache Beam
+    python-modules:
+      - airflow.providers.apache.beam.hooks.beam
+
+hook-class-names:
+  - airflow.providers.apache.beam.hooks.beam.BeamHook
diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json
index 748b1a5..836020c 100644
--- a/airflow/providers/dependencies.json
+++ b/airflow/providers/dependencies.json
@@ -8,6 +8,9 @@
     "postgres",
     "ssh"
   ],
+  "apache.beam": [
+    "google"
+  ],
   "apache.druid": [
     "apache.hive"
   ],
@@ -30,6 +33,7 @@
   ],
   "google": [
     "amazon",
+    "apache.beam",
     "apache.cassandra",
     "cncf.kubernetes",
     "facebook",
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py
index 0a665d4..0ad0262 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -19,23 +19,20 @@
 import functools
 import json
 import re
-import select
 import shlex
 import subprocess
-import textwrap
 import time
 import uuid
 import warnings
 from copy import deepcopy
-from tempfile import TemporaryDirectory
 from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Set, TypeVar, Union, cast
 
 from googleapiclient.discovery import build
 
 from airflow.exceptions import AirflowException
+from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.python_virtualenv import prepare_virtualenv
 from airflow.utils.timeout import timeout
 
 # This is the default location
@@ -50,6 +47,35 @@ JOB_ID_PATTERN = re.compile(
 T = TypeVar("T", bound=Callable)  # pylint: disable=invalid-name
 
 
+def process_line_and_extract_dataflow_job_id_callback(
+    on_new_job_id_callback: Optional[Callable[[str], None]]
+) -> Callable[[str], None]:
+    """
+    Returns callback which triggers function passed as `on_new_job_id_callback` when Dataflow job_id is found.
+    To be used for `process_line_callback` in
+    :py:class:`~airflow.providers.apache.beam.hooks.beam.BeamCommandRunner`
+
+    :param on_new_job_id_callback: Callback called when the job ID is known
+    :type on_new_job_id_callback: callback
+    """
+
+    def _process_line_and_extract_job_id(
+        line: str,
+        # on_new_job_id_callback: Optional[Callable[[str], None]]
+    ) -> None:
+        # Job id info: https://goo.gl/SE29y9.
+        matched_job = JOB_ID_PATTERN.search(line)
+        if matched_job:
+            job_id = matched_job.group("job_id_java") or matched_job.group("job_id_python")
+            if on_new_job_id_callback:
+                on_new_job_id_callback(job_id)
+
+    def wrap(line: str):
+        return _process_line_and_extract_job_id(line)
+
+    return wrap
+
+
 def _fallback_variable_parameter(parameter_name: str, variable_key_name: str) -> Callable[[T], T]:
     def _wrapper(func: T) -> T:
         """
@@ -484,98 +510,6 @@ class _DataflowJobsController(LoggingMixin):
             self.log.info("No jobs to cancel")
 
 
-class _DataflowRunner(LoggingMixin):
-    def __init__(
-        self,
-        cmd: List[str],
-        on_new_job_id_callback: Optional[Callable[[str], None]] = None,
-    ) -> None:
-        super().__init__()
-        self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd))
-        self.on_new_job_id_callback = on_new_job_id_callback
-        self.job_id: Optional[str] = None
-        self._proc = subprocess.Popen(
-            cmd,
-            shell=False,
-            stdout=subprocess.PIPE,
-            stderr=subprocess.PIPE,
-            close_fds=True,
-        )
-
-    def _process_fd(self, fd):
-        """
-        Prints output to logs and lookup for job ID in each line.
-
-        :param fd: File descriptor.
-        """
-        if fd == self._proc.stderr:
-            while True:
-                line = self._proc.stderr.readline().decode()
-                if not line:
-                    return
-                self._process_line_and_extract_job_id(line)
-                self.log.warning(line.rstrip("\n"))
-
-        if fd == self._proc.stdout:
-            while True:
-                line = self._proc.stdout.readline().decode()
-                if not line:
-                    return
-                self._process_line_and_extract_job_id(line)
-                self.log.info(line.rstrip("\n"))
-
-        raise Exception("No data in stderr or in stdout.")
-
-    def _process_line_and_extract_job_id(self, line: str) -> None:
-        """
-        Extracts job_id.
-
-        :param line: URL from which job_id has to be extracted
-        :type line: str
-        """
-        # Job id info: https://goo.gl/SE29y9.
-        matched_job = JOB_ID_PATTERN.search(line)
-        if matched_job:
-            job_id = matched_job.group("job_id_java") or matched_job.group("job_id_python")
-            self.log.info("Found Job ID: %s", job_id)
-            self.job_id = job_id
-            if self.on_new_job_id_callback:
-                self.on_new_job_id_callback(job_id)
-
-    def wait_for_done(self) -> Optional[str]:
-        """
-        Waits for Dataflow job to complete.
-
-        :return: Job id
-        :rtype: Optional[str]
-        """
-        self.log.info("Start waiting for DataFlow process to complete.")
-        self.job_id = None
-        reads = [self._proc.stderr, self._proc.stdout]
-        while True:
-            # Wait for at least one available fd.
-            readable_fds, _, _ = select.select(reads, [], [], 5)
-            if readable_fds is None:
-                self.log.info("Waiting for DataFlow process to complete.")
-                continue
-
-            for readable_fd in readable_fds:
-                self._process_fd(readable_fd)
-
-            if self._proc.poll() is not None:
-                break
-
-        # Corner case: check if more output was created between the last read and the process termination
-        for readable_fd in reads:
-            self._process_fd(readable_fd)
-
-        self.log.info("Process exited with return code: %s", self._proc.returncode)
-
-        if self._proc.returncode != 0:
-            raise Exception(f"DataFlow failed with return code {self._proc.returncode}")
-        return self.job_id
-
-
 class DataflowHook(GoogleBaseHook):
     """
     Hook for Google Dataflow.
@@ -598,6 +532,8 @@ class DataflowHook(GoogleBaseHook):
         self.drain_pipeline = drain_pipeline
         self.cancel_timeout = cancel_timeout
         self.wait_until_finished = wait_until_finished
+        self.job_id: Optional[str] = None
+        self.beam_hook = BeamHook(BeamRunnerType.DataflowRunner)
         super().__init__(
             gcp_conn_id=gcp_conn_id,
             delegate_to=delegate_to,
@@ -609,40 +545,6 @@ class DataflowHook(GoogleBaseHook):
         http_authorized = self._authorize()
         return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)
 
-    @GoogleBaseHook.provide_gcp_credential_file
-    def _start_dataflow(
-        self,
-        variables: dict,
-        name: str,
-        command_prefix: List[str],
-        project_id: str,
-        multiple_jobs: bool = False,
-        on_new_job_id_callback: Optional[Callable[[str], None]] = None,
-        location: str = DEFAULT_DATAFLOW_LOCATION,
-    ) -> None:
-        cmd = command_prefix + [
-            "--runner=DataflowRunner",
-            f"--project={project_id}",
-        ]
-        if variables:
-            cmd.extend(self._options_to_args(variables))
-        runner = _DataflowRunner(cmd=cmd, on_new_job_id_callback=on_new_job_id_callback)
-        job_id = runner.wait_for_done()
-        job_controller = _DataflowJobsController(
-            dataflow=self.get_conn(),
-            project_number=project_id,
-            name=name,
-            location=location,
-            poll_sleep=self.poll_sleep,
-            job_id=job_id,
-            num_retries=self.num_retries,
-            multiple_jobs=multiple_jobs,
-            drain_pipeline=self.drain_pipeline,
-            cancel_timeout=self.cancel_timeout,
-            wait_until_finished=self.wait_until_finished,
-        )
-        job_controller.wait_for_done()
-
     @_fallback_to_location_from_variables
     @_fallback_to_project_id_from_variables
     @GoogleBaseHook.fallback_to_default_project_id
@@ -680,22 +582,36 @@ class DataflowHook(GoogleBaseHook):
         :param location: Job location.
         :type location: str
         """
-        name = self._build_dataflow_job_name(job_name, append_job_name)
+        warnings.warn(
+            """"This method is deprecated.
+            Please use `airflow.providers.apache.beam.hooks.beam.start.start_java_pipeline`
+            to start pipeline and `providers.google.cloud.hooks.dataflow.DataflowHook.wait_for_done`
+            to wait for the required pipeline state.
+            """,
+            DeprecationWarning,
+            stacklevel=3,
+        )
+
+        name = self.build_dataflow_job_name(job_name, append_job_name)
+
         variables["jobName"] = name
         variables["region"] = location
+        variables["project"] = project_id
 
         if "labels" in variables:
             variables["labels"] = json.dumps(variables["labels"], separators=(",", ":"))
 
-        command_prefix = ["java", "-cp", jar, job_class] if job_class else ["java", "-jar", jar]
-        self._start_dataflow(
+        self.beam_hook.start_java_pipeline(
             variables=variables,
-            name=name,
-            command_prefix=command_prefix,
-            project_id=project_id,
-            multiple_jobs=multiple_jobs,
-            on_new_job_id_callback=on_new_job_id_callback,
+            jar=jar,
+            job_class=job_class,
+            process_line_callback=process_line_and_extract_dataflow_job_id_callback(on_new_job_id_callback),
+        )
+        self.wait_for_done(  # pylint: disable=no-value-for-parameter
+            job_name=name,
             location=location,
+            job_id=self.job_id,
+            multiple_jobs=multiple_jobs,
         )
 
     @_fallback_to_location_from_variables
@@ -748,7 +664,7 @@ class DataflowHook(GoogleBaseHook):
 
         :type environment: Optional[dict]
         """
-        name = self._build_dataflow_job_name(job_name, append_job_name)
+        name = self.build_dataflow_job_name(job_name, append_job_name)
 
         environment = environment or {}
         # available keys for runtime environment are listed here:
@@ -921,58 +837,40 @@ class DataflowHook(GoogleBaseHook):
         :param location: Job location.
         :type location: str
         """
-        name = self._build_dataflow_job_name(job_name, append_job_name)
+        warnings.warn(
+            """This method is deprecated.
+            Please use `airflow.providers.apache.beam.hooks.beam.start.start_python_pipeline`
+            to start pipeline and `providers.google.cloud.hooks.dataflow.DataflowHook.wait_for_done`
+            to wait for the required pipeline state.
+            """,
+            DeprecationWarning,
+            stacklevel=3,
+        )
+
+        name = self.build_dataflow_job_name(job_name, append_job_name)
         variables["job_name"] = name
         variables["region"] = location
+        variables["project"] = project_id
 
-        if "labels" in variables:
-            variables["labels"] = [f"{key}={value}" for key, value in variables["labels"].items()]
-
-        if py_requirements is not None:
-            if not py_requirements and not py_system_site_packages:
-                warning_invalid_environment = textwrap.dedent(
-                    """\
-                    Invalid method invocation. You have disabled inclusion of system packages and empty list
-                    required for installation, so it is not possible to create a valid virtual environment.
-                    In the virtual environment, apache-beam package must be installed for your job to be \
-                    executed. To fix this problem:
-                    * install apache-beam on the system, then set parameter py_system_site_packages to True,
-                    * add apache-beam to the list of required packages in parameter py_requirements.
-                    """
-                )
-                raise AirflowException(warning_invalid_environment)
-
-            with TemporaryDirectory(prefix="dataflow-venv") as tmp_dir:
-                py_interpreter = prepare_virtualenv(
-                    venv_directory=tmp_dir,
-                    python_bin=py_interpreter,
-                    system_site_packages=py_system_site_packages,
-                    requirements=py_requirements,
-                )
-                command_prefix = [py_interpreter] + py_options + [dataflow]
-
-                self._start_dataflow(
-                    variables=variables,
-                    name=name,
-                    command_prefix=command_prefix,
-                    project_id=project_id,
-                    on_new_job_id_callback=on_new_job_id_callback,
-                    location=location,
-                )
-        else:
-            command_prefix = [py_interpreter] + py_options + [dataflow]
-
-            self._start_dataflow(
-                variables=variables,
-                name=name,
-                command_prefix=command_prefix,
-                project_id=project_id,
-                on_new_job_id_callback=on_new_job_id_callback,
-                location=location,
-            )
+        self.beam_hook.start_python_pipeline(
+            variables=variables,
+            py_file=dataflow,
+            py_options=py_options,
+            py_interpreter=py_interpreter,
+            py_requirements=py_requirements,
+            py_system_site_packages=py_system_site_packages,
+            process_line_callback=process_line_and_extract_dataflow_job_id_callback(on_new_job_id_callback),
+        )
+
+        self.wait_for_done(  # pylint: disable=no-value-for-parameter
+            job_name=name,
+            location=location,
+            job_id=self.job_id,
+        )
 
     @staticmethod
-    def _build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str:
+    def build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str:
+        """Builds Dataflow job name."""
         base_job_name = str(job_name).replace("_", "-")
 
         if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", base_job_name):
@@ -989,23 +887,6 @@ class DataflowHook(GoogleBaseHook):
 
         return safe_job_name
 
-    @staticmethod
-    def _options_to_args(variables: dict) -> List[str]:
-        if not variables:
-            return []
-        # The logic of this method should be compatible with Apache Beam:
-        # https://github.com/apache/beam/blob/b56740f0e8cd80c2873412847d0b336837429fb9/sdks/python/
-        # apache_beam/options/pipeline_options.py#L230-L251
-        args: List[str] = []
-        for attr, value in variables.items():
-            if value is None or (isinstance(value, bool) and value):
-                args.append(f"--{attr}")
-            elif isinstance(value, list):
-                args.extend([f"--{attr}={v}" for v in value])
-            else:
-                args.append(f"--{attr}={value}")
-        return args
-
     @_fallback_to_location_from_variables
     @_fallback_to_project_id_from_variables
     @GoogleBaseHook.fallback_to_default_project_id
@@ -1125,7 +1006,7 @@ class DataflowHook(GoogleBaseHook):
             "--format=value(job.id)",
             f"--job-name={job_name}",
             f"--region={location}",
-            *(self._options_to_args(options)),
+            *(beam_options_to_args(options)),
         ]
         self.log.info("Executing command: %s", " ".join([shlex.quote(c) for c in cmd]))
         with self.provide_authorized_gcloud():
@@ -1266,3 +1147,44 @@ class DataflowHook(GoogleBaseHook):
             location=location,
         )
         return jobs_controller.fetch_job_autoscaling_events_by_id(job_id)
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def wait_for_done(
+        self,
+        job_name: str,
+        location: str,
+        project_id: str,
+        job_id: Optional[str] = None,
+        multiple_jobs: bool = False,
+    ) -> None:
+        """
+        Wait for Dataflow job.
+
+        :param job_name: The 'jobName' to use when executing the DataFlow job
+            (templated). This ends up being set in the pipeline options, so any entry
+            with key ``'jobName'`` in ``options`` will be overwritten.
+        :type job_name: str
+        :param location: location the job is running
+        :type location: str
+        :param project_id: Optional, the Google Cloud project ID in which to start a job.
+            If set to None or missing, the default project_id from the Google Cloud connection is used.
+        :type project_id:
+        :param job_id: a Dataflow job ID
+        :type job_id: str
+        :param multiple_jobs: If pipeline creates multiple jobs then monitor all jobs
+        :type multiple_jobs: boolean
+        """
+        job_controller = _DataflowJobsController(
+            dataflow=self.get_conn(),
+            project_number=project_id,
+            name=job_name,
+            location=location,
+            poll_sleep=self.poll_sleep,
+            job_id=job_id or self.job_id,
+            num_retries=self.num_retries,
+            multiple_jobs=multiple_jobs,
+            drain_pipeline=self.drain_pipeline,
+            cancel_timeout=self.cancel_timeout,
+            wait_until_finished=self.wait_until_finished,
+        )
+        job_controller.wait_for_done()
diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py
index 49863dc..f977704 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -16,15 +16,20 @@
 # specific language governing permissions and limitations
 # under the License.
 """This module contains Google Dataflow operators."""
-
 import copy
 import re
+import warnings
 from contextlib import ExitStack
 from enum import Enum
 from typing import Any, Dict, List, Optional, Sequence, Union
 
 from airflow.models import BaseOperator
-from airflow.providers.google.cloud.hooks.dataflow import DEFAULT_DATAFLOW_LOCATION, DataflowHook
+from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
+from airflow.providers.google.cloud.hooks.dataflow import (
+    DEFAULT_DATAFLOW_LOCATION,
+    DataflowHook,
+    process_line_and_extract_dataflow_job_id_callback,
+)
 from airflow.providers.google.cloud.hooks.gcs import GCSHook
 from airflow.utils.decorators import apply_defaults
 from airflow.version import version
@@ -43,12 +48,137 @@ class CheckJobRunning(Enum):
     WaitForRun = 3
 
 
+class DataflowConfiguration:
+    """Dataflow configuration that can be passed to
+    :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator` and
+    :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`.
+
+    :param job_name: The 'jobName' to use when executing the DataFlow job
+        (templated). This ends up being set in the pipeline options, so any entry
+        with key ``'jobName'`` or  ``'job_name'``in ``options`` will be overwritten.
+    :type job_name: str
+    :param append_job_name: True if unique suffix has to be appended to job name.
+    :type append_job_name: bool
+    :param project_id: Optional, the Google Cloud project ID in which to start a job.
+        If set to None or missing, the default project_id from the Google Cloud connection is used.
+    :type project_id: str
+    :param location: Job location.
+    :type location: str
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :type gcp_conn_id: str
+    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
+        if any. For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: str
+    :param poll_sleep: The time in seconds to sleep between polling Google
+        Cloud Platform for the dataflow job status while the job is in the
+        JOB_STATE_RUNNING state.
+    :type poll_sleep: int
+    :param impersonation_chain: Optional service account to impersonate using short-term
+        credentials, or chained list of accounts required to get the access_token
+        of the last account in the list, which will be impersonated in the request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding identity, with first
+        account from the list granting this role to the originating account (templated).
+    :type impersonation_chain: Union[str, Sequence[str]]
+    :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
+        instead of canceling during during killing task instance. See:
+        https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+    :type drain_pipeline: bool
+    :param cancel_timeout: How long (in seconds) operator should wait for the pipeline to be
+        successfully cancelled when task is being killed.
+    :type cancel_timeout: Optional[int]
+    :param wait_until_finished: (Optional)
+        If True, wait for the end of pipeline execution before exiting.
+        If False, only submits job.
+        If None, default behavior.
+
+        The default behavior depends on the type of pipeline:
+
+        * for the streaming pipeline, wait for jobs to start,
+        * for the batch pipeline, wait for the jobs to complete.
+
+        .. warning::
+
+            You cannot call ``PipelineResult.wait_until_finish`` method in your pipeline code for the operator
+            to work properly. i. e. you must use asynchronous execution. Otherwise, your pipeline will
+            always wait until finished. For more information, look at:
+            `Asynchronous execution
+            <https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#python_10>`__
+
+        The process of starting the Dataflow job in Airflow consists of two steps:
+
+        * running a subprocess and reading the stderr/stderr log for the job id.
+        * loop waiting for the end of the job ID from the previous step.
+          This loop checks the status of the job.
+
+        Step two is started just after step one has finished, so if you have wait_until_finished in your
+        pipeline code, step two will not start until the process stops. When this process stops,
+        steps two will run, but it will only execute one iteration as the job will be in a terminal state.
+
+        If you in your pipeline do not call the wait_for_pipeline method but pass wait_until_finish=True
+        to the operator, the second loop will wait for the job's terminal state.
+
+        If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False
+        to the operator, the second loop will check once is job not in terminal state and exit the loop.
+    :type wait_until_finished: Optional[bool]
+    :param multiple_jobs: If pipeline creates multiple jobs then monitor all jobs. Supported only by
+        :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+    :type multiple_jobs: boolean
+    :param check_if_running: Before running job, validate that a previous run is not in process.
+        IgnoreJob = do not check if running.
+        FinishIfRunning = if job is running finish with nothing.
+        WaitForRun = wait until job finished and the run job.
+        Supported only by:
+        :py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+    :type check_if_running: CheckJobRunning
+    """
+
+    template_fields = ["job_name", "location"]
+
+    def __init__(
+        self,
+        *,
+        job_name: Optional[str] = "{{task.task_id}}",
+        append_job_name: bool = True,
+        project_id: Optional[str] = None,
+        location: Optional[str] = DEFAULT_DATAFLOW_LOCATION,
+        gcp_conn_id: str = "google_cloud_default",
+        delegate_to: Optional[str] = None,
+        poll_sleep: int = 10,
+        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        drain_pipeline: bool = False,
+        cancel_timeout: Optional[int] = 5 * 60,
+        wait_until_finished: Optional[bool] = None,
+        multiple_jobs: Optional[bool] = None,
+        check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun,
+    ) -> None:
+        self.job_name = job_name
+        self.append_job_name = append_job_name
+        self.project_id = project_id
+        self.location = location
+        self.gcp_conn_id = gcp_conn_id
+        self.delegate_to = delegate_to
+        self.poll_sleep = poll_sleep
+        self.impersonation_chain = impersonation_chain
+        self.drain_pipeline = drain_pipeline
+        self.cancel_timeout = cancel_timeout
+        self.wait_until_finished = wait_until_finished
+        self.multiple_jobs = multiple_jobs
+        self.check_if_running = check_if_running
+
+
 # pylint: disable=too-many-instance-attributes
 class DataflowCreateJavaJobOperator(BaseOperator):
     """
     Start a Java Cloud DataFlow batch job. The parameters of the operation
     will be passed to the job.
 
+    This class is deprecated.
+    Please use `providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`.
+
     **Example**: ::
 
         default_args = {
@@ -235,6 +365,14 @@ class DataflowCreateJavaJobOperator(BaseOperator):
         wait_until_finished: Optional[bool] = None,
         **kwargs,
     ) -> None:
+        # TODO: Remove one day
+        warnings.warn(
+            "The `{cls}` operator is deprecated, please use "
+            "`providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator` instead."
+            "".format(cls=self.__class__.__name__),
+            DeprecationWarning,
+            stacklevel=2,
+        )
         super().__init__(**kwargs)
 
         dataflow_default_options = dataflow_default_options or {}
@@ -257,62 +395,83 @@ class DataflowCreateJavaJobOperator(BaseOperator):
         self.cancel_timeout = cancel_timeout
         self.wait_until_finished = wait_until_finished
         self.job_id = None
-        self.hook = None
+        self.beam_hook: Optional[BeamHook] = None
+        self.dataflow_hook: Optional[DataflowHook] = None
 
     def execute(self, context):
-        self.hook = DataflowHook(
+        """Execute the Apache Beam Pipeline."""
+        self.beam_hook = BeamHook(runner=BeamRunnerType.DataflowRunner)
+        self.dataflow_hook = DataflowHook(
             gcp_conn_id=self.gcp_conn_id,
             delegate_to=self.delegate_to,
             poll_sleep=self.poll_sleep,
             cancel_timeout=self.cancel_timeout,
             wait_until_finished=self.wait_until_finished,
         )
-        dataflow_options = copy.copy(self.dataflow_default_options)
-        dataflow_options.update(self.options)
-        is_running = False
-        if self.check_if_running != CheckJobRunning.IgnoreJob:
-            is_running = self.hook.is_job_dataflow_running(  # type: ignore[attr-defined]
-                name=self.job_name,
-                variables=dataflow_options,
-                project_id=self.project_id,
-                location=self.location,
-            )
-            while is_running and self.check_if_running == CheckJobRunning.WaitForRun:
-                is_running = self.hook.is_job_dataflow_running(  # type: ignore[attr-defined]
-                    name=self.job_name,
-                    variables=dataflow_options,
-                    project_id=self.project_id,
-                    location=self.location,
-                )
+        job_name = self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name)
+        pipeline_options = copy.deepcopy(self.dataflow_default_options)
+
+        pipeline_options["jobName"] = self.job_name
+        pipeline_options["project"] = self.project_id or self.dataflow_hook.project_id
+        pipeline_options["region"] = self.location
+        pipeline_options.update(self.options)
+        pipeline_options.setdefault("labels", {}).update(
+            {"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
+        )
+        pipeline_options.update(self.options)
 
-        if not is_running:
-            with ExitStack() as exit_stack:
-                if self.jar.lower().startswith("gs://"):
-                    gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
-                    tmp_gcs_file = exit_stack.enter_context(  # pylint: disable=no-member
-                        gcs_hook.provide_file(object_url=self.jar)
-                    )
-                    self.jar = tmp_gcs_file.name
-
-                def set_current_job_id(job_id):
-                    self.job_id = job_id
-
-                self.hook.start_java_dataflow(  # type: ignore[attr-defined]
-                    job_name=self.job_name,
-                    variables=dataflow_options,
-                    jar=self.jar,
-                    job_class=self.job_class,
-                    append_job_name=True,
-                    multiple_jobs=self.multiple_jobs,
-                    on_new_job_id_callback=set_current_job_id,
-                    project_id=self.project_id,
-                    location=self.location,
+        def set_current_job_id(job_id):
+            self.job_id = job_id
+
+        process_line_callback = process_line_and_extract_dataflow_job_id_callback(
+            on_new_job_id_callback=set_current_job_id
+        )
+
+        with ExitStack() as exit_stack:
+            if self.jar.lower().startswith("gs://"):
+                gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
+                tmp_gcs_file = exit_stack.enter_context(  # pylint: disable=no-member
+                    gcs_hook.provide_file(object_url=self.jar)
                 )
+                self.jar = tmp_gcs_file.name
+
+                is_running = False
+                if self.check_if_running != CheckJobRunning.IgnoreJob:
+                    is_running = (
+                        self.dataflow_hook.is_job_dataflow_running(  # pylint: disable=no-value-for-parameter
+                            name=self.job_name,
+                            variables=pipeline_options,
+                        )
+                    )
+                    while is_running and self.check_if_running == CheckJobRunning.WaitForRun:
+                        # pylint: disable=no-value-for-parameter
+                        is_running = self.dataflow_hook.is_job_dataflow_running(
+                            name=self.job_name,
+                            variables=pipeline_options,
+                        )
+                if not is_running:
+                    pipeline_options["jobName"] = job_name
+                    self.beam_hook.start_java_pipeline(
+                        variables=pipeline_options,
+                        jar=self.jar,
+                        job_class=self.job_class,
+                        process_line_callback=process_line_callback,
+                    )
+                    self.dataflow_hook.wait_for_done(  # pylint: disable=no-value-for-parameter
+                        job_name=job_name,
+                        location=self.location,
+                        job_id=self.job_id,
+                        multiple_jobs=self.multiple_jobs,
+                    )
+
+        return {"job_id": self.job_id}
 
     def on_kill(self) -> None:
         self.log.info("On kill.")
         if self.job_id:
-            self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
+            self.dataflow_hook.cancel_job(
+                job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id
+            )
 
 
 # pylint: disable=too-many-instance-attributes
@@ -760,6 +919,9 @@ class DataflowCreatePythonJobOperator(BaseOperator):
     high-level options, for instances, project and zone information, which
     apply to all dataflow operators in the DAG.
 
+    This class is deprecated.
+    Please use `providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`.
+
     .. seealso::
         For more detail on job submission have a look at the reference:
         https://cloud.google.com/dataflow/pipelines/specifying-exec-params
@@ -886,7 +1048,14 @@ class DataflowCreatePythonJobOperator(BaseOperator):
         wait_until_finished: Optional[bool] = None,
         **kwargs,
     ) -> None:
-
+        # TODO: Remove one day
+        warnings.warn(
+            "The `{cls}` operator is deprecated, please use "
+            "`providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator` instead."
+            "".format(cls=self.__class__.__name__),
+            DeprecationWarning,
+            stacklevel=2,
+        )
         super().__init__(**kwargs)
 
         self.py_file = py_file
@@ -909,10 +1078,40 @@ class DataflowCreatePythonJobOperator(BaseOperator):
         self.cancel_timeout = cancel_timeout
         self.wait_until_finished = wait_until_finished
         self.job_id = None
-        self.hook: Optional[DataflowHook] = None
+        self.beam_hook: Optional[BeamHook] = None
+        self.dataflow_hook: Optional[DataflowHook] = None
 
     def execute(self, context):
         """Execute the python dataflow job."""
+        self.beam_hook = BeamHook(runner=BeamRunnerType.DataflowRunner)
+        self.dataflow_hook = DataflowHook(
+            gcp_conn_id=self.gcp_conn_id,
+            delegate_to=self.delegate_to,
+            poll_sleep=self.poll_sleep,
+            impersonation_chain=None,
+            drain_pipeline=self.drain_pipeline,
+            cancel_timeout=self.cancel_timeout,
+            wait_until_finished=self.wait_until_finished,
+        )
+
+        job_name = self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name)
+        pipeline_options = self.dataflow_default_options.copy()
+        pipeline_options["job_name"] = job_name
+        pipeline_options["project"] = self.project_id or self.dataflow_hook.project_id
+        pipeline_options["region"] = self.location
+        pipeline_options.update(self.options)
+
+        # Convert argument names from lowerCamelCase to snake case.
+        camel_to_snake = lambda name: re.sub(r"[A-Z]", lambda x: "_" + x.group(0).lower(), name)
+        formatted_pipeline_options = {camel_to_snake(key): pipeline_options[key] for key in pipeline_options}
+
+        def set_current_job_id(job_id):
+            self.job_id = job_id
+
+        process_line_callback = process_line_and_extract_dataflow_job_id_callback(
+            on_new_job_id_callback=set_current_job_id
+        )
+
         with ExitStack() as exit_stack:
             if self.py_file.lower().startswith("gs://"):
                 gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
@@ -921,38 +1120,28 @@ class DataflowCreatePythonJobOperator(BaseOperator):
                 )
                 self.py_file = tmp_gcs_file.name
 
-            self.hook = DataflowHook(
-                gcp_conn_id=self.gcp_conn_id,
-                delegate_to=self.delegate_to,
-                poll_sleep=self.poll_sleep,
-                drain_pipeline=self.drain_pipeline,
-                cancel_timeout=self.cancel_timeout,
-                wait_until_finished=self.wait_until_finished,
-            )
-            dataflow_options = self.dataflow_default_options.copy()
-            dataflow_options.update(self.options)
-            # Convert argument names from lowerCamelCase to snake case.
-            camel_to_snake = lambda name: re.sub(r"[A-Z]", lambda x: "_" + x.group(0).lower(), name)
-            formatted_options = {camel_to_snake(key): dataflow_options[key] for key in dataflow_options}
-
-            def set_current_job_id(job_id):
-                self.job_id = job_id
-
-            self.hook.start_python_dataflow(  # type: ignore[attr-defined]
-                job_name=self.job_name,
-                variables=formatted_options,
-                dataflow=self.py_file,
+            self.beam_hook.start_python_pipeline(
+                variables=formatted_pipeline_options,
+                py_file=self.py_file,
                 py_options=self.py_options,
                 py_interpreter=self.py_interpreter,
                 py_requirements=self.py_requirements,
                 py_system_site_packages=self.py_system_site_packages,
-                on_new_job_id_callback=set_current_job_id,
-                project_id=self.project_id,
+                process_line_callback=process_line_callback,
+            )
+
+            self.dataflow_hook.wait_for_done(  # pylint: disable=no-value-for-parameter
+                job_name=job_name,
                 location=self.location,
+                job_id=self.job_id,
+                multiple_jobs=False,
             )
-            return {"job_id": self.job_id}
+
+        return {"job_id": self.job_id}
 
     def on_kill(self) -> None:
         self.log.info("On kill.")
         if self.job_id:
-            self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
+            self.dataflow_hook.cancel_job(
+                job_id=self.job_id, project_id=self.project_id or self.dataflow_hook.project_id
+            )
diff --git a/dev/provider_packages/copy_provider_package_sources.py b/dev/provider_packages/copy_provider_package_sources.py
index 1d10747..c7f75f5 100755
--- a/dev/provider_packages/copy_provider_package_sources.py
+++ b/dev/provider_packages/copy_provider_package_sources.py
@@ -703,6 +703,67 @@ class RefactorBackportPackages:
             .rename("airflow.models.baseoperator")
         )
 
+    def refactor_apache_beam_package(self):
+        r"""
+        Fixes to "apache_beam" providers package.
+
+        Copies some of the classes used from core Airflow to "common.utils" package of the
+        the provider and renames imports to use them from there. Note that in this case we also rename
+        the imports in the copied files.
+
+        For example we copy python_virtualenv.py, process_utils.py and change import as in example diff:
+
+        .. code-block:: diff
+
+            --- ./airflow/providers/apache/beam/common/utils/python_virtualenv.py
+            +++ ./airflow/providers/apache/beam/common/utils/python_virtualenv.py
+            @@ -21,7 +21,7 @@
+             \"\"\"
+            from typing import List, Optional
+
+            -from airflow.utils.process_utils import execute_in_subprocess
+            +from airflow.providers.apache.beam.common.utils.process_utils import execute_in_subprocess
+
+
+            def _generate_virtualenv_cmd(tmp_dir: str, python_bin: str, system_site_packages: bool)
+
+        """
+
+        def apache_beam_package_filter(node: LN, capture: Capture, filename: Filename) -> bool:
+            return filename.startswith("./airflow/providers/apache/beam")
+
+        os.makedirs(
+            os.path.join(get_target_providers_package_folder("apache.beam"), "common", "utils"), exist_ok=True
+        )
+        copyfile(
+            os.path.join(get_source_airflow_folder(), "airflow", "utils", "__init__.py"),
+            os.path.join(
+                get_target_providers_package_folder("apache.beam"), "common", "utils", "__init__.py"
+            ),
+        )
+        copyfile(
+            os.path.join(get_source_airflow_folder(), "airflow", "utils", "python_virtualenv.py"),
+            os.path.join(
+                get_target_providers_package_folder("apache.beam"), "common", "utils", "python_virtualenv.py"
+            ),
+        )
+        copyfile(
+            os.path.join(get_source_airflow_folder(), "airflow", "utils", "process_utils.py"),
+            os.path.join(
+                get_target_providers_package_folder("apache.beam"), "common", "utils", "process_utils.py"
+            ),
+        )
+        (
+            self.qry.select_module("airflow.utils.python_virtualenv")
+            .filter(callback=apache_beam_package_filter)
+            .rename("airflow.providers.apache.beam.common.utils.python_virtualenv")
+        )
+        (
+            self.qry.select_module("airflow.utils.process_utils")
+            .filter(callback=apache_beam_package_filter)
+            .rename("airflow.providers.apache.beam.common.utils.process_utils")
+        )
+
     def refactor_odbc_package(self):
         """
         Fixes to "odbc" providers package.
@@ -760,6 +821,7 @@ class RefactorBackportPackages:
         self.rename_deprecated_modules()
         self.refactor_amazon_package()
         self.refactor_google_package()
+        self.refactor_apache_beam_package()
         self.refactor_elasticsearch_package()
         self.refactor_odbc_package()
         self.remove_tags()
diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py
index 322a57f..3cfc39f 100755
--- a/dev/provider_packages/prepare_provider_packages.py
+++ b/dev/provider_packages/prepare_provider_packages.py
@@ -790,8 +790,10 @@ def convert_git_changes_to_table(
                 f"`{message_without_backticks}`" if markdown else f"``{message_without_backticks}``",
             )
         )
-    table = tabulate(table_data, headers=headers, tablefmt="pipe" if markdown else "rst")
     header = ""
+    if not table_data:
+        return header
+    table = tabulate(table_data, headers=headers, tablefmt="pipe" if markdown else "rst")
     if not markdown:
         header += f"\n\n{print_version}\n" + "." * len(print_version) + "\n\n"
         release_date = table_data[0][1]
diff --git a/docs/apache-airflow-providers-apache-beam/index.rst b/docs/apache-airflow-providers-apache-beam/index.rst
new file mode 100644
index 0000000..30718f9
--- /dev/null
+++ b/docs/apache-airflow-providers-apache-beam/index.rst
@@ -0,0 +1,36 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+``apache-airflow-providers-apache-beam``
+========================================
+
+Content
+-------
+
+.. toctree::
+    :maxdepth: 1
+    :caption: References
+
+    Python API <_api/airflow/providers/apache/beam/index>
+    PyPI Repository <https://pypi.org/project/apache-airflow-providers-apache-beam/>
+    Example DAGs <https://github.com/apache/airflow/tree/master/airflow/providers/apache/beam/example_dags>
+
+.. toctree::
+    :maxdepth: 1
+    :caption: Guides
+
+    Operators <operators>
diff --git a/docs/apache-airflow-providers-apache-beam/operators.rst b/docs/apache-airflow-providers-apache-beam/operators.rst
new file mode 100644
index 0000000..3c1b2bd
--- /dev/null
+++ b/docs/apache-airflow-providers-apache-beam/operators.rst
@@ -0,0 +1,116 @@
+
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+Apache Beam Operators
+=====================
+
+`Apache Beam <https://beam.apache.org/>`__ is an open source, unified model for defining both batch and
+streaming data-parallel processing pipelines. Using one of the open source Beam SDKs, you build a program
+that defines the pipeline. The pipeline is then executed by one of Beam’s supported distributed processing
+back-ends, which include Apache Flink, Apache Spark, and Google Cloud Dataflow.
+
+
+.. _howto/operator:BeamRunPythonPipelineOperator:
+
+Run Python Pipelines in Apache Beam
+===================================
+
+The ``py_file`` argument must be specified for
+:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`
+as it contains the pipeline to be executed by Beam. The Python file can be available on GCS that Airflow
+has the ability to download or available on the local filesystem (provide the absolute path to it).
+
+The ``py_interpreter`` argument specifies the Python version to be used when executing the pipeline, the default
+is ``python3`. If your Airflow instance is running on Python 2 - specify ``python2`` and ensure your ``py_file`` is
+in Python 2. For best results, use Python 3.
+
+If ``py_requirements`` argument is specified a temporary Python virtual environment with specified requirements will be created
+and within it pipeline will run.
+
+The ``py_system_site_packages`` argument specifies whether or not all the Python packages from your Airflow instance,
+will be accessible within virtual environment (if ``py_requirements`` argument is specified),
+recommend avoiding unless the Dataflow job requires it.
+
+Python Pipelines with DirectRunner
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_python_direct_runner_pipeline_local_file]
+    :end-before: [END howto_operator_start_python_direct_runner_pipeline_local_file]
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_python_direct_runner_pipeline_gcs_file]
+    :end-before: [END howto_operator_start_python_direct_runner_pipeline_gcs_file]
+
+Python Pipelines with DataflowRunner
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
+    :end-before: [END howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
+    :end-before: [END howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
+
+.. _howto/operator:BeamRunJavaPipelineOperator:
+
+Run Java Pipelines in Apache Beam
+=================================
+
+For Java pipeline the ``jar`` argument must be specified for
+:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+as it contains the pipeline to be executed by Apache Beam. The JAR can be available on GCS that Airflow
+has the ability to download or available on the local filesystem (provide the absolute path to it).
+
+Java Pipelines with DirectRunner
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_java_direct_runner_pipeline]
+    :end-before: [END howto_operator_start_java_direct_runner_pipeline
+
+Java Pipelines with DataflowRunner
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. exampleinclude:: /../../airflow/providers/apache/beam/example_dags/example_beam.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_java_dataflow_runner_pipeline]
+    :end-before: [END howto_operator_start_java_dataflow_runner_pipeline
+
+Reference
+^^^^^^^^^
+
+For further information, look at:
+
+* `Apache Beam Documentation <https://beam.apache.org/documentation/>`__
+* `Google Cloud API Documentation <https://cloud.google.com/dataflow/docs/apis>`__
+* `Product Documentation <https://cloud.google.com/dataflow/docs/>`__
+* `Dataflow Monitoring Interface <https://cloud.google.com/dataflow/docs/guides/using-monitoring-intf/>`__
+* `Dataflow Command-line Interface <https://cloud.google.com/dataflow/docs/guides/using-command-line-intf/>`__
diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst
index b2549ae..5221beb 100644
--- a/docs/apache-airflow/extra-packages-ref.rst
+++ b/docs/apache-airflow/extra-packages-ref.rst
@@ -107,6 +107,8 @@ custom bash/python providers).
 +=====================+=====================================================+================================================+
 | apache.atlas        | ``pip install 'apache-airflow[apache.atlas]'``      | Apache Atlas                                   |
 +---------------------+-----------------------------------------------------+------------------------------------------------+
+| apache.beam         | ``pip install 'apache-airflow[apache.beam]'``       | Apache Beam operators & hooks                  |
++---------------------+-----------------------------------------------------+------------------------------------------------+
 | apache.cassandra    | ``pip install 'apache-airflow[apache.cassandra]'``  | Cassandra related operators & hooks            |
 +---------------------+-----------------------------------------------------+------------------------------------------------+
 | apache.druid        | ``pip install 'apache-airflow[apache.druid]'``      | Druid related operators & hooks                |
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index db4342a..f8f8f83 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -141,6 +141,7 @@ Fileshares
 Filesystem
 Firehose
 Firestore
+Flink
 FluentD
 Fokko
 Formaturas
@@ -325,6 +326,7 @@ Seki
 Sendgrid
 Siddharth
 SlackHook
+Spark
 SparkPi
 SparkR
 SparkSQL
diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh
index 969fa29..9b951c7 100755
--- a/scripts/in_container/run_install_and_test_provider_packages.sh
+++ b/scripts/in_container/run_install_and_test_provider_packages.sh
@@ -95,7 +95,7 @@ function discover_all_provider_packages() {
     # Columns is to force it wider, so it doesn't wrap at 80 characters
     COLUMNS=180 airflow providers list
 
-    local expected_number_of_providers=62
+    local expected_number_of_providers=63
     local actual_number_of_providers
     actual_providers=$(airflow providers list --output yaml | grep package_name)
     actual_number_of_providers=$(wc -l <<<"$actual_providers")
diff --git a/setup.py b/setup.py
index 210b12f..50f6a2f 100644
--- a/setup.py
+++ b/setup.py
@@ -523,6 +523,7 @@ devel_hadoop = devel_minreq + hdfs + hive + kerberos + presto + webhdfs
 # Dict of all providers which are part of the Apache Airflow repository together with their requirements
 PROVIDERS_REQUIREMENTS: Dict[str, List[str]] = {
     'amazon': amazon,
+    'apache.beam': apache_beam,
     'apache.cassandra': cassandra,
     'apache.druid': druid,
     'apache.hdfs': hdfs,
diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py
index 7d80c58..39ee588 100644
--- a/tests/core/test_providers_manager.py
+++ b/tests/core/test_providers_manager.py
@@ -22,6 +22,7 @@ from airflow.providers_manager import ProvidersManager
 
 ALL_PROVIDERS = [
     'apache-airflow-providers-amazon',
+    'apache-airflow-providers-apache-beam',
     'apache-airflow-providers-apache-cassandra',
     'apache-airflow-providers-apache-druid',
     'apache-airflow-providers-apache-hdfs',
diff --git a/tests/providers/apache/beam/__init__.py b/tests/providers/apache/beam/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/providers/apache/beam/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/apache/beam/hooks/__init__.py b/tests/providers/apache/beam/hooks/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/providers/apache/beam/hooks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/apache/beam/hooks/test_beam.py b/tests/providers/apache/beam/hooks/test_beam.py
new file mode 100644
index 0000000..d0d713e
--- /dev/null
+++ b/tests/providers/apache/beam/hooks/test_beam.py
@@ -0,0 +1,271 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import copy
+import subprocess
+import unittest
+from unittest import mock
+from unittest.mock import MagicMock
+
+from parameterized import parameterized
+
+from airflow.exceptions import AirflowException
+from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner, BeamHook, beam_options_to_args
+
+PY_FILE = 'apache_beam.examples.wordcount'
+JAR_FILE = 'unitest.jar'
+JOB_CLASS = 'com.example.UnitTest'
+PY_OPTIONS = ['-m']
+TEST_JOB_ID = 'test-job-id'
+
+DEFAULT_RUNNER = "DirectRunner"
+BEAM_STRING = 'airflow.providers.apache.beam.hooks.beam.{}'
+BEAM_VARIABLES_PY = {'output': 'gs://test/output', 'labels': {'foo': 'bar'}}
+BEAM_VARIABLES_JAVA = {
+    'output': 'gs://test/output',
+    'labels': {'foo': 'bar'},
+}
+
+APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG = f""""\
+Dataflow SDK version: 2.14.0
+Jun 15, 2020 2:57:28 PM org.apache.beam.runners.dataflow.DataflowRunner run
+INFO: To access the Dataflow monitoring console, please navigate to https://console.cloud.google.com/dataflow\
+/jobsDetail/locations/europe-west3/jobs/{TEST_JOB_ID}?project=XXX
+Submitted job: {TEST_JOB_ID}
+Jun 15, 2020 2:57:28 PM org.apache.beam.runners.dataflow.DataflowRunner run
+INFO: To cancel the job using the 'gcloud' tool, run:
+> gcloud dataflow jobs --project=XXX cancel --region=europe-west3 {TEST_JOB_ID}
+"""
+
+
+class TestBeamHook(unittest.TestCase):
+    @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+    def test_start_python_pipeline(self, mock_runner):
+        hook = BeamHook(runner=DEFAULT_RUNNER)
+        wait_for_done = mock_runner.return_value.wait_for_done
+        process_line_callback = MagicMock()
+
+        hook.start_python_pipeline(  # pylint: disable=no-value-for-parameter
+            variables=copy.deepcopy(BEAM_VARIABLES_PY),
+            py_file=PY_FILE,
+            py_options=PY_OPTIONS,
+            process_line_callback=process_line_callback,
+        )
+
+        expected_cmd = [
+            "python3",
+            '-m',
+            PY_FILE,
+            f'--runner={DEFAULT_RUNNER}',
+            '--output=gs://test/output',
+            '--labels=foo=bar',
+        ]
+        mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+        wait_for_done.assert_called_once_with()
+
+    @parameterized.expand(
+        [
+            ('default_to_python3', 'python3'),
+            ('major_version_2', 'python2'),
+            ('major_version_3', 'python3'),
+            ('minor_version', 'python3.6'),
+        ]
+    )
+    @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+    def test_start_python_pipeline_with_custom_interpreter(self, _, py_interpreter, mock_runner):
+        hook = BeamHook(runner=DEFAULT_RUNNER)
+        wait_for_done = mock_runner.return_value.wait_for_done
+        process_line_callback = MagicMock()
+
+        hook.start_python_pipeline(  # pylint: disable=no-value-for-parameter
+            variables=copy.deepcopy(BEAM_VARIABLES_PY),
+            py_file=PY_FILE,
+            py_options=PY_OPTIONS,
+            py_interpreter=py_interpreter,
+            process_line_callback=process_line_callback,
+        )
+
+        expected_cmd = [
+            py_interpreter,
+            '-m',
+            PY_FILE,
+            f'--runner={DEFAULT_RUNNER}',
+            '--output=gs://test/output',
+            '--labels=foo=bar',
+        ]
+        mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+        wait_for_done.assert_called_once_with()
+
+    @parameterized.expand(
+        [
+            (['foo-bar'], False),
+            (['foo-bar'], True),
+            ([], True),
+        ]
+    )
+    @mock.patch(BEAM_STRING.format('prepare_virtualenv'))
+    @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+    def test_start_python_pipeline_with_non_empty_py_requirements_and_without_system_packages(
+        self, current_py_requirements, current_py_system_site_packages, mock_runner, mock_virtualenv
+    ):
+        hook = BeamHook(runner=DEFAULT_RUNNER)
+        wait_for_done = mock_runner.return_value.wait_for_done
+        mock_virtualenv.return_value = '/dummy_dir/bin/python'
+        process_line_callback = MagicMock()
+
+        hook.start_python_pipeline(  # pylint: disable=no-value-for-parameter
+            variables=copy.deepcopy(BEAM_VARIABLES_PY),
+            py_file=PY_FILE,
+            py_options=PY_OPTIONS,
+            py_requirements=current_py_requirements,
+            py_system_site_packages=current_py_system_site_packages,
+            process_line_callback=process_line_callback,
+        )
+
+        expected_cmd = [
+            '/dummy_dir/bin/python',
+            '-m',
+            PY_FILE,
+            f'--runner={DEFAULT_RUNNER}',
+            '--output=gs://test/output',
+            '--labels=foo=bar',
+        ]
+        mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+        wait_for_done.assert_called_once_with()
+        mock_virtualenv.assert_called_once_with(
+            venv_directory=mock.ANY,
+            python_bin="python3",
+            system_site_packages=current_py_system_site_packages,
+            requirements=current_py_requirements,
+        )
+
+    @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+    def test_start_python_pipeline_with_empty_py_requirements_and_without_system_packages(self, mock_runner):
+        hook = BeamHook(runner=DEFAULT_RUNNER)
+        wait_for_done = mock_runner.return_value.wait_for_done
+        process_line_callback = MagicMock()
+
+        with self.assertRaisesRegex(AirflowException, "Invalid method invocation."):
+            hook.start_python_pipeline(  # pylint: disable=no-value-for-parameter
+                variables=copy.deepcopy(BEAM_VARIABLES_PY),
+                py_file=PY_FILE,
+                py_options=PY_OPTIONS,
+                py_requirements=[],
+                process_line_callback=process_line_callback,
+            )
+
+        mock_runner.assert_not_called()
+        wait_for_done.assert_not_called()
+
+    @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+    def test_start_java_pipeline(self, mock_runner):
+        hook = BeamHook(runner=DEFAULT_RUNNER)
+        wait_for_done = mock_runner.return_value.wait_for_done
+        process_line_callback = MagicMock()
+
+        hook.start_java_pipeline(  # pylint: disable=no-value-for-parameter
+            jar=JAR_FILE,
+            variables=copy.deepcopy(BEAM_VARIABLES_JAVA),
+            process_line_callback=process_line_callback,
+        )
+
+        expected_cmd = [
+            'java',
+            '-jar',
+            JAR_FILE,
+            f'--runner={DEFAULT_RUNNER}',
+            '--output=gs://test/output',
+            '--labels={"foo":"bar"}',
+        ]
+        mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+        wait_for_done.assert_called_once_with()
+
+    @mock.patch(BEAM_STRING.format('BeamCommandRunner'))
+    def test_start_java_pipeline_with_job_class(self, mock_runner):
+        hook = BeamHook(runner=DEFAULT_RUNNER)
+        wait_for_done = mock_runner.return_value.wait_for_done
+        process_line_callback = MagicMock()
+
+        hook.start_java_pipeline(  # pylint: disable=no-value-for-parameter
+            jar=JAR_FILE,
+            variables=copy.deepcopy(BEAM_VARIABLES_JAVA),
+            job_class=JOB_CLASS,
+            process_line_callback=process_line_callback,
+        )
+
+        expected_cmd = [
+            'java',
+            '-cp',
+            JAR_FILE,
+            JOB_CLASS,
+            f'--runner={DEFAULT_RUNNER}',
+            '--output=gs://test/output',
+            '--labels={"foo":"bar"}',
+        ]
+        mock_runner.assert_called_once_with(cmd=expected_cmd, process_line_callback=process_line_callback)
+        wait_for_done.assert_called_once_with()
+
+
+class TestBeamRunner(unittest.TestCase):
+    @mock.patch('airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log')
+    @mock.patch('subprocess.Popen')
+    @mock.patch('select.select')
+    def test_beam_wait_for_done_logging(self, mock_select, mock_popen, mock_logging):
+        cmd = ['test', 'cmd']
+        mock_logging.info = MagicMock()
+        mock_logging.warning = MagicMock()
+        mock_proc = MagicMock()
+        mock_proc.stderr = MagicMock()
+        mock_proc.stderr.readlines = MagicMock(return_value=['test\n', 'error\n'])
+        mock_stderr_fd = MagicMock()
+        mock_proc.stderr.fileno = MagicMock(return_value=mock_stderr_fd)
+        mock_proc_poll = MagicMock()
+        mock_select.return_value = [[mock_stderr_fd]]
+
+        def poll_resp_error():
+            mock_proc.return_code = 1
+            return True
+
+        mock_proc_poll.side_effect = [None, poll_resp_error]
+        mock_proc.poll = mock_proc_poll
+        mock_popen.return_value = mock_proc
+        beam = BeamCommandRunner(cmd)
+        mock_logging.info.assert_called_once_with('Running command: %s', " ".join(cmd))
+        mock_popen.assert_called_once_with(
+            cmd,
+            shell=False,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            close_fds=True,
+        )
+        self.assertRaises(Exception, beam.wait_for_done)
+
+
+class TestBeamOptionsToArgs(unittest.TestCase):
+    @parameterized.expand(
+        [
+            ({"key": "val"}, ["--key=val"]),
+            ({"key": None}, ["--key"]),
+            ({"key": True}, ["--key"]),
+            ({"key": False}, ["--key=False"]),
+            ({"key": ["a", "b", "c"]}, ["--key=a", "--key=b", "--key=c"]),
+        ]
+    )
+    def test_beam_options_to_args(self, options, expected_args):
+        args = beam_options_to_args(options)
+        assert args == expected_args
diff --git a/tests/providers/apache/beam/operators/__init__.py b/tests/providers/apache/beam/operators/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/providers/apache/beam/operators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py
new file mode 100644
index 0000000..c31ff33
--- /dev/null
+++ b/tests/providers/apache/beam/operators/test_beam.py
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import unittest
+from unittest import mock
+
+from airflow.providers.apache.beam.operators.beam import (
+    BeamRunJavaPipelineOperator,
+    BeamRunPythonPipelineOperator,
+)
+from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration
+from airflow.version import version
+
+TASK_ID = 'test-beam-operator'
+DEFAULT_RUNNER = "DirectRunner"
+JOB_NAME = 'test-dataflow-pipeline-name'
+JOB_ID = 'test-dataflow-pipeline-id'
+JAR_FILE = 'gs://my-bucket/example/test.jar'
+JOB_CLASS = 'com.test.NotMain'
+PY_FILE = 'gs://my-bucket/my-object.py'
+PY_INTERPRETER = 'python3'
+PY_OPTIONS = ['-m']
+DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
+    'project': 'test',
+    'stagingLocation': 'gs://test/staging',
+}
+ADDITIONAL_OPTIONS = {'output': 'gs://test/output', 'labels': {'foo': 'bar'}}
+TEST_VERSION = f"v{version.replace('.', '-').replace('+', '-')}"
+EXPECTED_ADDITIONAL_OPTIONS = {
+    'output': 'gs://test/output',
+    'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+}
+
+
+class TestBeamRunPythonPipelineOperator(unittest.TestCase):
+    def setUp(self):
+        self.operator = BeamRunPythonPipelineOperator(
+            task_id=TASK_ID,
+            py_file=PY_FILE,
+            py_options=PY_OPTIONS,
+            default_pipeline_options=DEFAULT_OPTIONS_PYTHON,
+            pipeline_options=ADDITIONAL_OPTIONS,
+        )
+
+    def test_init(self):
+        """Test BeamRunPythonPipelineOperator instance is properly initialized."""
+        self.assertEqual(self.operator.task_id, TASK_ID)
+        self.assertEqual(self.operator.py_file, PY_FILE)
+        self.assertEqual(self.operator.runner, DEFAULT_RUNNER)
+        self.assertEqual(self.operator.py_options, PY_OPTIONS)
+        self.assertEqual(self.operator.py_interpreter, PY_INTERPRETER)
+        self.assertEqual(self.operator.default_pipeline_options, DEFAULT_OPTIONS_PYTHON)
+        self.assertEqual(self.operator.pipeline_options, EXPECTED_ADDITIONAL_OPTIONS)
+
+    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+    def test_exec_direct_runner(self, gcs_hook, beam_hook_mock):
+        """Test BeamHook is created and the right args are passed to
+        start_python_workflow.
+        """
+        start_python_hook = beam_hook_mock.return_value.start_python_pipeline
+        gcs_provide_file = gcs_hook.return_value.provide_file
+        self.operator.execute(None)
+        beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
+        expected_options = {
+            'project': 'test',
+            'staging_location': 'gs://test/staging',
+            'output': 'gs://test/output',
+            'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+        }
+        gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+        start_python_hook.assert_called_once_with(
+            variables=expected_options,
+            py_file=gcs_provide_file.return_value.__enter__.return_value.name,
+            py_options=PY_OPTIONS,
+            py_interpreter=PY_INTERPRETER,
+            py_requirements=None,
+            py_system_site_packages=False,
+            process_line_callback=None,
+        )
+
+    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+    def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock):
+        """Test DataflowHook is created and the right args are passed to
+        start_python_dataflow.
+        """
+        dataflow_config = DataflowConfiguration()
+        self.operator.runner = "DataflowRunner"
+        self.operator.dataflow_config = dataflow_config
+        gcs_provide_file = gcs_hook.return_value.provide_file
+        self.operator.execute(None)
+        job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
+        dataflow_hook_mock.assert_called_once_with(
+            gcp_conn_id=dataflow_config.gcp_conn_id,
+            delegate_to=dataflow_config.delegate_to,
+            poll_sleep=dataflow_config.poll_sleep,
+            impersonation_chain=dataflow_config.impersonation_chain,
+            drain_pipeline=dataflow_config.drain_pipeline,
+            cancel_timeout=dataflow_config.cancel_timeout,
+            wait_until_finished=dataflow_config.wait_until_finished,
+        )
+        expected_options = {
+            'project': dataflow_hook_mock.return_value.project_id,
+            'job_name': job_name,
+            'staging_location': 'gs://test/staging',
+            'output': 'gs://test/output',
+            'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+            'region': 'us-central1',
+        }
+        gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+        beam_hook_mock.return_value.start_python_pipeline.assert_called_once_with(
+            variables=expected_options,
+            py_file=gcs_provide_file.return_value.__enter__.return_value.name,
+            py_options=PY_OPTIONS,
+            py_interpreter=PY_INTERPRETER,
+            py_requirements=None,
+            py_system_site_packages=False,
+            process_line_callback=mock.ANY,
+        )
+        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+            job_id=self.operator.dataflow_job_id,
+            job_name=job_name,
+            location='us-central1',
+            multiple_jobs=False,
+        )
+
+    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+    def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __):
+        self.operator.runner = "DataflowRunner"
+        dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
+        self.operator.execute(None)
+        self.operator.dataflow_job_id = JOB_ID
+        self.operator.on_kill()
+        dataflow_cancel_job.assert_called_once_with(
+            job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
+        )
+
+    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+    def test_on_kill_direct_runner(self, _, dataflow_mock, __):
+        dataflow_cancel_job = dataflow_mock.return_value.cancel_job
+        self.operator.execute(None)
+        self.operator.on_kill()
+        dataflow_cancel_job.assert_not_called()
+
+
+class TestBeamRunJavaPipelineOperator(unittest.TestCase):
+    def setUp(self):
+        self.operator = BeamRunJavaPipelineOperator(
+            task_id=TASK_ID,
+            jar=JAR_FILE,
+            job_class=JOB_CLASS,
+            default_pipeline_options=DEFAULT_OPTIONS_JAVA,
+            pipeline_options=ADDITIONAL_OPTIONS,
+        )
+
+    def test_init(self):
+        """Test BeamRunJavaPipelineOperator instance is properly initialized."""
+        self.assertEqual(self.operator.task_id, TASK_ID)
+        self.assertEqual(self.operator.runner, DEFAULT_RUNNER)
+        self.assertEqual(self.operator.default_pipeline_options, DEFAULT_OPTIONS_JAVA)
+        self.assertEqual(self.operator.job_class, JOB_CLASS)
+        self.assertEqual(self.operator.jar, JAR_FILE)
+        self.assertEqual(self.operator.pipeline_options, ADDITIONAL_OPTIONS)
+
+    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+    def test_exec_direct_runner(self, gcs_hook, beam_hook_mock):
+        """Test BeamHook is created and the right args are passed to
+        start_java_workflow.
+        """
+        start_java_hook = beam_hook_mock.return_value.start_java_pipeline
+        gcs_provide_file = gcs_hook.return_value.provide_file
+        self.operator.execute(None)
+
+        beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
+        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
+        start_java_hook.assert_called_once_with(
+            variables={**DEFAULT_OPTIONS_JAVA, **ADDITIONAL_OPTIONS},
+            jar=gcs_provide_file.return_value.__enter__.return_value.name,
+            job_class=JOB_CLASS,
+            process_line_callback=None,
+        )
+
+    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+    def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock):
+        """Test DataflowHook is created and the right args are passed to
+        start_java_dataflow.
+        """
+        dataflow_config = DataflowConfiguration()
+        self.operator.runner = "DataflowRunner"
+        self.operator.dataflow_config = dataflow_config
+        gcs_provide_file = gcs_hook.return_value.provide_file
+        dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
+        self.operator.execute(None)
+        job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
+        self.assertEqual(job_name, self.operator._dataflow_job_name)
+        dataflow_hook_mock.assert_called_once_with(
+            gcp_conn_id=dataflow_config.gcp_conn_id,
+            delegate_to=dataflow_config.delegate_to,
+            poll_sleep=dataflow_config.poll_sleep,
+            impersonation_chain=dataflow_config.impersonation_chain,
+            drain_pipeline=dataflow_config.drain_pipeline,
+            cancel_timeout=dataflow_config.cancel_timeout,
+            wait_until_finished=dataflow_config.wait_until_finished,
+        )
+        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
+
+        expected_options = {
+            'project': dataflow_hook_mock.return_value.project_id,
+            'jobName': job_name,
+            'stagingLocation': 'gs://test/staging',
+            'region': 'us-central1',
+            'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+            'output': 'gs://test/output',
+        }
+
+        beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with(
+            variables=expected_options,
+            jar=gcs_provide_file.return_value.__enter__.return_value.name,
+            job_class=JOB_CLASS,
+            process_line_callback=mock.ANY,
+        )
+        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+            job_id=self.operator.dataflow_job_id,
+            job_name=job_name,
+            location='us-central1',
+            multiple_jobs=dataflow_config.multiple_jobs,
+            project_id=dataflow_hook_mock.return_value.project_id,
+        )
+
+    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+    def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __):
+        self.operator.runner = "DataflowRunner"
+        dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
+        dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
+        self.operator.execute(None)
+        self.operator.dataflow_job_id = JOB_ID
+        self.operator.on_kill()
+        dataflow_cancel_job.assert_called_once_with(
+            job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id
+        )
+
+    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
+    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
+    def test_on_kill_direct_runner(self, _, dataflow_mock, __):
+        dataflow_cancel_job = dataflow_mock.return_value.cancel_job
+        self.operator.execute(None)
+        self.operator.on_kill()
+        dataflow_cancel_job.assert_not_called()
diff --git a/tests/providers/apache/beam/operators/test_beam_system.py b/tests/providers/apache/beam/operators/test_beam_system.py
new file mode 100644
index 0000000..0798f35
--- /dev/null
+++ b/tests/providers/apache/beam/operators/test_beam_system.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+
+import pytest
+
+from tests.test_utils import AIRFLOW_MAIN_FOLDER
+from tests.test_utils.system_tests_class import SystemTest
+
+BEAM_DAG_FOLDER = os.path.join(AIRFLOW_MAIN_FOLDER, "airflow", "providers", "apache", "beam", "example_dags")
+
+
+@pytest.mark.system("apache.beam")
+class BeamExampleDagsSystemTest(SystemTest):
+    def test_run_example_dag_beam_python(self):
+        self.run_dag('example_beam_native_python', BEAM_DAG_FOLDER)
+
+    def test_run_example_dag_beam_python_dataflow_async(self):
+        self.run_dag('example_beam_native_python_dataflow_async', BEAM_DAG_FOLDER)
+
+    def test_run_example_dag_beam_java_direct_runner(self):
+        self.run_dag('example_beam_native_java_direct_runner', BEAM_DAG_FOLDER)
+
+    def test_run_example_dag_beam_java_dataflow_runner(self):
+        self.run_dag('example_beam_native_java_dataflow_runner', BEAM_DAG_FOLDER)
+
+    def test_run_example_dag_beam_java_spark_runner(self):
+        self.run_dag('example_beam_native_java_spark_runner', BEAM_DAG_FOLDER)
+
+    def test_run_example_dag_beam_java_flink_runner(self):
+        self.run_dag('example_beam_native_java_flink_runner', BEAM_DAG_FOLDER)
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py
index 5297b30..c0da030 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -30,16 +30,20 @@ import pytest
 from parameterized import parameterized
 
 from airflow.exceptions import AirflowException
+from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner, BeamHook
 from airflow.providers.google.cloud.hooks.dataflow import (
     DEFAULT_DATAFLOW_LOCATION,
     DataflowHook,
     DataflowJobStatus,
     DataflowJobType,
     _DataflowJobsController,
-    _DataflowRunner,
     _fallback_to_project_id_from_variables,
+    process_line_and_extract_dataflow_job_id_callback,
 )
 
+DEFAULT_RUNNER = "DirectRunner"
+BEAM_STRING = 'airflow.providers.apache.beam.hooks.beam.{}'
+
 TASK_ID = 'test-dataflow-operator'
 JOB_NAME = 'test-dataflow-pipeline'
 MOCK_UUID = UUID('cf4a56d2-8101-4217-b027-2af6216feb48')
@@ -183,6 +187,7 @@ class TestDataflowHook(unittest.TestCase):
     def setUp(self):
         with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init):
             self.dataflow_hook = DataflowHook(gcp_conn_id='test')
+            self.dataflow_hook.beam_hook = MagicMock()
 
     @mock.patch("airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize")
     @mock.patch("airflow.providers.google.cloud.hooks.dataflow.build")
@@ -194,186 +199,229 @@ class TestDataflowHook(unittest.TestCase):
         assert mock_build.return_value == result
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
-    def test_start_python_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
+    def test_start_python_dataflow(self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid):
+        mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME,
-            variables=DATAFLOW_VARIABLES_PY,
-            dataflow=PY_FILE,
+        on_new_job_id_callback = MagicMock()
+        py_requirements = ["pands", "numpy"]
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=DATAFLOW_VARIABLES_PY,
+                dataflow=PY_FILE,
+                py_options=PY_OPTIONS,
+                py_interpreter=DEFAULT_PY_INTERPRETER,
+                py_requirements=py_requirements,
+                on_new_job_id_callback=on_new_job_id_callback,
+            )
+
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+        expected_variables["job_name"] = job_name
+        expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_python_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            py_file=PY_FILE,
+            py_interpreter=DEFAULT_PY_INTERPRETER,
             py_options=PY_OPTIONS,
+            py_requirements=py_requirements,
+            py_system_site_packages=False,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
         )
-        expected_cmd = [
-            "python3",
-            '-m',
-            PY_FILE,
-            '--region=us-central1',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--labels=foo=bar',
-            '--staging_location=gs://test/staging',
-            f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
     def test_start_python_dataflow_with_custom_region_as_variable(
-        self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+        self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
     ):
+        mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
-        variables['region'] = TEST_LOCATION
-        self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME,
-            variables=variables,
-            dataflow=PY_FILE,
+        on_new_job_id_callback = MagicMock()
+        py_requirements = ["pands", "numpy"]
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+        passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+        passed_variables["region"] = TEST_LOCATION
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=passed_variables,
+                dataflow=PY_FILE,
+                py_options=PY_OPTIONS,
+                py_interpreter=DEFAULT_PY_INTERPRETER,
+                py_requirements=py_requirements,
+                on_new_job_id_callback=on_new_job_id_callback,
+            )
+
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+        expected_variables["job_name"] = job_name
+        expected_variables["region"] = TEST_LOCATION
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_python_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            py_file=PY_FILE,
+            py_interpreter=DEFAULT_PY_INTERPRETER,
             py_options=PY_OPTIONS,
+            py_requirements=py_requirements,
+            py_system_site_packages=False,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION
         )
-        expected_cmd = [
-            "python3",
-            '-m',
-            PY_FILE,
-            f'--region={TEST_LOCATION}',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--labels=foo=bar',
-            '--staging_location=gs://test/staging',
-            f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
     def test_start_python_dataflow_with_custom_region_as_parameter(
-        self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+        self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
     ):
+        mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME,
-            variables=DATAFLOW_VARIABLES_PY,
-            dataflow=PY_FILE,
+        on_new_job_id_callback = MagicMock()
+        py_requirements = ["pands", "numpy"]
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+        passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=passed_variables,
+                dataflow=PY_FILE,
+                py_options=PY_OPTIONS,
+                py_interpreter=DEFAULT_PY_INTERPRETER,
+                py_requirements=py_requirements,
+                on_new_job_id_callback=on_new_job_id_callback,
+                location=TEST_LOCATION,
+            )
+
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+        expected_variables["job_name"] = job_name
+        expected_variables["region"] = TEST_LOCATION
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_python_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            py_file=PY_FILE,
+            py_interpreter=DEFAULT_PY_INTERPRETER,
             py_options=PY_OPTIONS,
-            location=TEST_LOCATION,
+            py_requirements=py_requirements,
+            py_system_site_packages=False,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION
         )
-        expected_cmd = [
-            "python3",
-            '-m',
-            PY_FILE,
-            f'--region={TEST_LOCATION}',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--labels=foo=bar',
-            '--staging_location=gs://test/staging',
-            f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
     def test_start_python_dataflow_with_multiple_extra_packages(
-        self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+        self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
     ):
+        mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_PY)
-        variables['extra-package'] = ['a.whl', 'b.whl']
+        on_new_job_id_callback = MagicMock()
+        py_requirements = ["pands", "numpy"]
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
 
-        self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME,
-            variables=variables,
-            dataflow=PY_FILE,
+        passed_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+        passed_variables['extra-package'] = ['a.whl', 'b.whl']
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=passed_variables,
+                dataflow=PY_FILE,
+                py_options=PY_OPTIONS,
+                py_interpreter=DEFAULT_PY_INTERPRETER,
+                py_requirements=py_requirements,
+                on_new_job_id_callback=on_new_job_id_callback,
+            )
+
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+        expected_variables["job_name"] = job_name
+        expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+        expected_variables['extra-package'] = ['a.whl', 'b.whl']
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_python_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            py_file=PY_FILE,
+            py_interpreter=DEFAULT_PY_INTERPRETER,
             py_options=PY_OPTIONS,
+            py_requirements=py_requirements,
+            py_system_site_packages=False,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
         )
-        expected_cmd = [
-            "python3",
-            '-m',
-            PY_FILE,
-            '--extra-package=a.whl',
-            '--extra-package=b.whl',
-            '--region=us-central1',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--labels=foo=bar',
-            '--staging_location=gs://test/staging',
-            f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
 
     @parameterized.expand(
         [
-            ('default_to_python3', 'python3'),
-            ('major_version_2', 'python2'),
-            ('major_version_3', 'python3'),
-            ('minor_version', 'python3.6'),
+            ('python3',),
+            ('python2',),
+            ('python3',),
+            ('python3.6',),
         ]
     )
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
     def test_start_python_dataflow_with_custom_interpreter(
-        self,
-        name,
-        py_interpreter,
-        mock_conn,
-        mock_dataflow,
-        mock_dataflowjob,
-        mock_uuid,
+        self, py_interpreter, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
     ):
-        del name  # unused variable
+        mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME,
-            variables=DATAFLOW_VARIABLES_PY,
-            dataflow=PY_FILE,
-            py_options=PY_OPTIONS,
+        on_new_job_id_callback = MagicMock()
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=DATAFLOW_VARIABLES_PY,
+                dataflow=PY_FILE,
+                py_options=PY_OPTIONS,
+                py_interpreter=py_interpreter,
+                py_requirements=None,
+                on_new_job_id_callback=on_new_job_id_callback,
+            )
+
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+        expected_variables["job_name"] = job_name
+        expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_python_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            py_file=PY_FILE,
             py_interpreter=py_interpreter,
+            py_options=PY_OPTIONS,
+            py_requirements=None,
+            py_system_site_packages=False,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
         )
-        expected_cmd = [
-            py_interpreter,
-            '-m',
-            PY_FILE,
-            '--region=us-central1',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--labels=foo=bar',
-            '--staging_location=gs://test/staging',
-            f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
 
     @parameterized.expand(
         [
@@ -382,225 +430,229 @@ class TestDataflowHook(unittest.TestCase):
             ([], True),
         ]
     )
-    @mock.patch(DATAFLOW_STRING.format('prepare_virtualenv'))
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
     def test_start_python_dataflow_with_non_empty_py_requirements_and_without_system_packages(
         self,
         current_py_requirements,
         current_py_system_site_packages,
-        mock_conn,
-        mock_dataflow,
-        mock_dataflowjob,
+        mock_callback_on_job_id,
+        mock_dataflow_wait_for_done,
         mock_uuid,
-        mock_virtualenv,
     ):
+        mock_beam_start_python_pipeline = self.dataflow_hook.beam_hook.start_python_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        mock_virtualenv.return_value = '/dummy_dir/bin/python'
-        self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME,
-            variables=DATAFLOW_VARIABLES_PY,
-            dataflow=PY_FILE,
+        on_new_job_id_callback = MagicMock()
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=DATAFLOW_VARIABLES_PY,
+                dataflow=PY_FILE,
+                py_options=PY_OPTIONS,
+                py_interpreter=DEFAULT_PY_INTERPRETER,
+                py_requirements=current_py_requirements,
+                py_system_site_packages=current_py_system_site_packages,
+                on_new_job_id_callback=on_new_job_id_callback,
+            )
+
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_PY)
+        expected_variables["job_name"] = job_name
+        expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_python_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            py_file=PY_FILE,
+            py_interpreter=DEFAULT_PY_INTERPRETER,
             py_options=PY_OPTIONS,
             py_requirements=current_py_requirements,
             py_system_site_packages=current_py_system_site_packages,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION
         )
-        expected_cmd = [
-            '/dummy_dir/bin/python',
-            '-m',
-            PY_FILE,
-            '--region=us-central1',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--labels=foo=bar',
-            '--staging_location=gs://test/staging',
-            f'--job_name={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
     def test_start_python_dataflow_with_empty_py_requirements_and_without_system_packages(
-        self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+        self, mock_dataflow_wait_for_done, mock_uuid
     ):
+        self.dataflow_hook.beam_hook = BeamHook(runner="DataflowRunner")
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        with pytest.raises(AirflowException, match="Invalid method invocation."):
+        on_new_job_id_callback = MagicMock()
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"), self.assertRaisesRegex(
+            AirflowException, "Invalid method invocation."
+        ):
             self.dataflow_hook.start_python_dataflow(  # pylint: disable=no-value-for-parameter
                 job_name=JOB_NAME,
                 variables=DATAFLOW_VARIABLES_PY,
                 dataflow=PY_FILE,
                 py_options=PY_OPTIONS,
+                py_interpreter=DEFAULT_PY_INTERPRETER,
                 py_requirements=[],
+                on_new_job_id_callback=on_new_job_id_callback,
             )
 
+        mock_dataflow_wait_for_done.assert_not_called()
+
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
-    def test_start_java_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
+    def test_start_java_dataflow(self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid):
+        mock_beam_start_java_pipeline = self.dataflow_hook.beam_hook.start_java_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, jar=JAR_FILE
-        )
-        expected_cmd = [
-            'java',
-            '-jar',
-            JAR_FILE,
-            '--region=us-central1',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--stagingLocation=gs://test/staging',
-            '--labels={"foo":"bar"}',
-            f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
+        on_new_job_id_callback = MagicMock()
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=DATAFLOW_VARIABLES_JAVA,
+                jar=JAR_FILE,
+                job_class=JOB_CLASS,
+                on_new_job_id_callback=on_new_job_id_callback,
+            )
+
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+        expected_variables["jobName"] = job_name
+        expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+        expected_variables["labels"] = '{"foo":"bar"}'
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_java_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            jar=JAR_FILE,
+            job_class=JOB_CLASS,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION, multiple_jobs=False
+        )
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
     def test_start_java_dataflow_with_multiple_values_in_variables(
-        self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+        self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
     ):
+        mock_beam_start_java_pipeline = self.dataflow_hook.beam_hook.start_java_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
-        variables['mock-option'] = ['a.whl', 'b.whl']
-
-        self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME, variables=variables, jar=JAR_FILE
-        )
-        expected_cmd = [
-            'java',
-            '-jar',
-            JAR_FILE,
-            '--mock-option=a.whl',
-            '--mock-option=b.whl',
-            '--region=us-central1',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--stagingLocation=gs://test/staging',
-            '--labels={"foo":"bar"}',
-            f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
+        on_new_job_id_callback = MagicMock()
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
+
+        passed_variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+        passed_variables['mock-option'] = ['a.whl', 'b.whl']
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=passed_variables,
+                jar=JAR_FILE,
+                job_class=JOB_CLASS,
+                on_new_job_id_callback=on_new_job_id_callback,
+            )
+
+        expected_variables = copy.deepcopy(passed_variables)
+        expected_variables["jobName"] = job_name
+        expected_variables["region"] = DEFAULT_DATAFLOW_LOCATION
+        expected_variables["labels"] = '{"foo":"bar"}'
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_java_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            jar=JAR_FILE,
+            job_class=JOB_CLASS,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=DEFAULT_DATAFLOW_LOCATION, multiple_jobs=False
+        )
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
     def test_start_java_dataflow_with_custom_region_as_variable(
-        self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+        self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
     ):
+        mock_beam_start_java_pipeline = self.dataflow_hook.beam_hook.start_java_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
+        on_new_job_id_callback = MagicMock()
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
 
-        variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
-        variables['region'] = TEST_LOCATION
-
-        self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME, variables=variables, jar=JAR_FILE
-        )
-        expected_cmd = [
-            'java',
-            '-jar',
-            JAR_FILE,
-            f'--region={TEST_LOCATION}',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--stagingLocation=gs://test/staging',
-            '--labels={"foo":"bar"}',
-            f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
+        passed_variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+        passed_variables['region'] = TEST_LOCATION
+
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=passed_variables,
+                jar=JAR_FILE,
+                job_class=JOB_CLASS,
+                on_new_job_id_callback=on_new_job_id_callback,
+            )
+
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+        expected_variables["jobName"] = job_name
+        expected_variables["region"] = TEST_LOCATION
+        expected_variables["labels"] = '{"foo":"bar"}'
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_java_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            jar=JAR_FILE,
+            job_class=JOB_CLASS,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION, multiple_jobs=False
+        )
 
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.wait_for_done'))
+    @mock.patch(DATAFLOW_STRING.format('process_line_and_extract_dataflow_job_id_callback'))
     def test_start_java_dataflow_with_custom_region_as_parameter(
-        self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
+        self, mock_callback_on_job_id, mock_dataflow_wait_for_done, mock_uuid
     ):
+        mock_beam_start_java_pipeline = self.dataflow_hook.beam_hook.start_java_pipeline
         mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
+        on_new_job_id_callback = MagicMock()
+        job_name = f"{JOB_NAME}-{MOCK_UUID_PREFIX}"
 
-        variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
-        variables['region'] = TEST_LOCATION
-
-        self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME, variables=variables, jar=JAR_FILE
-        )
-        expected_cmd = [
-            'java',
-            '-jar',
-            JAR_FILE,
-            f'--region={TEST_LOCATION}',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--stagingLocation=gs://test/staging',
-            '--labels={"foo":"bar"}',
-            f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(expected_cmd) == sorted(mock_dataflow.call_args[1]["cmd"])
+        with self.assertWarnsRegex(DeprecationWarning, "This method is deprecated"):
+            self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
+                job_name=JOB_NAME,
+                variables=DATAFLOW_VARIABLES_JAVA,
+                jar=JAR_FILE,
+                job_class=JOB_CLASS,
+                on_new_job_id_callback=on_new_job_id_callback,
+                location=TEST_LOCATION,
+            )
 
-    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
-    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
-    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
-    def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
-        mock_uuid.return_value = MOCK_UUID
-        mock_conn.return_value = None
-        dataflow_instance = mock_dataflow.return_value
-        dataflow_instance.wait_for_done.return_value = None
-        dataflowjob_instance = mock_dataflowjob.return_value
-        dataflowjob_instance.wait_for_done.return_value = None
-        self.dataflow_hook.start_java_dataflow(  # pylint: disable=no-value-for-parameter
-            job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, jar=JAR_FILE, job_class=JOB_CLASS
-        )
-        expected_cmd = [
-            'java',
-            '-cp',
-            JAR_FILE,
-            JOB_CLASS,
-            '--region=us-central1',
-            '--runner=DataflowRunner',
-            '--project=test',
-            '--stagingLocation=gs://test/staging',
-            '--labels={"foo":"bar"}',
-            f'--jobName={JOB_NAME}-{MOCK_UUID_PREFIX}',
-        ]
-        assert sorted(mock_dataflow.call_args[1]["cmd"]) == sorted(expected_cmd)
+        expected_variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
+        expected_variables["jobName"] = job_name
+        expected_variables["region"] = TEST_LOCATION
+        expected_variables["labels"] = '{"foo":"bar"}'
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback)
+        mock_beam_start_java_pipeline.assert_called_once_with(
+            variables=expected_variables,
+            jar=JAR_FILE,
+            job_class=JOB_CLASS,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+
+        mock_dataflow_wait_for_done.assert_called_once_with(
+            job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION, multiple_jobs=False
+        )
 
     @parameterized.expand(
         [
@@ -616,17 +668,20 @@ class TestDataflowHook(unittest.TestCase):
     )
     @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
     def test_valid_dataflow_job_name(self, expected_result, job_name, append_job_name, mock_uuid4):
-        job_name = self.dataflow_hook._build_dataflow_job_name(
+        job_name = self.dataflow_hook.build_dataflow_job_name(
             job_name=job_name, append_job_name=append_job_name
         )
 
-        assert expected_result == job_name
+        self.assertEqual(expected_result, job_name)
 
+    #
     @parameterized.expand([("1dfjob@",), ("dfjob@",), ("df^jo",)])
     def test_build_dataflow_job_name_with_invalid_value(self, job_name):
-        with pytest.raises(ValueError):
-            self.dataflow_hook._build_dataflow_job_name(job_name=job_name, append_job_name=False)
+        self.assertRaises(
+            ValueError, self.dataflow_hook.build_dataflow_job_name, job_name=job_name, append_job_name=False
+        )
 
+    #
     @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
     @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
     def test_get_job(self, mock_conn, mock_dataflowjob):
@@ -641,6 +696,7 @@ class TestDataflowHook(unittest.TestCase):
         )
         method_fetch_job_by_id.assert_called_once_with(TEST_JOB_ID)
 
+    #
     @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
     @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
     def test_fetch_job_metrics_by_id(self, mock_conn, mock_dataflowjob):
@@ -706,6 +762,34 @@ class TestDataflowHook(unittest.TestCase):
         )
         method_fetch_job_autoscaling_events_by_id.assert_called_once_with(TEST_JOB_ID)
 
+    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
+    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
+    def test_wait_for_done(self, mock_conn, mock_dataflowjob):
+        method_wait_for_done = mock_dataflowjob.return_value.wait_for_done
+
+        self.dataflow_hook.wait_for_done(
+            job_name="JOB_NAME",
+            project_id=TEST_PROJECT_ID,
+            job_id=TEST_JOB_ID,
+            location=TEST_LOCATION,
+            multiple_jobs=False,
+        )
+        mock_conn.assert_called_once()
+        mock_dataflowjob.assert_called_once_with(
+            dataflow=mock_conn.return_value,
+            project_number=TEST_PROJECT_ID,
+            name="JOB_NAME",
+            location=TEST_LOCATION,
+            poll_sleep=self.dataflow_hook.poll_sleep,
+            job_id=TEST_JOB_ID,
+            num_retries=self.dataflow_hook.num_retries,
+            multiple_jobs=False,
+            drain_pipeline=self.dataflow_hook.drain_pipeline,
+            cancel_timeout=self.dataflow_hook.cancel_timeout,
+            wait_until_finished=self.dataflow_hook.wait_until_finished,
+        )
+        method_wait_for_done.assert_called_once_with()
+
 
 class TestDataflowTemplateHook(unittest.TestCase):
     def setUp(self):
@@ -1691,13 +1775,32 @@ class TestDataflow(unittest.TestCase):
     def test_data_flow_valid_job_id(self, log):
         echos = ";".join([f"echo {shlex.quote(line)}" for line in log.split("\n")])
         cmd = ["bash", "-c", echos]
-        assert _DataflowRunner(cmd).wait_for_done() == TEST_JOB_ID
+        found_job_id = None
+
+        def callback(job_id):
+            nonlocal found_job_id
+            found_job_id = job_id
+
+        BeamCommandRunner(
+            cmd, process_line_callback=process_line_and_extract_dataflow_job_id_callback(callback)
+        ).wait_for_done()
+        self.assertEqual(found_job_id, TEST_JOB_ID)
 
     def test_data_flow_missing_job_id(self):
         cmd = ['echo', 'unit testing']
-        assert _DataflowRunner(cmd).wait_for_done() is None
+        found_job_id = None
+
+        def callback(job_id):
+            nonlocal found_job_id
+            found_job_id = job_id
+
+        BeamCommandRunner(
+            cmd, process_line_callback=process_line_and_extract_dataflow_job_id_callback(callback)
+        ).wait_for_done()
+
+        self.assertEqual(found_job_id, None)
 
-    @mock.patch('airflow.providers.google.cloud.hooks.dataflow._DataflowRunner.log')
+    @mock.patch('airflow.providers.apache.beam.hooks.beam.BeamCommandRunner.log')
     @mock.patch('subprocess.Popen')
     @mock.patch('select.select')
     def test_dataflow_wait_for_done_logging(self, mock_select, mock_popen, mock_logging):
@@ -1718,7 +1821,6 @@ class TestDataflow(unittest.TestCase):
         mock_proc_poll.side_effect = [None, poll_resp_error]
         mock_proc.poll = mock_proc_poll
         mock_popen.return_value = mock_proc
-        dataflow = _DataflowRunner(['test', 'cmd'])
+        dataflow = BeamCommandRunner(['test', 'cmd'])
         mock_logging.info.assert_called_once_with('Running command: %s', 'test cmd')
-        with pytest.raises(Exception):
-            dataflow.wait_for_done()
+        self.assertRaises(Exception, dataflow.wait_for_done)
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py
index 7e290d7..3018052 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-
+import copy
 import unittest
 from copy import deepcopy
 from unittest import mock
@@ -115,35 +115,56 @@ class TestDataflowPythonOperator(unittest.TestCase):
         assert self.dataflow.dataflow_default_options == DEFAULT_OPTIONS_PYTHON
         assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS
 
+    @mock.patch(
+        'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
+    )
+    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
-    def test_exec(self, gcs_hook, dataflow_mock):
+    def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id):
         """Test DataflowHook is created and the right args are passed to
         start_python_workflow.
 
         """
-        start_python_hook = dataflow_mock.return_value.start_python_dataflow
+        start_python_mock = beam_hook_mock.return_value.start_python_pipeline
         gcs_provide_file = gcs_hook.return_value.provide_file
+        job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
         self.dataflow.execute(None)
-        assert dataflow_mock.called
+        beam_hook_mock.assert_called_once_with(runner="DataflowRunner")
+        self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
+        gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback=mock.ANY)
+        dataflow_hook_mock.assert_called_once_with(
+            gcp_conn_id="google_cloud_default",
+            delegate_to=mock.ANY,
+            poll_sleep=POLL_SLEEP,
+            impersonation_chain=None,
+            drain_pipeline=False,
+            cancel_timeout=mock.ANY,
+            wait_until_finished=None,
+        )
         expected_options = {
-            'project': 'test',
-            'staging_location': 'gs://test/staging',
+            "project": dataflow_hook_mock.return_value.project_id,
+            "staging_location": 'gs://test/staging',
+            "job_name": job_name,
+            "region": TEST_LOCATION,
             'output': 'gs://test/output',
-            'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION},
+            'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
         }
-        gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
-        start_python_hook.assert_called_once_with(
-            job_name=JOB_NAME,
+        start_python_mock.assert_called_once_with(
             variables=expected_options,
-            dataflow=mock.ANY,
+            py_file=gcs_provide_file.return_value.__enter__.return_value.name,
             py_options=PY_OPTIONS,
             py_interpreter=PY_INTERPRETER,
             py_requirements=None,
             py_system_site_packages=False,
-            on_new_job_id_callback=mock.ANY,
-            project_id=None,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+            job_id=mock.ANY,
+            job_name=job_name,
             location=TEST_LOCATION,
+            multiple_jobs=False,
         )
         assert self.dataflow.py_file.startswith('/tmp/dataflow')
 
@@ -172,110 +193,182 @@ class TestDataflowJavaOperator(unittest.TestCase):
         assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS
         assert self.dataflow.check_if_running == CheckJobRunning.WaitForRun
 
+    @mock.patch(
+        'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
+    )
+    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
-    def test_exec(self, gcs_hook, dataflow_mock):
+    def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id):
         """Test DataflowHook is created and the right args are passed to
         start_java_workflow.
 
         """
-        start_java_hook = dataflow_mock.return_value.start_java_dataflow
+        start_java_mock = beam_hook_mock.return_value.start_java_pipeline
         gcs_provide_file = gcs_hook.return_value.provide_file
+        job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
         self.dataflow.check_if_running = CheckJobRunning.IgnoreJob
+
         self.dataflow.execute(None)
-        assert dataflow_mock.called
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback=mock.ANY)
         gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
-        start_java_hook.assert_called_once_with(
-            job_name=JOB_NAME,
-            variables=mock.ANY,
-            jar=mock.ANY,
+        expected_variables = {
+            'project': dataflow_hook_mock.return_value.project_id,
+            'stagingLocation': 'gs://test/staging',
+            'jobName': job_name,
+            'region': TEST_LOCATION,
+            'output': 'gs://test/output',
+            'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
+        }
+
+        start_java_mock.assert_called_once_with(
+            variables=expected_variables,
+            jar=gcs_provide_file.return_value.__enter__.return_value.name,
             job_class=JOB_CLASS,
-            append_job_name=True,
-            multiple_jobs=None,
-            on_new_job_id_callback=mock.ANY,
-            project_id=None,
+            process_line_callback=mock_callback_on_job_id.return_value,
+        )
+        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+            job_id=mock.ANY,
+            job_name=job_name,
             location=TEST_LOCATION,
+            multiple_jobs=None,
         )
 
+    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
-    def test_check_job_running_exec(self, gcs_hook, dataflow_mock):
+    def test_check_job_running_exec(self, gcs_hook, dataflow_mock, beam_hook_mock):
         """Test DataflowHook is created and the right args are passed to
         start_java_workflow.
 
         """
         dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
         dataflow_running.return_value = True
-        start_java_hook = dataflow_mock.return_value.start_java_dataflow
+        start_java_hook = beam_hook_mock.return_value.start_java_pipeline
         gcs_provide_file = gcs_hook.return_value.provide_file
         self.dataflow.check_if_running = True
+
         self.dataflow.execute(None)
-        assert dataflow_mock.called
-        gcs_provide_file.assert_not_called()
+
+        self.assertTrue(dataflow_mock.called)
         start_java_hook.assert_not_called()
-        dataflow_running.assert_called_once_with(
-            name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION
-        )
+        gcs_provide_file.assert_called_once()
+        variables = {
+            'project': dataflow_mock.return_value.project_id,
+            'stagingLocation': 'gs://test/staging',
+            'jobName': JOB_NAME,
+            'region': TEST_LOCATION,
+            'output': 'gs://test/output',
+            'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
+        }
+        dataflow_running.assert_called_once_with(name=JOB_NAME, variables=variables)
 
+    @mock.patch(
+        'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
+    )
+    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
-    def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock):
+    def test_check_job_not_running_exec(
+        self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id
+    ):
         """Test DataflowHook is created and the right args are passed to
         start_java_workflow with option to check if job is running
-
         """
-        dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
+        is_job_dataflow_running_variables = None
+
+        def set_is_job_dataflow_running_variables(*args, **kwargs):
+            nonlocal is_job_dataflow_running_variables
+            is_job_dataflow_running_variables = copy.deepcopy(kwargs.get("variables"))
+
+        dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running
+        dataflow_running.side_effect = set_is_job_dataflow_running_variables
         dataflow_running.return_value = False
-        start_java_hook = dataflow_mock.return_value.start_java_dataflow
+        start_java_mock = beam_hook_mock.return_value.start_java_pipeline
         gcs_provide_file = gcs_hook.return_value.provide_file
         self.dataflow.check_if_running = True
+
         self.dataflow.execute(None)
-        assert dataflow_mock.called
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback=mock.ANY)
         gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
-        start_java_hook.assert_called_once_with(
-            job_name=JOB_NAME,
-            variables=mock.ANY,
-            jar=mock.ANY,
+        expected_variables = {
+            'project': dataflow_hook_mock.return_value.project_id,
+            'stagingLocation': 'gs://test/staging',
+            'jobName': JOB_NAME,
+            'region': TEST_LOCATION,
+            'output': 'gs://test/output',
+            'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
+        }
+        self.assertEqual(expected_variables, is_job_dataflow_running_variables)
+        job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
+        expected_variables["jobName"] = job_name
+        start_java_mock.assert_called_once_with(
+            variables=expected_variables,
+            jar=gcs_provide_file.return_value.__enter__.return_value.name,
             job_class=JOB_CLASS,
-            append_job_name=True,
-            multiple_jobs=None,
-            on_new_job_id_callback=mock.ANY,
-            project_id=None,
-            location=TEST_LOCATION,
+            process_line_callback=mock_callback_on_job_id.return_value,
         )
-        dataflow_running.assert_called_once_with(
-            name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION
+        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+            job_id=mock.ANY,
+            job_name=job_name,
+            location=TEST_LOCATION,
+            multiple_jobs=None,
         )
 
+    @mock.patch(
+        'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
+    )
+    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
     @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
-    def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock):
+    def test_check_multiple_job_exec(
+        self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id
+    ):
         """Test DataflowHook is created and the right args are passed to
-        start_java_workflow with option to check multiple jobs
-
+        start_java_workflow with option to check if job is running
         """
-        dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
+        is_job_dataflow_running_variables = None
+
+        def set_is_job_dataflow_running_variables(*args, **kwargs):
+            nonlocal is_job_dataflow_running_variables
+            is_job_dataflow_running_variables = copy.deepcopy(kwargs.get("variables"))
+
+        dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running
+        dataflow_running.side_effect = set_is_job_dataflow_running_variables
         dataflow_running.return_value = False
-        start_java_hook = dataflow_mock.return_value.start_java_dataflow
+        start_java_mock = beam_hook_mock.return_value.start_java_pipeline
         gcs_provide_file = gcs_hook.return_value.provide_file
-        self.dataflow.multiple_jobs = True
         self.dataflow.check_if_running = True
+        self.dataflow.multiple_jobs = True
+
         self.dataflow.execute(None)
-        assert dataflow_mock.called
+
+        mock_callback_on_job_id.assert_called_once_with(on_new_job_id_callback=mock.ANY)
         gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
-        start_java_hook.assert_called_once_with(
-            job_name=JOB_NAME,
-            variables=mock.ANY,
-            jar=mock.ANY,
+        expected_variables = {
+            'project': dataflow_hook_mock.return_value.project_id,
+            'stagingLocation': 'gs://test/staging',
+            'jobName': JOB_NAME,
+            'region': TEST_LOCATION,
+            'output': 'gs://test/output',
+            'labels': {'foo': 'bar', 'airflow-version': 'v2-1-0-dev0'},
+        }
+        self.assertEqual(expected_variables, is_job_dataflow_running_variables)
+        job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
+        expected_variables["jobName"] = job_name
+        start_java_mock.assert_called_once_with(
+            variables=expected_variables,
+            jar=gcs_provide_file.return_value.__enter__.return_value.name,
             job_class=JOB_CLASS,
-            append_job_name=True,
-            multiple_jobs=True,
-            on_new_job_id_callback=mock.ANY,
-            project_id=None,
-            location=TEST_LOCATION,
+            process_line_callback=mock_callback_on_job_id.return_value,
         )
-        dataflow_running.assert_called_once_with(
-            name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION
+        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
+            job_id=mock.ANY,
+            job_name=job_name,
+            location=TEST_LOCATION,
+            multiple_jobs=True,
         )
 
 
diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py b/tests/providers/google/cloud/operators/test_mlengine_utils.py
index 539ee60..c46fa62 100644
--- a/tests/providers/google/cloud/operators/test_mlengine_utils.py
+++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py
@@ -106,9 +106,14 @@ class TestCreateEvaluateOps(unittest.TestCase):
             )
             assert success_message['predictionOutput'] == result
 
-        with patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') as mock_dataflow_hook:
-            hook_instance = mock_dataflow_hook.return_value
-            hook_instance.start_python_dataflow.return_value = None
+        with patch(
+            'airflow.providers.google.cloud.operators.dataflow.DataflowHook'
+        ) as mock_dataflow_hook, patch(
+            'airflow.providers.google.cloud.operators.dataflow.BeamHook'
+        ) as mock_beam_hook:
+            dataflow_hook_instance = mock_dataflow_hook.return_value
+            dataflow_hook_instance.start_python_dataflow.return_value = None
+            beam_hook_instance = mock_beam_hook.return_value
             summary.execute(None)
             mock_dataflow_hook.assert_called_once_with(
                 gcp_conn_id='google_cloud_default',
@@ -117,23 +122,28 @@ class TestCreateEvaluateOps(unittest.TestCase):
                 drain_pipeline=False,
                 cancel_timeout=600,
                 wait_until_finished=None,
+                impersonation_chain=None,
             )
-            hook_instance.start_python_dataflow.assert_called_once_with(
-                job_name='{{task.task_id}}',
+            mock_beam_hook.assert_called_once_with(runner="DataflowRunner")
+            beam_hook_instance.start_python_pipeline.assert_called_once_with(
                 variables={
                     'prediction_path': 'gs://legal-bucket/fake-output-path',
                     'labels': {'airflow-version': TEST_VERSION},
                     'metric_keys': 'err',
                     'metric_fn_encoded': self.metric_fn_encoded,
+                    'project': 'test-project',
+                    'region': 'us-central1',
+                    'job_name': mock.ANY,
                 },
-                dataflow=mock.ANY,
+                py_file=mock.ANY,
                 py_options=[],
-                py_requirements=['apache-beam[gcp]>=2.14.0'],
                 py_interpreter='python3',
+                py_requirements=['apache-beam[gcp]>=2.14.0'],
                 py_system_site_packages=False,
-                on_new_job_id_callback=ANY,
-                project_id='test-project',
-                location='us-central1',
+                process_line_callback=mock.ANY,
+            )
+            dataflow_hook_instance.wait_for_done.assert_called_once_with(
+                job_name=mock.ANY, location='us-central1', job_id=mock.ANY, multiple_jobs=False
             )
 
         with patch('airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook') as mock_gcs_hook:


[airflow] 20/28: Support google-cloud-logging` >=2.0.0 (#13801)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 25f2db1c5dd043ee8bf9b9c7b2f09a505ce9f560
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Wed Feb 3 04:16:50 2021 +0100

    Support google-cloud-logging` >=2.0.0 (#13801)
    
    (cherry picked from commit 0e8c77b93a5ca5ecfdcd1c4bd91f54846fc15d57)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |   1 +
 .../google/cloud/log/stackdriver_task_handler.py   |  72 +++++--
 setup.py                                           |   2 +-
 .../cloud/log/test_stackdriver_task_handler.py     | 225 +++++++++++++--------
 4 files changed, 200 insertions(+), 100 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index 9cf9853..a363051 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -34,6 +34,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
 | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
 | [``google-cloud-dataproc``](https://pypi.org/project/google-cloud-dataproc/) | ``>=1.0.1,<2.0.0`` | ``>=2.2.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-dataproc/blob/master/UPGRADING.md) |
 | [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
+| [``google-cloud-logging``](https://pypi.org/project/google-cloud-logging/) | ``>=1.14.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-logging/blob/master/UPGRADING.md) |
 | [``google-cloud-monitoring``](https://pypi.org/project/google-cloud-monitoring/) | ``>=0.34.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-monitoring/blob/master/UPGRADING.md) |
 | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
 | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
diff --git a/airflow/providers/google/cloud/log/stackdriver_task_handler.py b/airflow/providers/google/cloud/log/stackdriver_task_handler.py
index be75fcd..5479185 100644
--- a/airflow/providers/google/cloud/log/stackdriver_task_handler.py
+++ b/airflow/providers/google/cloud/log/stackdriver_task_handler.py
@@ -21,9 +21,12 @@ from urllib.parse import urlencode
 
 from cached_property import cached_property
 from google.api_core.gapic_v1.client_info import ClientInfo
+from google.auth.credentials import Credentials
 from google.cloud import logging as gcp_logging
+from google.cloud.logging import Resource
 from google.cloud.logging.handlers.transports import BackgroundThreadTransport, Transport
-from google.cloud.logging.resource import Resource
+from google.cloud.logging_v2.services.logging_service_v2 import LoggingServiceV2Client
+from google.cloud.logging_v2.types import ListLogEntriesRequest, ListLogEntriesResponse
 
 from airflow import version
 from airflow.models import TaskInstance
@@ -99,13 +102,19 @@ class StackdriverTaskHandler(logging.Handler):
         self.resource: Resource = resource
         self.labels: Optional[Dict[str, str]] = labels
         self.task_instance_labels: Optional[Dict[str, str]] = {}
+        self.task_instance_hostname = 'default-hostname'
 
     @cached_property
-    def _client(self) -> gcp_logging.Client:
-        """Google Cloud Library API client"""
+    def _credentials_and_project(self) -> Tuple[Credentials, str]:
         credentials, project = get_credentials_and_project_id(
             key_path=self.gcp_key_path, scopes=self.scopes, disable_logging=True
         )
+        return credentials, project
+
+    @property
+    def _client(self) -> gcp_logging.Client:
+        """The Cloud Library API client"""
+        credentials, project = self._credentials_and_project
         client = gcp_logging.Client(
             credentials=credentials,
             project=project,
@@ -113,6 +122,16 @@ class StackdriverTaskHandler(logging.Handler):
         )
         return client
 
+    @property
+    def _logging_service_client(self) -> LoggingServiceV2Client:
+        """The Cloud logging service v2 client."""
+        credentials, _ = self._credentials_and_project
+        client = LoggingServiceV2Client(
+            credentials=credentials,
+            client_info=ClientInfo(client_library_version='airflow_v' + version.version),
+        )
+        return client
+
     @cached_property
     def _transport(self) -> Transport:
         """Object responsible for sending data to Stackdriver"""
@@ -146,10 +165,11 @@ class StackdriverTaskHandler(logging.Handler):
         :type task_instance:  :class:`airflow.models.TaskInstance`
         """
         self.task_instance_labels = self._task_instance_to_labels(task_instance)
+        self.task_instance_hostname = task_instance.hostname
 
     def read(
         self, task_instance: TaskInstance, try_number: Optional[int] = None, metadata: Optional[Dict] = None
-    ) -> Tuple[List[str], List[Dict]]:
+    ) -> Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]]:
         """
         Read logs of given task instance from Stackdriver logging.
 
@@ -160,12 +180,14 @@ class StackdriverTaskHandler(logging.Handler):
         :type try_number: Optional[int]
         :param metadata: log metadata. It is used for steaming log reading and auto-tailing.
         :type metadata: Dict
-        :return: a tuple of list of logs and list of metadata
-        :rtype: Tuple[List[str], List[Dict]]
+        :return: a tuple of (
+            list of (one element tuple with two element tuple - hostname and logs)
+            and list of metadata)
+        :rtype: Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]]
         """
         if try_number is not None and try_number < 1:
-            logs = [f"Error fetching the logs. Try number {try_number} is invalid."]
-            return logs, [{"end_of_log": "true"}]
+            logs = f"Error fetching the logs. Try number {try_number} is invalid."
+            return [((self.task_instance_hostname, logs),)], [{"end_of_log": "true"}]
 
         if not metadata:
             metadata = {}
@@ -188,7 +210,7 @@ class StackdriverTaskHandler(logging.Handler):
         if next_page_token:
             new_metadata['next_page_token'] = next_page_token
 
-        return [messages], [new_metadata]
+        return [((self.task_instance_hostname, messages),)], [new_metadata]
 
     def _prepare_log_filter(self, ti_labels: Dict[str, str]) -> str:
         """
@@ -210,9 +232,10 @@ class StackdriverTaskHandler(logging.Handler):
             escaped_value = value.replace("\\", "\\\\").replace('"', '\\"')
             return f'"{escaped_value}"'
 
+        _, project = self._credentials_and_project
         log_filters = [
             f'resource.type={escale_label_value(self.resource.type)}',
-            f'logName="projects/{self._client.project}/logs/{self.name}"',
+            f'logName="projects/{project}/logs/{self.name}"',
         ]
 
         for key, value in self.resource.labels.items():
@@ -252,6 +275,8 @@ class StackdriverTaskHandler(logging.Handler):
                     log_filter=log_filter, page_token=next_page_token
                 )
                 messages.append(new_messages)
+                if not messages:
+                    break
 
             end_of_log = True
             next_page_token = None
@@ -271,15 +296,21 @@ class StackdriverTaskHandler(logging.Handler):
         :return: Downloaded logs and next page token
         :rtype: Tuple[str, str]
         """
-        entries = self._client.list_entries(filter_=log_filter, page_token=page_token)
-        page = next(entries.pages)
-        next_page_token = entries.next_page_token
+        _, project = self._credentials_and_project
+        request = ListLogEntriesRequest(
+            resource_names=[f'projects/{project}'],
+            filter=log_filter,
+            page_token=page_token,
+            order_by='timestamp asc',
+            page_size=1000,
+        )
+        response = self._logging_service_client.list_log_entries(request=request)
+        page: ListLogEntriesResponse = next(response.pages)
         messages = []
-        for entry in page:
-            if "message" in entry.payload:
-                messages.append(entry.payload["message"])
-
-        return "\n".join(messages), next_page_token
+        for entry in page.entries:
+            if "message" in entry.json_payload:
+                messages.append(entry.json_payload["message"])
+        return "\n".join(messages), page.next_page_token
 
     @classmethod
     def _task_instance_to_labels(cls, ti: TaskInstance) -> Dict[str, str]:
@@ -315,7 +346,7 @@ class StackdriverTaskHandler(logging.Handler):
         :return: URL to the external log collection service
         :rtype: str
         """
-        project_id = self._client.project
+        _, project_id = self._credentials_and_project
 
         ti_labels = self._task_instance_to_labels(task_instance)
         ti_labels[self.LABEL_TRY_NUMBER] = str(try_number)
@@ -331,3 +362,6 @@ class StackdriverTaskHandler(logging.Handler):
 
         url = f"{self.LOG_VIEWER_BASE_URL}?{urlencode(url_query_string)}"
         return url
+
+    def close(self) -> None:
+        self._transport.flush()
diff --git a/setup.py b/setup.py
index fa1e73a..7beb684 100644
--- a/setup.py
+++ b/setup.py
@@ -292,7 +292,7 @@ google = [
     'google-cloud-dlp>=0.11.0,<2.0.0',
     'google-cloud-kms>=2.0.0,<3.0.0',
     'google-cloud-language>=1.1.1,<2.0.0',
-    'google-cloud-logging>=1.14.0,<2.0.0',
+    'google-cloud-logging>=2.1.1,<3.0.0',
     'google-cloud-memcache>=0.2.0',
     'google-cloud-monitoring>=2.0.0,<3.0.0',
     'google-cloud-os-login>=2.0.0,<3.0.0',
diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
index 4159e9e..b4dbf69 100644
--- a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
+++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
@@ -21,7 +21,8 @@ from datetime import datetime
 from unittest import mock
 from urllib.parse import parse_qs, urlparse
 
-from google.cloud.logging.resource import Resource
+from google.cloud.logging import Resource
+from google.cloud.logging_v2.types import ListLogEntriesRequest, ListLogEntriesResponse, LogEntry
 
 from airflow.models import TaskInstance
 from airflow.models.dag import DAG
@@ -30,15 +31,27 @@ from airflow.providers.google.cloud.log.stackdriver_task_handler import Stackdri
 from airflow.utils.state import State
 
 
-def _create_list_response(messages, token):
-    page = [mock.MagicMock(payload={"message": message}) for message in messages]
-    return mock.MagicMock(pages=(n for n in [page]), next_page_token=token)
+def _create_list_log_entries_response_mock(messages, token):
+    return ListLogEntriesResponse(
+        entries=[LogEntry(json_payload={"message": message}) for message in messages], next_page_token=token
+    )
+
+
+def _remove_stackdriver_handlers():
+    for handler_ref in reversed(logging._handlerList[:]):
+        handler = handler_ref()
+        if not isinstance(handler, StackdriverTaskHandler):
+            continue
+        logging._removeHandlerRef(handler_ref)
+        del handler
 
 
 class TestStackdriverLoggingHandlerStandalone(unittest.TestCase):
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
     def test_should_pass_message_to_client(self, mock_client, mock_get_creds_and_project_id):
+        self.addCleanup(_remove_stackdriver_handlers)
+
         mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
 
         transport_type = mock.MagicMock()
@@ -69,6 +82,7 @@ class TestStackdriverLoggingHandlerTask(unittest.TestCase):
         self.ti.try_number = 1
         self.ti.state = State.RUNNING
         self.addCleanup(self.dag.clear)
+        self.addCleanup(_remove_stackdriver_handlers)
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
@@ -118,107 +132,153 @@ class TestStackdriverLoggingHandlerTask(unittest.TestCase):
         )
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
-    @mock.patch(
-        'airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client',
-        **{'return_value.project': 'asf-project'},  # type: ignore
-    )
+    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client')
     def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_project_id):
-        mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None)
+        mock_client.return_value.list_log_entries.return_value.pages = iter(
+            [_create_list_log_entries_response_mock(["MSG1", "MSG2"], None)]
+        )
         mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
 
         logs, metadata = self.stackdriver_task_handler.read(self.ti)
-        mock_client.return_value.list_entries.assert_called_once_with(
-            filter_='resource.type="global"\n'
-            'logName="projects/asf-project/logs/airflow"\n'
-            'labels.task_id="task_for_testing_file_log_handler"\n'
-            'labels.dag_id="dag_for_testing_file_task_handler"\n'
-            'labels.execution_date="2016-01-01T00:00:00+00:00"',
-            page_token=None,
+        mock_client.return_value.list_log_entries.assert_called_once_with(
+            request=ListLogEntriesRequest(
+                resource_names=["projects/project_id"],
+                filter=(
+                    'resource.type="global"\n'
+                    'logName="projects/project_id/logs/airflow"\n'
+                    'labels.task_id="task_for_testing_file_log_handler"\n'
+                    'labels.dag_id="dag_for_testing_file_task_handler"\n'
+                    'labels.execution_date="2016-01-01T00:00:00+00:00"'
+                ),
+                order_by='timestamp asc',
+                page_size=1000,
+                page_token=None,
+            )
         )
-        assert ['MSG1\nMSG2'] == logs
+        assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
         assert [{'end_of_log': True}] == metadata
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
-    @mock.patch(
-        'airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client',
-        **{'return_value.project': 'asf-project'},  # type: ignore
-    )
+    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client')
     def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_and_project_id):
-        mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None)
+        mock_client.return_value.list_log_entries.return_value.pages = iter(
+            [_create_list_log_entries_response_mock(["MSG1", "MSG2"], None)]
+        )
         mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
         self.ti.task_id = "K\"OT"
         logs, metadata = self.stackdriver_task_handler.read(self.ti)
-        mock_client.return_value.list_entries.assert_called_once_with(
-            filter_='resource.type="global"\n'
-            'logName="projects/asf-project/logs/airflow"\n'
-            'labels.task_id="K\\"OT"\n'
-            'labels.dag_id="dag_for_testing_file_task_handler"\n'
-            'labels.execution_date="2016-01-01T00:00:00+00:00"',
-            page_token=None,
+        mock_client.return_value.list_log_entries.assert_called_once_with(
+            request=ListLogEntriesRequest(
+                resource_names=["projects/project_id"],
+                filter=(
+                    'resource.type="global"\n'
+                    'logName="projects/project_id/logs/airflow"\n'
+                    'labels.task_id="K\\"OT"\n'
+                    'labels.dag_id="dag_for_testing_file_task_handler"\n'
+                    'labels.execution_date="2016-01-01T00:00:00+00:00"'
+                ),
+                order_by='timestamp asc',
+                page_size=1000,
+                page_token=None,
+            )
         )
-        assert ['MSG1\nMSG2'] == logs
+        assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
         assert [{'end_of_log': True}] == metadata
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
-    @mock.patch(
-        'airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client',
-        **{'return_value.project': 'asf-project'},  # type: ignore
-    )
+    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client')
     def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_project_id):
-        mock_client.return_value.list_entries.return_value = _create_list_response(["MSG1", "MSG2"], None)
+        mock_client.return_value.list_log_entries.return_value.pages = iter(
+            [_create_list_log_entries_response_mock(["MSG1", "MSG2"], None)]
+        )
         mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
 
         logs, metadata = self.stackdriver_task_handler.read(self.ti, 3)
-        mock_client.return_value.list_entries.assert_called_once_with(
-            filter_='resource.type="global"\n'
-            'logName="projects/asf-project/logs/airflow"\n'
-            'labels.task_id="task_for_testing_file_log_handler"\n'
-            'labels.dag_id="dag_for_testing_file_task_handler"\n'
-            'labels.execution_date="2016-01-01T00:00:00+00:00"\n'
-            'labels.try_number="3"',
-            page_token=None,
+        mock_client.return_value.list_log_entries.assert_called_once_with(
+            request=ListLogEntriesRequest(
+                resource_names=["projects/project_id"],
+                filter=(
+                    'resource.type="global"\n'
+                    'logName="projects/project_id/logs/airflow"\n'
+                    'labels.task_id="task_for_testing_file_log_handler"\n'
+                    'labels.dag_id="dag_for_testing_file_task_handler"\n'
+                    'labels.execution_date="2016-01-01T00:00:00+00:00"\n'
+                    'labels.try_number="3"'
+                ),
+                order_by='timestamp asc',
+                page_size=1000,
+                page_token=None,
+            )
         )
-        assert ['MSG1\nMSG2'] == logs
+        assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
         assert [{'end_of_log': True}] == metadata
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
-    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
+    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client')
     def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_project_id):
-        mock_client.return_value.list_entries.side_effect = [
-            _create_list_response(["MSG1", "MSG2"], "TOKEN1"),
-            _create_list_response(["MSG3", "MSG4"], None),
+        mock_client.return_value.list_log_entries.side_effect = [
+            mock.MagicMock(pages=iter([_create_list_log_entries_response_mock(["MSG1", "MSG2"], "TOKEN1")])),
+            mock.MagicMock(pages=iter([_create_list_log_entries_response_mock(["MSG3", "MSG4"], None)])),
         ]
         mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
         logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3)
-        mock_client.return_value.list_entries.assert_called_once_with(filter_=mock.ANY, page_token=None)
-        assert ['MSG1\nMSG2'] == logs
+        mock_client.return_value.list_log_entries.assert_called_once_with(
+            request=ListLogEntriesRequest(
+                resource_names=["projects/project_id"],
+                filter=(
+                    '''resource.type="global"
+logName="projects/project_id/logs/airflow"
+labels.task_id="task_for_testing_file_log_handler"
+labels.dag_id="dag_for_testing_file_task_handler"
+labels.execution_date="2016-01-01T00:00:00+00:00"
+labels.try_number="3"'''
+                ),
+                order_by='timestamp asc',
+                page_size=1000,
+                page_token=None,
+            )
+        )
+        assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
         assert [{'end_of_log': False, 'next_page_token': 'TOKEN1'}] == metadata1
 
-        mock_client.return_value.list_entries.return_value.next_page_token = None
+        mock_client.return_value.list_log_entries.return_value.next_page_token = None
         logs, metadata2 = self.stackdriver_task_handler.read(self.ti, 3, metadata1[0])
-        mock_client.return_value.list_entries.assert_called_with(filter_=mock.ANY, page_token="TOKEN1")
-        assert ['MSG3\nMSG4'] == logs
+
+        mock_client.return_value.list_log_entries.assert_called_with(
+            request=ListLogEntriesRequest(
+                resource_names=["projects/project_id"],
+                filter=(
+                    'resource.type="global"\n'
+                    'logName="projects/project_id/logs/airflow"\n'
+                    'labels.task_id="task_for_testing_file_log_handler"\n'
+                    'labels.dag_id="dag_for_testing_file_task_handler"\n'
+                    'labels.execution_date="2016-01-01T00:00:00+00:00"\n'
+                    'labels.try_number="3"'
+                ),
+                order_by='timestamp asc',
+                page_size=1000,
+                page_token="TOKEN1",
+            )
+        )
+        assert [(('default-hostname', 'MSG3\nMSG4'),)] == logs
         assert [{'end_of_log': True}] == metadata2
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
-    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
+    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client')
     def test_should_read_logs_with_download(self, mock_client, mock_get_creds_and_project_id):
-        mock_client.return_value.list_entries.side_effect = [
-            _create_list_response(["MSG1", "MSG2"], "TOKEN1"),
-            _create_list_response(["MSG3", "MSG4"], None),
+        mock_client.return_value.list_log_entries.side_effect = [
+            mock.MagicMock(pages=iter([_create_list_log_entries_response_mock(["MSG1", "MSG2"], "TOKEN1")])),
+            mock.MagicMock(pages=iter([_create_list_log_entries_response_mock(["MSG3", "MSG4"], None)])),
         ]
         mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
 
         logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3, {'download_logs': True})
 
-        assert ['MSG1\nMSG2\nMSG3\nMSG4'] == logs
+        assert [(('default-hostname', 'MSG1\nMSG2\nMSG3\nMSG4'),)] == logs
         assert [{'end_of_log': True}] == metadata1
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
-    @mock.patch(
-        'airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client',
-        **{'return_value.project': 'asf-project'},  # type: ignore
-    )
+    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client')
     def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_creds_and_project_id):
         mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
         resource = Resource(
@@ -226,31 +286,37 @@ class TestStackdriverLoggingHandlerTask(unittest.TestCase):
             labels={
                 "environment.name": 'test-instancce',
                 "location": 'europpe-west-3',
-                "project_id": "asf-project",
+                "project_id": "project_id",
             },
         )
         self.stackdriver_task_handler = StackdriverTaskHandler(
             transport=self.transport_mock, resource=resource
         )
 
-        entry = mock.MagicMock(payload={"message": "TEXT"})
-        page = [entry, entry]
-        mock_client.return_value.list_entries.return_value.pages = (n for n in [page])
-        mock_client.return_value.list_entries.return_value.next_page_token = None
+        entry = mock.MagicMock(json_payload={"message": "TEXT"})
+        page = mock.MagicMock(entries=[entry, entry], next_page_token=None)
+        mock_client.return_value.list_log_entries.return_value.pages = (n for n in [page])
 
         logs, metadata = self.stackdriver_task_handler.read(self.ti)
-        mock_client.return_value.list_entries.assert_called_once_with(
-            filter_='resource.type="cloud_composer_environment"\n'
-            'logName="projects/asf-project/logs/airflow"\n'
-            'resource.labels."environment.name"="test-instancce"\n'
-            'resource.labels.location="europpe-west-3"\n'
-            'resource.labels.project_id="asf-project"\n'
-            'labels.task_id="task_for_testing_file_log_handler"\n'
-            'labels.dag_id="dag_for_testing_file_task_handler"\n'
-            'labels.execution_date="2016-01-01T00:00:00+00:00"',
-            page_token=None,
+        mock_client.return_value.list_log_entries.assert_called_once_with(
+            request=ListLogEntriesRequest(
+                resource_names=["projects/project_id"],
+                filter=(
+                    'resource.type="cloud_composer_environment"\n'
+                    'logName="projects/project_id/logs/airflow"\n'
+                    'resource.labels."environment.name"="test-instancce"\n'
+                    'resource.labels.location="europpe-west-3"\n'
+                    'resource.labels.project_id="project_id"\n'
+                    'labels.task_id="task_for_testing_file_log_handler"\n'
+                    'labels.dag_id="dag_for_testing_file_task_handler"\n'
+                    'labels.execution_date="2016-01-01T00:00:00+00:00"'
+                ),
+                order_by='timestamp asc',
+                page_size=1000,
+                page_token=None,
+            )
         )
-        assert ['TEXT\nTEXT'] == logs
+        assert [(('default-hostname', 'TEXT\nTEXT'),)] == logs
         assert [{'end_of_log': True}] == metadata
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@@ -278,10 +344,9 @@ class TestStackdriverLoggingHandlerTask(unittest.TestCase):
         assert mock_client.return_value == client
 
     @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
-    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
+    @mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.LoggingServiceV2Client')
     def test_should_return_valid_external_url(self, mock_client, mock_get_creds_and_project_id):
         mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
-        mock_client.return_value.project = 'project_id'
 
         stackdriver_task_handler = StackdriverTaskHandler(
             gcp_key_path="KEY_PATH",


[airflow] 07/28: Update compatibility with google-cloud-os-login>=2.0.0 (#13126)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit cfd5a480c2bbec7eab361516ce7f1eed34896349
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Thu Dec 17 11:00:59 2020 +0100

    Update compatibility with google-cloud-os-login>=2.0.0 (#13126)
    
    (cherry picked from commit 1259c712a42d69135dc389de88f79942c70079a3)
---
 airflow/providers/google/cloud/hooks/os_login.py   | 16 +++++++++-------
 setup.py                                           |  2 +-
 .../providers/google/cloud/hooks/test_os_login.py  | 22 ++++++++++++----------
 3 files changed, 22 insertions(+), 18 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/os_login.py b/airflow/providers/google/cloud/hooks/os_login.py
index c7a4234..361ea60 100644
--- a/airflow/providers/google/cloud/hooks/os_login.py
+++ b/airflow/providers/google/cloud/hooks/os_login.py
@@ -17,7 +17,7 @@
 
 from typing import Dict, Optional, Sequence, Union
 
-from google.cloud.oslogin_v1 import OsLoginServiceClient
+from google.cloud.oslogin_v1 import ImportSshPublicKeyResponse, OsLoginServiceClient
 
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
 
@@ -54,7 +54,7 @@ class OSLoginHook(GoogleBaseHook):
     @GoogleBaseHook.fallback_to_default_project_id
     def import_ssh_public_key(
         self, user: str, ssh_public_key: Dict, project_id: str, retry=None, timeout=None, metadata=None
-    ):
+    ) -> ImportSshPublicKeyResponse:
         """
         Adds an SSH public key and returns the profile information. Default POSIX
         account information is set when no username and UID exist as part of the
@@ -74,14 +74,16 @@ class OSLoginHook(GoogleBaseHook):
         :type timeout: Optional[float]
         :param metadata: Additional metadata that is provided to the method.
         :type metadata: Optional[Sequence[Tuple[str, str]]]
-        :return:  A :class:`~google.cloud.oslogin_v1.types.ImportSshPublicKeyResponse` instance.
+        :return: A :class:`~google.cloud.oslogin_v1.ImportSshPublicKeyResponse` instance.
         """
         conn = self.get_conn()
         return conn.import_ssh_public_key(
-            parent=OsLoginServiceClient.user_path(user=user),
-            ssh_public_key=ssh_public_key,
-            project_id=project_id,
+            request=dict(
+                parent=f"users/{user}",
+                ssh_public_key=ssh_public_key,
+                project_id=project_id,
+            ),
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
diff --git a/setup.py b/setup.py
index 7071795..0586bf3 100644
--- a/setup.py
+++ b/setup.py
@@ -295,7 +295,7 @@ google = [
     'google-cloud-logging>=1.14.0,<2.0.0',
     'google-cloud-memcache>=0.2.0',
     'google-cloud-monitoring>=0.34.0,<2.0.0',
-    'google-cloud-os-login>=1.0.0,<2.0.0',
+    'google-cloud-os-login>=2.0.0,<3.0.0',
     'google-cloud-pubsub>=1.0.0,<2.0.0',
     'google-cloud-redis>=0.3.0,<2.0.0',
     'google-cloud-secret-manager>=0.2.0,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_os_login.py b/tests/providers/google/cloud/hooks/test_os_login.py
index 303f1ea..d2b88e4 100644
--- a/tests/providers/google/cloud/hooks/test_os_login.py
+++ b/tests/providers/google/cloud/hooks/test_os_login.py
@@ -38,7 +38,7 @@ TEST_CREDENTIALS = mock.MagicMock()
 TEST_BODY: Dict = mock.MagicMock()
 TEST_RETRY: Retry = mock.MagicMock()
 TEST_TIMEOUT: float = 4
-TEST_METADATA: Sequence[Tuple[str, str]] = []
+TEST_METADATA: Sequence[Tuple[str, str]] = ()
 TEST_PARENT: str = "users/test-user"
 
 
@@ -67,9 +67,11 @@ class TestOSLoginHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.import_ssh_public_key.assert_called_once_with(
-            parent=TEST_PARENT,
-            ssh_public_key=TEST_BODY,
-            project_id=TEST_PROJECT_ID,
+            request=dict(
+                parent=TEST_PARENT,
+                ssh_public_key=TEST_BODY,
+                project_id=TEST_PROJECT_ID,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -101,9 +103,11 @@ class TestOSLoginHookWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.import_ssh_public_key.assert_called_once_with(
-            parent=TEST_PARENT,
-            ssh_public_key=TEST_BODY,
-            project_id=TEST_PROJECT_ID_2,
+            request=dict(
+                parent=TEST_PARENT,
+                ssh_public_key=TEST_BODY,
+                project_id=TEST_PROJECT_ID_2,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -135,9 +139,7 @@ class TestOSLoginHookWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.import_ssh_public_key.assert_called_once_with(
-            parent=TEST_PARENT,
-            ssh_public_key=TEST_BODY,
-            project_id=TEST_PROJECT_ID,
+            request=dict(parent=TEST_PARENT, ssh_public_key=TEST_BODY, project_id=TEST_PROJECT_ID),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,


[airflow] 15/28: Support google-cloud-datacatalog>=3.0.0 (#13534)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 9ed976e3cd2e1348bb5f8695306820a711578aee
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Mon Jan 11 09:39:19 2021 +0100

    Support google-cloud-datacatalog>=3.0.0 (#13534)
    
    (cherry picked from commit 947dbb73bba736eb146f33117545a18fc2fd3c09)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |   2 +-
 .../cloud/example_dags/example_datacatalog.py      |  10 +-
 .../providers/google/cloud/hooks/datacatalog.py    | 220 ++++++++++++-------
 .../google/cloud/operators/datacatalog.py          |  47 ++--
 setup.py                                           |   2 +-
 .../google/cloud/hooks/test_datacatalog.py         | 237 +++++++++++++--------
 .../google/cloud/operators/test_datacatalog.py     |  49 +++--
 7 files changed, 357 insertions(+), 210 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index eca05df..d80f9e1 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -30,7 +30,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
 | Library name | Previous constraints | Current constraints | |
 | --- | --- | --- | --- |
 | [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
-| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=1.0.0,<2.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
+| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
 | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
 | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
 | [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
diff --git a/airflow/providers/google/cloud/example_dags/example_datacatalog.py b/airflow/providers/google/cloud/example_dags/example_datacatalog.py
index c8597a6..cc4b73a 100644
--- a/airflow/providers/google/cloud/example_dags/example_datacatalog.py
+++ b/airflow/providers/google/cloud/example_dags/example_datacatalog.py
@@ -19,7 +19,7 @@
 """
 Example Airflow DAG that interacts with Google Data Catalog service
 """
-from google.cloud.datacatalog_v1beta1.proto.tags_pb2 import FieldType, TagField, TagTemplateField
+from google.cloud.datacatalog_v1beta1 import FieldType, TagField, TagTemplateField
 
 from airflow import models
 from airflow.operators.bash_operator import BashOperator
@@ -91,7 +91,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte
         entry_id=ENTRY_ID,
         entry={
             "display_name": "Wizard",
-            "type": "FILESET",
+            "type_": "FILESET",
             "gcs_fileset_spec": {"file_patterns": ["gs://test-datacatalog/**"]},
         },
     )
@@ -144,7 +144,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte
             "display_name": "Awesome Tag Template",
             "fields": {
                 FIELD_NAME_1: TagTemplateField(
-                    display_name="first-field", type=FieldType(primitive_type="STRING")
+                    display_name="first-field", type_=dict(primitive_type="STRING")
                 )
             },
         },
@@ -172,7 +172,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte
         tag_template=TEMPLATE_ID,
         tag_template_field_id=FIELD_NAME_2,
         tag_template_field=TagTemplateField(
-            display_name="second-field", type=FieldType(primitive_type="STRING")
+            display_name="second-field", type_=FieldType(primitive_type="STRING")
         ),
     )
     # [END howto_operator_gcp_datacatalog_create_tag_template_field]
@@ -305,7 +305,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte
     # [START howto_operator_gcp_datacatalog_lookup_entry_result]
     lookup_entry_result = BashOperator(
         task_id="lookup_entry_result",
-        bash_command="echo \"{{ task_instance.xcom_pull('lookup_entry')['displayName'] }}\"",
+        bash_command="echo \"{{ task_instance.xcom_pull('lookup_entry')['display_name'] }}\"",
     )
     # [END howto_operator_gcp_datacatalog_lookup_entry_result]
 
diff --git a/airflow/providers/google/cloud/hooks/datacatalog.py b/airflow/providers/google/cloud/hooks/datacatalog.py
index 70b488d..0d6cc75 100644
--- a/airflow/providers/google/cloud/hooks/datacatalog.py
+++ b/airflow/providers/google/cloud/hooks/datacatalog.py
@@ -18,16 +18,18 @@
 from typing import Dict, Optional, Sequence, Tuple, Union
 
 from google.api_core.retry import Retry
-from google.cloud.datacatalog_v1beta1 import DataCatalogClient
-from google.cloud.datacatalog_v1beta1.types import (
+from google.cloud import datacatalog
+from google.cloud.datacatalog_v1beta1 import (
+    CreateTagRequest,
+    DataCatalogClient,
     Entry,
     EntryGroup,
-    FieldMask,
     SearchCatalogRequest,
     Tag,
     TagTemplate,
     TagTemplateField,
 )
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow import AirflowException
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@@ -115,10 +117,13 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        parent = DataCatalogClient.entry_group_path(project_id, location, entry_group)
+        parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}"
         self.log.info('Creating a new entry: parent=%s', parent)
         result = client.create_entry(
-            parent=parent, entry_id=entry_id, entry=entry, retry=retry, timeout=timeout, metadata=metadata
+            request={'parent': parent, 'entry_id': entry_id, 'entry': entry},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
         self.log.info('Created a entry: name=%s', result.name)
         return result
@@ -161,16 +166,14 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        parent = DataCatalogClient.location_path(project_id, location)
+        parent = f"projects/{project_id}/locations/{location}"
         self.log.info('Creating a new entry group: parent=%s', parent)
 
         result = client.create_entry_group(
-            parent=parent,
-            entry_group_id=entry_group_id,
-            entry_group=entry_group,
+            request={'parent': parent, 'entry_group_id': entry_group_id, 'entry_group': entry_group},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         self.log.info('Created a entry group: name=%s', result.name)
 
@@ -218,15 +221,34 @@ class CloudDataCatalogHook(GoogleBaseHook):
         """
         client = self.get_conn()
         if template_id:
-            template_path = DataCatalogClient.tag_template_path(project_id, location, template_id)
+            template_path = f"projects/{project_id}/locations/{location}/tagTemplates/{template_id}"
             if isinstance(tag, Tag):
                 tag.template = template_path
             else:
                 tag["template"] = template_path
-        parent = DataCatalogClient.entry_path(project_id, location, entry_group, entry)
+        parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
 
         self.log.info('Creating a new tag: parent=%s', parent)
-        result = client.create_tag(parent=parent, tag=tag, retry=retry, timeout=timeout, metadata=metadata)
+        # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a
+        # primitive type, so we need to convert it manually.
+        # See: https://github.com/googleapis/python-datacatalog/issues/84
+        if isinstance(tag, dict):
+            tag = Tag(
+                name=tag.get('name'),
+                template=tag.get('template'),
+                template_display_name=tag.get('template_display_name'),
+                column=tag.get('column'),
+                fields={
+                    k: datacatalog.TagField(**v) if isinstance(v, dict) else v
+                    for k, v in tag.get("fields", {}).items()
+                },
+            )
+        request = CreateTagRequest(
+            parent=parent,
+            tag=tag,
+        )
+
+        result = client.create_tag(request=request, retry=retry, timeout=timeout, metadata=metadata or ())
         self.log.info('Created a tag: name=%s', result.name)
 
         return result
@@ -267,17 +289,30 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        parent = DataCatalogClient.location_path(project_id, location)
+        parent = f"projects/{project_id}/locations/{location}"
 
         self.log.info('Creating a new tag template: parent=%s', parent)
+        # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a
+        # primitive type, so we need to convert it manually.
+        # See: https://github.com/googleapis/python-datacatalog/issues/84
+        if isinstance(tag_template, dict):
+            tag_template = datacatalog.TagTemplate(
+                name=tag_template.get("name"),
+                display_name=tag_template.get("display_name"),
+                fields={
+                    k: datacatalog.TagTemplateField(**v) if isinstance(v, dict) else v
+                    for k, v in tag_template.get("fields", {}).items()
+                },
+            )
 
+        request = datacatalog.CreateTagTemplateRequest(
+            parent=parent, tag_template_id=tag_template_id, tag_template=tag_template
+        )
         result = client.create_tag_template(
-            parent=parent,
-            tag_template_id=tag_template_id,
-            tag_template=tag_template,
+            request=request,
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         self.log.info('Created a tag template: name=%s', result.name)
 
@@ -325,17 +360,19 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        parent = DataCatalogClient.tag_template_path(project_id, location, tag_template)
+        parent = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}"
 
         self.log.info('Creating a new tag template field: parent=%s', parent)
 
         result = client.create_tag_template_field(
-            parent=parent,
-            tag_template_field_id=tag_template_field_id,
-            tag_template_field=tag_template_field,
+            request={
+                'parent': parent,
+                'tag_template_field_id': tag_template_field_id,
+                'tag_template_field': tag_template_field,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
         self.log.info('Created a tag template field: name=%s', result.name)
@@ -375,9 +412,9 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.entry_path(project_id, location, entry_group, entry)
+        name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
         self.log.info('Deleting a entry: name=%s', name)
-        client.delete_entry(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        client.delete_entry(request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ())
         self.log.info('Deleted a entry: name=%s', name)
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -412,10 +449,12 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.entry_group_path(project_id, location, entry_group)
+        name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}"
 
         self.log.info('Deleting a entry group: name=%s', name)
-        client.delete_entry_group(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        client.delete_entry_group(
+            request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
         self.log.info('Deleted a entry group: name=%s', name)
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -454,10 +493,12 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.tag_path(project_id, location, entry_group, entry, tag)
+        name = (
+            f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}/tags/{tag}"
+        )
 
         self.log.info('Deleting a tag: name=%s', name)
-        client.delete_tag(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        client.delete_tag(request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ())
         self.log.info('Deleted a tag: name=%s', name)
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -495,10 +536,12 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.tag_template_path(project_id, location, tag_template)
+        name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}"
 
         self.log.info('Deleting a tag template: name=%s', name)
-        client.delete_tag_template(name=name, force=force, retry=retry, timeout=timeout, metadata=metadata)
+        client.delete_tag_template(
+            request={'name': name, 'force': force}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
         self.log.info('Deleted a tag template: name=%s', name)
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -537,11 +580,11 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.tag_template_field_path(project_id, location, tag_template, field)
+        name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}"
 
         self.log.info('Deleting a tag template field: name=%s', name)
         client.delete_tag_template_field(
-            name=name, force=force, retry=retry, timeout=timeout, metadata=metadata
+            request={'name': name, 'force': force}, retry=retry, timeout=timeout, metadata=metadata or ()
         )
         self.log.info('Deleted a tag template field: name=%s', name)
 
@@ -578,10 +621,12 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.entry_path(project_id, location, entry_group, entry)
+        name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
 
         self.log.info('Getting a entry: name=%s', name)
-        result = client.get_entry(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        result = client.get_entry(
+            request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
         self.log.info('Received a entry: name=%s', result.name)
 
         return result
@@ -607,8 +652,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :param read_mask: The fields to return. If not set or empty, all fields are returned.
 
             If a dict is provided, it must be of the same form as the protobuf message
-            :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-        :type read_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+            :class:`~google.protobuf.field_mask_pb2.FieldMask`
+        :type read_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
         :param project_id: The ID of the Google Cloud project that owns the entry group.
             If set to ``None`` or missing, the default project_id from the Google Cloud connection is used.
         :type project_id: str
@@ -622,12 +667,15 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.entry_group_path(project_id, location, entry_group)
+        name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}"
 
         self.log.info('Getting a entry group: name=%s', name)
 
         result = client.get_entry_group(
-            name=name, read_mask=read_mask, retry=retry, timeout=timeout, metadata=metadata
+            request={'name': name, 'read_mask': read_mask},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
 
         self.log.info('Received a entry group: name=%s', result.name)
@@ -664,11 +712,13 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.tag_template_path(project_id, location, tag_template)
+        name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}"
 
         self.log.info('Getting a tag template: name=%s', name)
 
-        result = client.get_tag_template(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        result = client.get_tag_template(
+            request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
         self.log.info('Received a tag template: name=%s', result.name)
 
@@ -712,12 +762,15 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        parent = DataCatalogClient.entry_path(project_id, location, entry_group, entry)
+        parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
 
         self.log.info('Listing tag on entry: entry_name=%s', parent)
 
         result = client.list_tags(
-            parent=parent, page_size=page_size, retry=retry, timeout=timeout, metadata=metadata
+            request={'parent': parent, 'page_size': page_size},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
 
         self.log.info('Received tags.')
@@ -811,12 +864,18 @@ class CloudDataCatalogHook(GoogleBaseHook):
         if linked_resource:
             self.log.info('Getting entry: linked_resource=%s', linked_resource)
             result = client.lookup_entry(
-                linked_resource=linked_resource, retry=retry, timeout=timeout, metadata=metadata
+                request={'linked_resource': linked_resource},
+                retry=retry,
+                timeout=timeout,
+                metadata=metadata or (),
             )
         else:
             self.log.info('Getting entry: sql_resource=%s', sql_resource)
             result = client.lookup_entry(
-                sql_resource=sql_resource, retry=retry, timeout=timeout, metadata=metadata
+                request={'sql_resource': sql_resource},
+                retry=retry,
+                timeout=timeout,
+                metadata=metadata or (),
             )
         self.log.info('Received entry. name=%s', result.name)
 
@@ -860,18 +919,17 @@ class CloudDataCatalogHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = DataCatalogClient.tag_template_field_path(project_id, location, tag_template, field)
+        name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}"
 
         self.log.info(
             'Renaming field: old_name=%s, new_tag_template_field_id=%s', name, new_tag_template_field_id
         )
 
         result = client.rename_tag_template_field(
-            name=name,
-            new_tag_template_field_id=new_tag_template_field_id,
+            request={'name': name, 'new_tag_template_field_id': new_tag_template_field_id},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
         self.log.info('Renamed tag template field.')
@@ -946,13 +1004,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
             order_by,
         )
         result = client.search_catalog(
-            scope=scope,
-            query=query,
-            page_size=page_size,
-            order_by=order_by,
+            request={'scope': scope, 'query': query, 'page_size': page_size, 'order_by': order_by},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
         self.log.info('Received items.')
@@ -984,8 +1039,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
             updated.
 
             If a dict is provided, it must be of the same form as the protobuf message
-            :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-        :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+            :class:`~google.protobuf.field_mask_pb2.FieldMask`
+        :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
         :param location: Required. The location of the entry to update.
         :type location: str
         :param entry_group: The entry group ID for the entry that is being updated.
@@ -1006,7 +1061,9 @@ class CloudDataCatalogHook(GoogleBaseHook):
         """
         client = self.get_conn()
         if project_id and location and entry_group and entry_id:
-            full_entry_name = DataCatalogClient.entry_path(project_id, location, entry_group, entry_id)
+            full_entry_name = (
+                f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry_id}"
+            )
             if isinstance(entry, Entry):
                 entry.name = full_entry_name
             elif isinstance(entry, dict):
@@ -1025,7 +1082,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
         if isinstance(entry, dict):
             entry = Entry(**entry)
         result = client.update_entry(
-            entry=entry, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata
+            request={'entry': entry, 'update_mask': update_mask},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
 
         self.log.info('Updated entry.')
@@ -1059,7 +1119,7 @@ class CloudDataCatalogHook(GoogleBaseHook):
 
             If a dict is provided, it must be of the same form as the protobuf message
             :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-        :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+        :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
         :param location: Required. The location of the tag to rename.
         :type location: str
         :param entry_group: The entry group ID for the tag that is being updated.
@@ -1082,7 +1142,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
         """
         client = self.get_conn()
         if project_id and location and entry_group and entry and tag_id:
-            full_tag_name = DataCatalogClient.tag_path(project_id, location, entry_group, entry, tag_id)
+            full_tag_name = (
+                f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
+                f"/tags/{tag_id}"
+            )
             if isinstance(tag, Tag):
                 tag.name = full_tag_name
             elif isinstance(tag, dict):
@@ -1102,7 +1165,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
         if isinstance(tag, dict):
             tag = Tag(**tag)
         result = client.update_tag(
-            tag=tag, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata
+            request={'tag': tag, 'update_mask': update_mask},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
         self.log.info('Updated tag.')
 
@@ -1137,8 +1203,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
             If absent or empty, all of the allowed fields above will be updated.
 
             If a dict is provided, it must be of the same form as the protobuf message
-            :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-        :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+            :class:`~google.protobuf.field_mask_pb2.FieldMask`
+        :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
         :param location: Required. The location of the tag template to rename.
         :type location: str
         :param tag_template_id: Optional. The tag template ID for the entry that is being updated.
@@ -1157,8 +1223,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
         """
         client = self.get_conn()
         if project_id and location and tag_template:
-            full_tag_template_name = DataCatalogClient.tag_template_path(
-                project_id, location, tag_template_id
+            full_tag_template_name = (
+                f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template_id}"
             )
             if isinstance(tag_template, TagTemplate):
                 tag_template.name = full_tag_template_name
@@ -1179,11 +1245,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
         if isinstance(tag_template, dict):
             tag_template = TagTemplate(**tag_template)
         result = client.update_tag_template(
-            tag_template=tag_template,
-            update_mask=update_mask,
+            request={'tag_template': tag_template, 'update_mask': update_mask},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         self.log.info('Updated tag template.')
 
@@ -1222,8 +1287,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
             Therefore, enum values can only be added, existing enum values cannot be deleted nor renamed.
 
             If a dict is provided, it must be of the same form as the protobuf message
-            :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-        :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+            :class:`~google.protobuf.field_mask_pb2.FieldMask`
+        :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
         :param tag_template_field_name: Optional. The name of the tag template field to rename.
         :type tag_template_field_name: str
         :param location: Optional. The location of the tag to rename.
@@ -1246,19 +1311,22 @@ class CloudDataCatalogHook(GoogleBaseHook):
         """
         client = self.get_conn()
         if project_id and location and tag_template and tag_template_field_id:
-            tag_template_field_name = DataCatalogClient.tag_template_field_path(
-                project_id, location, tag_template, tag_template_field_id
+            tag_template_field_name = (
+                f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}"
+                f"/fields/{tag_template_field_id}"
             )
 
         self.log.info("Updating tag template field: name=%s", tag_template_field_name)
 
         result = client.update_tag_template_field(
-            name=tag_template_field_name,
-            tag_template_field=tag_template_field,
-            update_mask=update_mask,
+            request={
+                'name': tag_template_field_name,
+                'tag_template_field': tag_template_field,
+                'update_mask': update_mask,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         self.log.info('Updated tag template field.')
 
diff --git a/airflow/providers/google/cloud/operators/datacatalog.py b/airflow/providers/google/cloud/operators/datacatalog.py
index 00b2765..4b0da05 100644
--- a/airflow/providers/google/cloud/operators/datacatalog.py
+++ b/airflow/providers/google/cloud/operators/datacatalog.py
@@ -19,17 +19,16 @@ from typing import Dict, Optional, Sequence, Tuple, Union
 
 from google.api_core.exceptions import AlreadyExists, NotFound
 from google.api_core.retry import Retry
-from google.cloud.datacatalog_v1beta1 import DataCatalogClient
+from google.cloud.datacatalog_v1beta1 import DataCatalogClient, SearchCatalogResult
 from google.cloud.datacatalog_v1beta1.types import (
     Entry,
     EntryGroup,
-    FieldMask,
     SearchCatalogRequest,
     Tag,
     TagTemplate,
     TagTemplateField,
 )
-from google.protobuf.json_format import MessageToDict
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook
@@ -153,7 +152,7 @@ class CloudDataCatalogCreateEntryOperator(BaseOperator):
         _, _, entry_id = result.name.rpartition("/")
         self.log.info("Current entry_id ID: %s", entry_id)
         context["task_instance"].xcom_push(key="entry_id", value=entry_id)
-        return MessageToDict(result)
+        return Entry.to_dict(result)
 
 
 class CloudDataCatalogCreateEntryGroupOperator(BaseOperator):
@@ -268,7 +267,7 @@ class CloudDataCatalogCreateEntryGroupOperator(BaseOperator):
         _, _, entry_group_id = result.name.rpartition("/")
         self.log.info("Current entry group ID: %s", entry_group_id)
         context["task_instance"].xcom_push(key="entry_group_id", value=entry_group_id)
-        return MessageToDict(result)
+        return EntryGroup.to_dict(result)
 
 
 class CloudDataCatalogCreateTagOperator(BaseOperator):
@@ -404,7 +403,7 @@ class CloudDataCatalogCreateTagOperator(BaseOperator):
         _, _, tag_id = tag.name.rpartition("/")
         self.log.info("Current Tag ID: %s", tag_id)
         context["task_instance"].xcom_push(key="tag_id", value=tag_id)
-        return MessageToDict(tag)
+        return Tag.to_dict(tag)
 
 
 class CloudDataCatalogCreateTagTemplateOperator(BaseOperator):
@@ -516,7 +515,7 @@ class CloudDataCatalogCreateTagTemplateOperator(BaseOperator):
         _, _, tag_template = result.name.rpartition("/")
         self.log.info("Current Tag ID: %s", tag_template)
         context["task_instance"].xcom_push(key="tag_template_id", value=tag_template)
-        return MessageToDict(result)
+        return TagTemplate.to_dict(result)
 
 
 class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator):
@@ -638,7 +637,7 @@ class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator):
 
         self.log.info("Current Tag ID: %s", self.tag_template_field_id)
         context["task_instance"].xcom_push(key="tag_template_field_id", value=self.tag_template_field_id)
-        return MessageToDict(result)
+        return TagTemplateField.to_dict(result)
 
 
 class CloudDataCatalogDeleteEntryOperator(BaseOperator):
@@ -1216,7 +1215,7 @@ class CloudDataCatalogGetEntryOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(result)
+        return Entry.to_dict(result)
 
 
 class CloudDataCatalogGetEntryGroupOperator(BaseOperator):
@@ -1234,8 +1233,8 @@ class CloudDataCatalogGetEntryGroupOperator(BaseOperator):
     :param read_mask: The fields to return. If not set or empty, all fields are returned.
 
         If a dict is provided, it must be of the same form as the protobuf message
-        :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-    :type read_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+        :class:`~google.protobuf.field_mask_pb2.FieldMask`
+    :type read_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
     :param project_id: The ID of the Google Cloud project that owns the entry group.
         If set to ``None`` or missing, the default project_id from the Google Cloud connection is used.
     :type project_id: Optional[str]
@@ -1312,7 +1311,7 @@ class CloudDataCatalogGetEntryGroupOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(result)
+        return EntryGroup.to_dict(result)
 
 
 class CloudDataCatalogGetTagTemplateOperator(BaseOperator):
@@ -1399,7 +1398,7 @@ class CloudDataCatalogGetTagTemplateOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(result)
+        return TagTemplate.to_dict(result)
 
 
 class CloudDataCatalogListTagsOperator(BaseOperator):
@@ -1501,7 +1500,7 @@ class CloudDataCatalogListTagsOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return [MessageToDict(item) for item in result]
+        return [Tag.to_dict(item) for item in result]
 
 
 class CloudDataCatalogLookupEntryOperator(BaseOperator):
@@ -1589,7 +1588,7 @@ class CloudDataCatalogLookupEntryOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(result)
+        return Entry.to_dict(result)
 
 
 class CloudDataCatalogRenameTagTemplateFieldOperator(BaseOperator):
@@ -1809,7 +1808,7 @@ class CloudDataCatalogSearchCatalogOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return [MessageToDict(item) for item in result]
+        return [SearchCatalogResult.to_dict(item) for item in result]
 
 
 class CloudDataCatalogUpdateEntryOperator(BaseOperator):
@@ -1829,8 +1828,8 @@ class CloudDataCatalogUpdateEntryOperator(BaseOperator):
         updated.
 
         If a dict is provided, it must be of the same form as the protobuf message
-        :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-    :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+        :class:`~google.protobuf.field_mask_pb2.FieldMask`
+    :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
     :param location: Required. The location of the entry to update.
     :type location: str
     :param entry_group: The entry group ID for the entry that is being updated.
@@ -1940,8 +1939,8 @@ class CloudDataCatalogUpdateTagOperator(BaseOperator):
         updated. Currently the only modifiable field is the field ``fields``.
 
         If a dict is provided, it must be of the same form as the protobuf message
-        :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-    :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+        :class:`~google.protobuf.field_mask_pb2.FieldMask`
+    :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
     :param location: Required. The location of the tag to rename.
     :type location: str
     :param entry_group: The entry group ID for the tag that is being updated.
@@ -2060,8 +2059,8 @@ class CloudDataCatalogUpdateTagTemplateOperator(BaseOperator):
         If absent or empty, all of the allowed fields above will be updated.
 
         If a dict is provided, it must be of the same form as the protobuf message
-        :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-    :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+        :class:`~google.protobuf.field_mask_pb2.FieldMask`
+    :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
     :param location: Required. The location of the tag template to rename.
     :type location: str
     :param tag_template_id: Optional. The tag template ID for the entry that is being updated.
@@ -2172,8 +2171,8 @@ class CloudDataCatalogUpdateTagTemplateFieldOperator(BaseOperator):
         Therefore, enum values can only be added, existing enum values cannot be deleted nor renamed.
 
         If a dict is provided, it must be of the same form as the protobuf message
-        :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
-    :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+        :class:`~google.protobuf.field_mask_pb2.FieldMask`
+    :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
     :param tag_template_field_name: Optional. The name of the tag template field to rename.
     :type tag_template_field_name: str
     :param location: Optional. The location of the tag to rename.
diff --git a/setup.py b/setup.py
index 75f5db5..5314814 100644
--- a/setup.py
+++ b/setup.py
@@ -287,7 +287,7 @@ google = [
     'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0',
     'google-cloud-bigtable>=1.0.0,<2.0.0',
     'google-cloud-container>=0.1.1,<2.0.0',
-    'google-cloud-datacatalog>=1.0.0,<2.0.0',
+    'google-cloud-datacatalog>=3.0.0,<4.0.0',
     'google-cloud-dataproc>=1.0.1,<2.0.0',
     'google-cloud-dlp>=0.11.0,<2.0.0',
     'google-cloud-kms>=2.0.0,<3.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_datacatalog.py b/tests/providers/google/cloud/hooks/test_datacatalog.py
index f5192c5..99d785f 100644
--- a/tests/providers/google/cloud/hooks/test_datacatalog.py
+++ b/tests/providers/google/cloud/hooks/test_datacatalog.py
@@ -22,6 +22,7 @@ from unittest import TestCase, mock
 
 import pytest
 from google.api_core.retry import Retry
+from google.cloud.datacatalog_v1beta1 import CreateTagRequest, CreateTagTemplateRequest
 from google.cloud.datacatalog_v1beta1.types import Entry, Tag, TagTemplate
 
 from airflow import AirflowException
@@ -38,7 +39,7 @@ TEST_ENTRY_ID: str = "test-entry-id"
 TEST_ENTRY: Dict = {}
 TEST_RETRY: Retry = Retry()
 TEST_TIMEOUT: float = 4
-TEST_METADATA: Sequence[Tuple[str, str]] = []
+TEST_METADATA: Sequence[Tuple[str, str]] = ()
 TEST_ENTRY_GROUP_ID: str = "test-entry-group-id"
 TEST_ENTRY_GROUP: Dict = {}
 TEST_TAG: Dict = {}
@@ -102,7 +103,7 @@ class TestCloudDataCatalog(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.lookup_entry.assert_called_once_with(
-            linked_resource=TEST_LINKED_RESOURCE,
+            request=dict(linked_resource=TEST_LINKED_RESOURCE),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -118,7 +119,10 @@ class TestCloudDataCatalog(TestCase):
             sql_resource=TEST_SQL_RESOURCE, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
         )
         mock_get_conn.return_value.lookup_entry.assert_called_once_with(
-            sql_resource=TEST_SQL_RESOURCE, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
+            request=dict(sql_resource=TEST_SQL_RESOURCE),
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
         )
 
     @mock.patch(
@@ -148,10 +152,9 @@ class TestCloudDataCatalog(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.search_catalog.assert_called_once_with(
-            scope=TEST_SCOPE,
-            query=TEST_QUERY,
-            page_size=TEST_PAGE_SIZE,
-            order_by=TEST_ORDER_BY,
+            request=dict(
+                scope=TEST_SCOPE, query=TEST_QUERY, page_size=TEST_PAGE_SIZE, order_by=TEST_ORDER_BY
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -184,9 +187,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_entry.assert_called_once_with(
-            parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
-            entry_id=TEST_ENTRY_ID,
-            entry=TEST_ENTRY,
+            request=dict(
+                parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
+                entry_id=TEST_ENTRY_ID,
+                entry=TEST_ENTRY,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -207,9 +212,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_entry_group.assert_called_once_with(
-            parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1),
-            entry_group_id=TEST_ENTRY_GROUP_ID,
-            entry_group=TEST_ENTRY_GROUP,
+            request=dict(
+                parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1),
+                entry_group_id=TEST_ENTRY_GROUP_ID,
+                entry_group=TEST_ENTRY_GROUP,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -232,8 +239,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_tag.assert_called_once_with(
-            parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
-            tag={"template": TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)},
+            request=CreateTagRequest(
+                parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+                tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -256,8 +265,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_tag.assert_called_once_with(
-            parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
-            tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
+            request=CreateTagRequest(
+                parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+                tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -278,9 +289,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_tag_template.assert_called_once_with(
-            parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1),
-            tag_template_id=TEST_TAG_TEMPLATE_ID,
-            tag_template=TEST_TAG_TEMPLATE,
+            request=CreateTagTemplateRequest(
+                parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1),
+                tag_template_id=TEST_TAG_TEMPLATE_ID,
+                tag_template=TEST_TAG_TEMPLATE,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -302,9 +315,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_tag_template_field.assert_called_once_with(
-            parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
-            tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID,
-            tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+            request=dict(
+                parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
+                tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID,
+                tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -325,7 +340,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_entry.assert_called_once_with(
-            name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+            request=dict(
+                name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -345,7 +362,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_entry_group.assert_called_once_with(
-            name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
+            request=dict(
+                name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -367,7 +386,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_tag.assert_called_once_with(
-            name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1),
+            request=dict(
+                name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -388,8 +409,7 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_tag_template.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
-            force=TEST_FORCE,
+            request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1), force=TEST_FORCE),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -411,8 +431,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_tag_template_field.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
-            force=TEST_FORCE,
+            request=dict(
+                name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
+                force=TEST_FORCE,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -433,7 +455,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_entry.assert_called_once_with(
-            name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+            request=dict(
+                name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -454,8 +478,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_entry_group.assert_called_once_with(
-            name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
-            read_mask=TEST_READ_MASK,
+            request=dict(
+                name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
+                read_mask=TEST_READ_MASK,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -475,7 +501,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_tag_template.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
+            request=dict(
+                name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -497,8 +525,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.list_tags.assert_called_once_with(
-            parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
-            page_size=TEST_PAGE_SIZE,
+            request=dict(
+                parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+                page_size=TEST_PAGE_SIZE,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -524,8 +554,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.list_tags.assert_called_once_with(
-            parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
-            page_size=100,
+            request=dict(
+                parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+                page_size=100,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -548,8 +580,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.rename_tag_template_field.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
-            new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID,
+            request=dict(
+                name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
+                new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -572,8 +606,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_entry.assert_called_once_with(
-            entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1)),
-            update_mask=TEST_UPDATE_MASK,
+            request=dict(
+                entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1)),
+                update_mask=TEST_UPDATE_MASK,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -597,8 +633,7 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_tag.assert_called_once_with(
-            tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1)),
-            update_mask=TEST_UPDATE_MASK,
+            request=dict(tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1)), update_mask=TEST_UPDATE_MASK),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -620,8 +655,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_tag_template.assert_called_once_with(
-            tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
-            update_mask=TEST_UPDATE_MASK,
+            request=dict(
+                tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
+                update_mask=TEST_UPDATE_MASK,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -644,9 +681,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_tag_template_field.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
-            tag_template_field=TEST_TAG_TEMPLATE_FIELD,
-            update_mask=TEST_UPDATE_MASK,
+            request=dict(
+                name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
+                tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+                update_mask=TEST_UPDATE_MASK,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -680,9 +719,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_entry.assert_called_once_with(
-            parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
-            entry_id=TEST_ENTRY_ID,
-            entry=TEST_ENTRY,
+            request=dict(
+                parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
+                entry_id=TEST_ENTRY_ID,
+                entry=TEST_ENTRY,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -704,9 +745,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_entry_group.assert_called_once_with(
-            parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2),
-            entry_group_id=TEST_ENTRY_GROUP_ID,
-            entry_group=TEST_ENTRY_GROUP,
+            request=dict(
+                parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2),
+                entry_group_id=TEST_ENTRY_GROUP_ID,
+                entry_group=TEST_ENTRY_GROUP,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -730,8 +773,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_tag.assert_called_once_with(
-            parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
-            tag={"template": TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)},
+            request=CreateTagRequest(
+                parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
+                tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -755,8 +800,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_tag.assert_called_once_with(
-            parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
-            tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
+            request=CreateTagRequest(
+                parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
+                tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -778,9 +825,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_tag_template.assert_called_once_with(
-            parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2),
-            tag_template_id=TEST_TAG_TEMPLATE_ID,
-            tag_template=TEST_TAG_TEMPLATE,
+            request=CreateTagTemplateRequest(
+                parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2),
+                tag_template_id=TEST_TAG_TEMPLATE_ID,
+                tag_template=TEST_TAG_TEMPLATE,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -803,9 +852,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.create_tag_template_field.assert_called_once_with(
-            parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2),
-            tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID,
-            tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+            request=dict(
+                parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2),
+                tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID,
+                tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -827,7 +878,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_entry.assert_called_once_with(
-            name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
+            request=dict(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -848,7 +899,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_entry_group.assert_called_once_with(
-            name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
+            request=dict(name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2)),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -871,7 +922,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_tag.assert_called_once_with(
-            name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2),
+            request=dict(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -893,8 +944,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_tag_template.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2),
-            force=TEST_FORCE,
+            request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2), force=TEST_FORCE),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -917,8 +967,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_tag_template_field.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
-            force=TEST_FORCE,
+            request=dict(name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2), force=TEST_FORCE),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -940,7 +989,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_entry.assert_called_once_with(
-            name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
+            request=dict(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -962,8 +1011,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_entry_group.assert_called_once_with(
-            name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
-            read_mask=TEST_READ_MASK,
+            request=dict(
+                name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
+                read_mask=TEST_READ_MASK,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -984,7 +1035,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_tag_template.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2),
+            request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -1007,8 +1058,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.list_tags.assert_called_once_with(
-            parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
-            page_size=TEST_PAGE_SIZE,
+            request=dict(parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), page_size=TEST_PAGE_SIZE),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -1035,8 +1085,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.list_tags.assert_called_once_with(
-            parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
-            page_size=100,
+            request=dict(parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), page_size=100),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -1060,8 +1109,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.rename_tag_template_field.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
-            new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID,
+            request=dict(
+                name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
+                new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -1085,8 +1136,9 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_entry.assert_called_once_with(
-            entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)),
-            update_mask=TEST_UPDATE_MASK,
+            request=dict(
+                entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -1111,8 +1163,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_tag.assert_called_once_with(
-            tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)),
-            update_mask=TEST_UPDATE_MASK,
+            request=dict(tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -1135,8 +1186,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_tag_template.assert_called_once_with(
-            tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
-            update_mask=TEST_UPDATE_MASK,
+            request=dict(
+                tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
+                update_mask=TEST_UPDATE_MASK,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -1160,9 +1213,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_tag_template_field.assert_called_once_with(
-            name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
-            tag_template_field=TEST_TAG_TEMPLATE_FIELD,
-            update_mask=TEST_UPDATE_MASK,
+            request=dict(
+                name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
+                tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+                update_mask=TEST_UPDATE_MASK,
+            ),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
diff --git a/tests/providers/google/cloud/operators/test_datacatalog.py b/tests/providers/google/cloud/operators/test_datacatalog.py
index b575dd4..517b35c 100644
--- a/tests/providers/google/cloud/operators/test_datacatalog.py
+++ b/tests/providers/google/cloud/operators/test_datacatalog.py
@@ -87,15 +87,25 @@ TEST_TAG_PATH: str = (
 )
 
 TEST_ENTRY: Entry = Entry(name=TEST_ENTRY_PATH)
-TEST_ENTRY_DICT: Dict = dict(name=TEST_ENTRY_PATH)
+TEST_ENTRY_DICT: Dict = {
+    'description': '',
+    'display_name': '',
+    'linked_resource': '',
+    'name': TEST_ENTRY_PATH,
+}
 TEST_ENTRY_GROUP: EntryGroup = EntryGroup(name=TEST_ENTRY_GROUP_PATH)
-TEST_ENTRY_GROUP_DICT: Dict = dict(name=TEST_ENTRY_GROUP_PATH)
-TEST_TAG: EntryGroup = Tag(name=TEST_TAG_PATH)
-TEST_TAG_DICT: Dict = dict(name=TEST_TAG_PATH)
+TEST_ENTRY_GROUP_DICT: Dict = {'description': '', 'display_name': '', 'name': TEST_ENTRY_GROUP_PATH}
+TEST_TAG: Tag = Tag(name=TEST_TAG_PATH)
+TEST_TAG_DICT: Dict = {'fields': {}, 'name': TEST_TAG_PATH, 'template': '', 'template_display_name': ''}
 TEST_TAG_TEMPLATE: TagTemplate = TagTemplate(name=TEST_TAG_TEMPLATE_PATH)
-TEST_TAG_TEMPLATE_DICT: Dict = dict(name=TEST_TAG_TEMPLATE_PATH)
-TEST_TAG_TEMPLATE_FIELD: Dict = TagTemplateField(name=TEST_TAG_TEMPLATE_FIELD_ID)
-TEST_TAG_TEMPLATE_FIELD_DICT: Dict = dict(name=TEST_TAG_TEMPLATE_FIELD_ID)
+TEST_TAG_TEMPLATE_DICT: Dict = {'display_name': '', 'fields': {}, 'name': TEST_TAG_TEMPLATE_PATH}
+TEST_TAG_TEMPLATE_FIELD: TagTemplateField = TagTemplateField(name=TEST_TAG_TEMPLATE_FIELD_ID)
+TEST_TAG_TEMPLATE_FIELD_DICT: Dict = {
+    'display_name': '',
+    'is_required': False,
+    'name': TEST_TAG_TEMPLATE_FIELD_ID,
+    'order': 0,
+}
 
 
 class TestCloudDataCatalogCreateEntryOperator(TestCase):
@@ -498,7 +508,10 @@ class TestCloudDataCatalogDeleteTagTemplateFieldOperator(TestCase):
 
 
 class TestCloudDataCatalogGetEntryOperator(TestCase):
-    @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+    @mock.patch(
+        "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+        **{"return_value.get_entry.return_value": TEST_ENTRY},  # type: ignore
+    )
     def test_assert_valid_hook_call(self, mock_hook) -> None:
         task = CloudDataCatalogGetEntryOperator(
             task_id="task_id",
@@ -529,7 +542,10 @@ class TestCloudDataCatalogGetEntryOperator(TestCase):
 
 
 class TestCloudDataCatalogGetEntryGroupOperator(TestCase):
-    @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+    @mock.patch(
+        "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+        **{"return_value.get_entry_group.return_value": TEST_ENTRY_GROUP},  # type: ignore
+    )
     def test_assert_valid_hook_call(self, mock_hook) -> None:
         task = CloudDataCatalogGetEntryGroupOperator(
             task_id="task_id",
@@ -560,7 +576,10 @@ class TestCloudDataCatalogGetEntryGroupOperator(TestCase):
 
 
 class TestCloudDataCatalogGetTagTemplateOperator(TestCase):
-    @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+    @mock.patch(
+        "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+        **{"return_value.get_tag_template.return_value": TEST_TAG_TEMPLATE},  # type: ignore
+    )
     def test_assert_valid_hook_call(self, mock_hook) -> None:
         task = CloudDataCatalogGetTagTemplateOperator(
             task_id="task_id",
@@ -589,7 +608,10 @@ class TestCloudDataCatalogGetTagTemplateOperator(TestCase):
 
 
 class TestCloudDataCatalogListTagsOperator(TestCase):
-    @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+    @mock.patch(
+        "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+        **{"return_value.list_tags.return_value": [TEST_TAG]},  # type: ignore
+    )
     def test_assert_valid_hook_call(self, mock_hook) -> None:
         task = CloudDataCatalogListTagsOperator(
             task_id="task_id",
@@ -622,7 +644,10 @@ class TestCloudDataCatalogListTagsOperator(TestCase):
 
 
 class TestCloudDataCatalogLookupEntryOperator(TestCase):
-    @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+    @mock.patch(
+        "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+        **{"return_value.lookup_entry.return_value": TEST_ENTRY},  # type: ignore
+    )
     def test_assert_valid_hook_call(self, mock_hook) -> None:
         task = CloudDataCatalogLookupEntryOperator(
             task_id="task_id",


[airflow] 24/28: Remove testfixtures module that is only used once (#14318)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 12b5dcae25d8037fafd4a35240ef044e66fcdf51
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Mon Feb 22 20:17:31 2021 +0000

    Remove testfixtures module that is only used once (#14318)
    
    This is only used in a single test, everywhere else we use Pytest or
    unittest's built in feature
    
    (cherry picked from commit 3a046faaeb457572b1484faf158cc96eb81df44a)
---
 setup.py                                         |  1 -
 tests/providers/amazon/aws/hooks/test_glacier.py | 65 ++++++++++--------------
 2 files changed, 27 insertions(+), 39 deletions(-)

diff --git a/setup.py b/setup.py
index 92eb113..ad4fdd5 100644
--- a/setup.py
+++ b/setup.py
@@ -508,7 +508,6 @@ devel = [
     'pywinrm',
     'qds-sdk>=1.9.6',
     'requests_mock',
-    'testfixtures',
     'wheel',
     'yamllint',
 ]
diff --git a/tests/providers/amazon/aws/hooks/test_glacier.py b/tests/providers/amazon/aws/hooks/test_glacier.py
index c1c86a5..c22620f 100644
--- a/tests/providers/amazon/aws/hooks/test_glacier.py
+++ b/tests/providers/amazon/aws/hooks/test_glacier.py
@@ -19,8 +19,6 @@
 import unittest
 from unittest import mock
 
-from testfixtures import LogCapture
-
 from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
 
 CREDENTIALS = "aws_conn"
@@ -52,26 +50,20 @@ class TestAmazonGlacierHook(unittest.TestCase):
         # given
         job_id = {"jobId": "1234abcd"}
         # when
-        with LogCapture() as log:
+        with self.assertLogs() as log:
             mock_conn.return_value.initiate_job.return_value = job_id
             self.hook.retrieve_inventory(VAULT_NAME)
             # then
-            log.check(
-                (
-                    'airflow.providers.amazon.aws.hooks.glacier.GlacierHook',
-                    'INFO',
-                    f"Retrieving inventory for vault: {VAULT_NAME}",
-                ),
-                (
-                    'airflow.providers.amazon.aws.hooks.glacier.GlacierHook',
-                    'INFO',
-                    f"Initiated inventory-retrieval job for: {VAULT_NAME}",
-                ),
-                (
-                    'airflow.providers.amazon.aws.hooks.glacier.GlacierHook',
-                    'INFO',
-                    f"Retrieval Job ID: {job_id.get('jobId')}",
-                ),
+            self.assertEqual(
+                log.output,
+                [
+                    'INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:'
+                    + f"Retrieving inventory for vault: {VAULT_NAME}",
+                    'INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:'
+                    + f"Initiated inventory-retrieval job for: {VAULT_NAME}",
+                    'INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:'
+                    + f"Retrieval Job ID: {job_id.get('jobId')}",
+                ],
             )
 
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
@@ -86,16 +78,16 @@ class TestAmazonGlacierHook(unittest.TestCase):
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
     def test_retrieve_inventory_results_should_log_mgs(self, mock_conn):
         # when
-        with LogCapture() as log:
+        with self.assertLogs() as log:
             mock_conn.return_value.get_job_output.return_value = REQUEST_RESULT
             self.hook.retrieve_inventory_results(VAULT_NAME, JOB_ID)
             # then
-            log.check(
-                (
-                    'airflow.providers.amazon.aws.hooks.glacier.GlacierHook',
-                    'INFO',
-                    f"Retrieving the job results for vault: {VAULT_NAME}...",
-                ),
+            self.assertEqual(
+                log.output,
+                [
+                    'INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:'
+                    + f"Retrieving the job results for vault: {VAULT_NAME}...",
+                ],
             )
 
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
@@ -110,19 +102,16 @@ class TestAmazonGlacierHook(unittest.TestCase):
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
     def test_describe_job_should_log_mgs(self, mock_conn):
         # when
-        with LogCapture() as log:
+        with self.assertLogs() as log:
             mock_conn.return_value.describe_job.return_value = JOB_STATUS
             self.hook.describe_job(VAULT_NAME, JOB_ID)
             # then
-            log.check(
-                (
-                    'airflow.providers.amazon.aws.hooks.glacier.GlacierHook',
-                    'INFO',
-                    f"Retrieving status for vault: {VAULT_NAME} and job {JOB_ID}",
-                ),
-                (
-                    'airflow.providers.amazon.aws.hooks.glacier.GlacierHook',
-                    'INFO',
-                    f"Job status: {JOB_STATUS.get('Action')}, code status: {JOB_STATUS.get('StatusCode')}",
-                ),
+            self.assertEqual(
+                log.output,
+                [
+                    'INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:'
+                    + f"Retrieving status for vault: {VAULT_NAME} and job {JOB_ID}",
+                    'INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:'
+                    + f"Job status: {JOB_STATUS.get('Action')}, code status: {JOB_STATUS.get('StatusCode')}",
+                ],
             )


[airflow] 02/28: Minor doc fixes (#14547)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a671f37fc888e5fed2620b4cb99e9222e83e9a21
Author: Xiaodong DENG <xd...@apache.org>
AuthorDate: Mon Mar 1 21:31:58 2021 +0100

    Minor doc fixes (#14547)
    
    (cherry picked from commit 391baee4047127fe722eeb7f4aec219c86a89295)
---
 docs/apache-airflow/installation.rst | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/docs/apache-airflow/installation.rst b/docs/apache-airflow/installation.rst
index f8042ac..eac6894 100644
--- a/docs/apache-airflow/installation.rst
+++ b/docs/apache-airflow/installation.rst
@@ -63,7 +63,7 @@ issues from ``pip`` 20.3.0 release have been fixed in 20.3.3). In order to insta
 either downgrade pip to version 20.2.4 ``pip install --upgrade pip==20.2.4`` or, in case you use Pip 20.3, you need to add option
 ``--use-deprecated legacy-resolver`` to your pip install command.
 
-While they are some successes with using other tools like `poetry <https://python-poetry.org/>`_ or
+While there are some successes with using other tools like `poetry <https://python-poetry.org/>`_ or
 `pip-tools <https://pypi.org/project/pip-tools/>`_, they do not share the same workflow as
 ``pip`` - especially when it comes to constraint vs. requirements management.
 Installing via ``Poetry`` or ``pip-tools`` is not currently supported. If you wish to install airflow
@@ -81,8 +81,8 @@ environment. For instance, if you don't need connectivity with Postgres,
 you won't have to go through the trouble of installing the ``postgres-devel``
 yum package, or whatever equivalent applies on the distribution you are using.
 
-Most of the extra dependencies are linked to a corresponding providers package. For example "amazon" extra
-has a corresponding ``apache-airflow-providers-amazon`` providers package to be installed. When you install
+Most of the extra dependencies are linked to a corresponding provider package. For example "amazon" extra
+has a corresponding ``apache-airflow-providers-amazon`` provider package to be installed. When you install
 Airflow with such extras, the necessary provider packages are installed automatically (latest versions from
 PyPI for those packages). However you can freely upgrade and install provider packages independently from
 the main Airflow installation.
@@ -96,7 +96,7 @@ Provider packages
 
 Unlike Apache Airflow 1.10, the Airflow 2.0 is delivered in multiple, separate, but connected packages.
 The core of Airflow scheduling system is delivered as ``apache-airflow`` package and there are around
-60 providers packages which can be installed separately as so called ``Airflow Provider packages``.
+60 provider packages which can be installed separately as so called ``Airflow Provider packages``.
 The default Airflow installation doesn't have many integrations and you have to install them yourself.
 
 You can even develop and install your own providers for Airflow. For more information,
@@ -164,9 +164,9 @@ In order to have repeatable installation, starting from **Airflow 1.10.10** and
 ``constraints-master``, ``constraints-2-0`` and ``constraints-1-10`` orphan branches and then we create tag
 for each released version e.g. ``constraints-2.0.1``. This way, when we keep a tested and working set of dependencies.
 
-Those "known-to-be-working" constraints are per major/minor python version. You can use them as constraint
+Those "known-to-be-working" constraints are per major/minor Python version. You can use them as constraint
 files when installing Airflow from PyPI. Note that you have to specify correct Airflow version
-and python versions in the URL.
+and Python versions in the URL.
 
 You can create the URL to the file substituting the variables in the template below.
 


[airflow] 27/28: Add Azure Data Factory hook (#11015)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 0d199d375a73c2a1c4e05f677915ed61f96344a2
Author: flvndh <17...@users.noreply.github.com>
AuthorDate: Fri Feb 26 17:28:21 2021 +0100

    Add Azure Data Factory hook (#11015)
    
    fixes #10995
    
    (cherry picked from commit 11d03d2f63d88a284d6aaded5f9ab6642a60561b)
---
 .../microsoft/azure/hooks/azure_data_factory.py    | 716 +++++++++++++++++++++
 airflow/providers/microsoft/azure/provider.yaml    |   8 +
 .../integration-logos/azure/Azure Data Factory.svg |   1 +
 docs/spelling_wordlist.txt                         |   1 +
 setup.py                                           |   1 +
 .../azure/hooks/test_azure_data_factory.py         | 439 +++++++++++++
 6 files changed, 1166 insertions(+)

diff --git a/airflow/providers/microsoft/azure/hooks/azure_data_factory.py b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
new file mode 100644
index 0000000..d6c686b
--- /dev/null
+++ b/airflow/providers/microsoft/azure/hooks/azure_data_factory.py
@@ -0,0 +1,716 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import inspect
+from functools import wraps
+from typing import Any, Callable, Optional
+
+from azure.mgmt.datafactory import DataFactoryManagementClient
+from azure.mgmt.datafactory.models import (
+    CreateRunResponse,
+    Dataset,
+    DatasetResource,
+    Factory,
+    LinkedService,
+    LinkedServiceResource,
+    PipelineResource,
+    PipelineRun,
+    Trigger,
+    TriggerResource,
+)
+from msrestazure.azure_operation import AzureOperationPoller
+
+from airflow.exceptions import AirflowException
+from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook
+
+
+def provide_targeted_factory(func: Callable) -> Callable:
+    """
+    Provide the targeted factory to the decorated function in case it isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    def wrapper(*args, **kwargs) -> Callable:
+        bound_args = signature.bind(*args, **kwargs)
+
+        def bind_argument(arg, default_key):
+            if arg not in bound_args.arguments:
+                self = args[0]
+                conn = self.get_connection(self.conn_id)
+                default_value = conn.extra_dejson.get(default_key)
+
+                if not default_value:
+                    raise AirflowException("Could not determine the targeted data factory.")
+
+                bound_args.arguments[arg] = conn.extra_dejson[default_key]
+
+        bind_argument("resource_group_name", "resourceGroup")
+        bind_argument("factory_name", "factory")
+
+        return func(*bound_args.args, **bound_args.kwargs)
+
+    return wrapper
+
+
+class AzureDataFactoryHook(AzureBaseHook):  # pylint: disable=too-many-public-methods
+    """
+    A hook to interact with Azure Data Factory.
+
+    :param conn_id: The Azure Data Factory connection id.
+    """
+
+    def __init__(self, conn_id: str = "azure_data_factory_default"):
+        super().__init__(sdk_client=DataFactoryManagementClient, conn_id=conn_id)
+        self._conn: DataFactoryManagementClient = None
+
+    def get_conn(self) -> DataFactoryManagementClient:
+        if not self._conn:
+            self._conn = super().get_conn()
+
+        return self._conn
+
+    @provide_targeted_factory
+    def get_factory(
+        self, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None, **config: Any
+    ) -> Factory:
+        """
+        Get the factory.
+
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: The factory.
+        """
+        return self.get_conn().factories.get(resource_group_name, factory_name, **config)
+
+    def _factory_exists(self, resource_group_name, factory_name) -> bool:
+        """Return whether or not the factory already exists."""
+        factories = {
+            factory.name for factory in self.get_conn().factories.list_by_resource_group(resource_group_name)
+        }
+
+        return factory_name in factories
+
+    @provide_targeted_factory
+    def update_factory(
+        self,
+        factory: Factory,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> Factory:
+        """
+        Update the factory.
+
+        :param factory: The factory resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the factory does not exist.
+        :return: The factory.
+        """
+        if not self._factory_exists(resource_group_name, factory):
+            raise AirflowException(f"Factory {factory!r} does not exist.")
+
+        return self.get_conn().factories.create_or_update(
+            resource_group_name, factory_name, factory, **config
+        )
+
+    @provide_targeted_factory
+    def create_factory(
+        self,
+        factory: Factory,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> Factory:
+        """
+        Create the factory.
+
+        :param factory: The factory resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the factory already exists.
+        :return: The factory.
+        """
+        if self._factory_exists(resource_group_name, factory):
+            raise AirflowException(f"Factory {factory!r} already exists.")
+
+        return self.get_conn().factories.create_or_update(
+            resource_group_name, factory_name, factory, **config
+        )
+
+    @provide_targeted_factory
+    def delete_factory(
+        self, resource_group_name: Optional[str] = None, factory_name: Optional[str] = None, **config: Any
+    ) -> None:
+        """
+        Delete the factory.
+
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        """
+        self.get_conn().factories.delete(resource_group_name, factory_name, **config)
+
+    @provide_targeted_factory
+    def get_linked_service(
+        self,
+        linked_service_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> LinkedServiceResource:
+        """
+        Get the linked service.
+
+        :param linked_service_name: The linked service name.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: The linked service.
+        """
+        return self.get_conn().linked_services.get(
+            resource_group_name, factory_name, linked_service_name, **config
+        )
+
+    def _linked_service_exists(self, resource_group_name, factory_name, linked_service_name) -> bool:
+        """Return whether or not the linked service already exists."""
+        linked_services = {
+            linked_service.name
+            for linked_service in self.get_conn().linked_services.list_by_factory(
+                resource_group_name, factory_name
+            )
+        }
+
+        return linked_service_name in linked_services
+
+    @provide_targeted_factory
+    def update_linked_service(
+        self,
+        linked_service_name: str,
+        linked_service: LinkedService,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> LinkedServiceResource:
+        """
+        Update the linked service.
+
+        :param linked_service_name: The linked service name.
+        :param linked_service: The linked service resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the linked service does not exist.
+        :return: The linked service.
+        """
+        if not self._linked_service_exists(resource_group_name, factory_name, linked_service_name):
+            raise AirflowException(f"Linked service {linked_service_name!r} does not exist.")
+
+        return self.get_conn().linked_services.create_or_update(
+            resource_group_name, factory_name, linked_service_name, linked_service, **config
+        )
+
+    @provide_targeted_factory
+    def create_linked_service(
+        self,
+        linked_service_name: str,
+        linked_service: LinkedService,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> LinkedServiceResource:
+        """
+        Create the linked service.
+
+        :param linked_service_name: The linked service name.
+        :param linked_service: The linked service resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the linked service already exists.
+        :return: The linked service.
+        """
+        if self._linked_service_exists(resource_group_name, factory_name, linked_service_name):
+            raise AirflowException(f"Linked service {linked_service_name!r} already exists.")
+
+        return self.get_conn().linked_services.create_or_update(
+            resource_group_name, factory_name, linked_service_name, linked_service, **config
+        )
+
+    @provide_targeted_factory
+    def delete_linked_service(
+        self,
+        linked_service_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> None:
+        """
+        Delete the linked service:
+
+        :param linked_service_name: The linked service name.
+        :param resource_group_name: The linked service name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        """
+        self.get_conn().linked_services.delete(
+            resource_group_name, factory_name, linked_service_name, **config
+        )
+
+    @provide_targeted_factory
+    def get_dataset(
+        self,
+        dataset_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> DatasetResource:
+        """
+        Get the dataset.
+
+        :param dataset_name: The dataset name.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: The dataset.
+        """
+        return self.get_conn().datasets.get(resource_group_name, factory_name, dataset_name, **config)
+
+    def _dataset_exists(self, resource_group_name, factory_name, dataset_name) -> bool:
+        """Return whether or not the dataset already exists."""
+        datasets = {
+            dataset.name
+            for dataset in self.get_conn().datasets.list_by_factory(resource_group_name, factory_name)
+        }
+
+        return dataset_name in datasets
+
+    @provide_targeted_factory
+    def update_dataset(
+        self,
+        dataset_name: str,
+        dataset: Dataset,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> DatasetResource:
+        """
+        Update the dataset.
+
+        :param dataset_name: The dataset name.
+        :param dataset: The dataset resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the dataset does not exist.
+        :return: The dataset.
+        """
+        if not self._dataset_exists(resource_group_name, factory_name, dataset_name):
+            raise AirflowException(f"Dataset {dataset_name!r} does not exist.")
+
+        return self.get_conn().datasets.create_or_update(
+            resource_group_name, factory_name, dataset_name, dataset, **config
+        )
+
+    @provide_targeted_factory
+    def create_dataset(
+        self,
+        dataset_name: str,
+        dataset: Dataset,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> DatasetResource:
+        """
+        Create the dataset.
+
+        :param dataset_name: The dataset name.
+        :param dataset: The dataset resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the dataset already exists.
+        :return: The dataset.
+        """
+        if self._dataset_exists(resource_group_name, factory_name, dataset_name):
+            raise AirflowException(f"Dataset {dataset_name!r} already exists.")
+
+        return self.get_conn().datasets.create_or_update(
+            resource_group_name, factory_name, dataset_name, dataset, **config
+        )
+
+    @provide_targeted_factory
+    def delete_dataset(
+        self,
+        dataset_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> None:
+        """
+        Delete the dataset:
+
+        :param dataset_name: The dataset name.
+        :param resource_group_name: The dataset name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        """
+        self.get_conn().datasets.delete(resource_group_name, factory_name, dataset_name, **config)
+
+    @provide_targeted_factory
+    def get_pipeline(
+        self,
+        pipeline_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> PipelineResource:
+        """
+        Get the pipeline.
+
+        :param pipeline_name: The pipeline name.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: The pipeline.
+        """
+        return self.get_conn().pipelines.get(resource_group_name, factory_name, pipeline_name, **config)
+
+    def _pipeline_exists(self, resource_group_name, factory_name, pipeline_name) -> bool:
+        """Return whether or not the pipeline already exists."""
+        pipelines = {
+            pipeline.name
+            for pipeline in self.get_conn().pipelines.list_by_factory(resource_group_name, factory_name)
+        }
+
+        return pipeline_name in pipelines
+
+    @provide_targeted_factory
+    def update_pipeline(
+        self,
+        pipeline_name: str,
+        pipeline: PipelineResource,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> PipelineResource:
+        """
+        Update the pipeline.
+
+        :param pipeline_name: The pipeline name.
+        :param pipeline: The pipeline resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the pipeline does not exist.
+        :return: The pipeline.
+        """
+        if not self._pipeline_exists(resource_group_name, factory_name, pipeline_name):
+            raise AirflowException(f"Pipeline {pipeline_name!r} does not exist.")
+
+        return self.get_conn().pipelines.create_or_update(
+            resource_group_name, factory_name, pipeline_name, pipeline, **config
+        )
+
+    @provide_targeted_factory
+    def create_pipeline(
+        self,
+        pipeline_name: str,
+        pipeline: PipelineResource,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> PipelineResource:
+        """
+        Create the pipeline.
+
+        :param pipeline_name: The pipeline name.
+        :param pipeline: The pipeline resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the pipeline already exists.
+        :return: The pipeline.
+        """
+        if self._pipeline_exists(resource_group_name, factory_name, pipeline_name):
+            raise AirflowException(f"Pipeline {pipeline_name!r} already exists.")
+
+        return self.get_conn().pipelines.create_or_update(
+            resource_group_name, factory_name, pipeline_name, pipeline, **config
+        )
+
+    @provide_targeted_factory
+    def delete_pipeline(
+        self,
+        pipeline_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> None:
+        """
+        Delete the pipeline:
+
+        :param pipeline_name: The pipeline name.
+        :param resource_group_name: The pipeline name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        """
+        self.get_conn().pipelines.delete(resource_group_name, factory_name, pipeline_name, **config)
+
+    @provide_targeted_factory
+    def run_pipeline(
+        self,
+        pipeline_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> CreateRunResponse:
+        """
+        Run a pipeline.
+
+        :param pipeline_name: The pipeline name.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: The pipeline run.
+        """
+        return self.get_conn().pipelines.create_run(
+            resource_group_name, factory_name, pipeline_name, **config
+        )
+
+    @provide_targeted_factory
+    def get_pipeline_run(
+        self,
+        run_id: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> PipelineRun:
+        """
+        Get the pipeline run.
+
+        :param run_id: The pipeline run identifier.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: The pipeline run.
+        """
+        return self.get_conn().pipeline_runs.get(resource_group_name, factory_name, run_id, **config)
+
+    @provide_targeted_factory
+    def cancel_pipeline_run(
+        self,
+        run_id: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> None:
+        """
+        Cancel the pipeline run.
+
+        :param run_id: The pipeline run identifier.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        """
+        self.get_conn().pipeline_runs.cancel(resource_group_name, factory_name, run_id, **config)
+
+    @provide_targeted_factory
+    def get_trigger(
+        self,
+        trigger_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> TriggerResource:
+        """
+        Get the trigger.
+
+        :param trigger_name: The trigger name.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: The trigger.
+        """
+        return self.get_conn().triggers.get(resource_group_name, factory_name, trigger_name, **config)
+
+    def _trigger_exists(self, resource_group_name, factory_name, trigger_name) -> bool:
+        """Return whether or not the trigger already exists."""
+        triggers = {
+            trigger.name
+            for trigger in self.get_conn().triggers.list_by_factory(resource_group_name, factory_name)
+        }
+
+        return trigger_name in triggers
+
+    @provide_targeted_factory
+    def update_trigger(
+        self,
+        trigger_name: str,
+        trigger: Trigger,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> TriggerResource:
+        """
+        Update the trigger.
+
+        :param trigger_name: The trigger name.
+        :param trigger: The trigger resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the trigger does not exist.
+        :return: The trigger.
+        """
+        if not self._trigger_exists(resource_group_name, factory_name, trigger_name):
+            raise AirflowException(f"Trigger {trigger_name!r} does not exist.")
+
+        return self.get_conn().triggers.create_or_update(
+            resource_group_name, factory_name, trigger_name, trigger, **config
+        )
+
+    @provide_targeted_factory
+    def create_trigger(
+        self,
+        trigger_name: str,
+        trigger: Trigger,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> TriggerResource:
+        """
+        Create the trigger.
+
+        :param trigger_name: The trigger name.
+        :param trigger: The trigger resource definition.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :raise AirflowException: If the trigger already exists.
+        :return: The trigger.
+        """
+        if self._trigger_exists(resource_group_name, factory_name, trigger_name):
+            raise AirflowException(f"Trigger {trigger_name!r} already exists.")
+
+        return self.get_conn().triggers.create_or_update(
+            resource_group_name, factory_name, trigger_name, trigger, **config
+        )
+
+    @provide_targeted_factory
+    def delete_trigger(
+        self,
+        trigger_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> None:
+        """
+        Delete the trigger.
+
+        :param trigger_name: The trigger name.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        """
+        self.get_conn().triggers.delete(resource_group_name, factory_name, trigger_name, **config)
+
+    @provide_targeted_factory
+    def start_trigger(
+        self,
+        trigger_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> AzureOperationPoller:
+        """
+        Start the trigger.
+
+        :param trigger_name: The trigger name.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: An Azure operation poller.
+        """
+        return self.get_conn().triggers.start(resource_group_name, factory_name, trigger_name, **config)
+
+    @provide_targeted_factory
+    def stop_trigger(
+        self,
+        trigger_name: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> AzureOperationPoller:
+        """
+        Stop the trigger.
+
+        :param trigger_name: The trigger name.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        :return: An Azure operation poller.
+        """
+        return self.get_conn().triggers.stop(resource_group_name, factory_name, trigger_name, **config)
+
+    @provide_targeted_factory
+    def rerun_trigger(
+        self,
+        trigger_name: str,
+        run_id: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> None:
+        """
+        Rerun the trigger.
+
+        :param trigger_name: The trigger name.
+        :param run_id: The trigger run identifier.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        """
+        return self.get_conn().trigger_runs.rerun(
+            resource_group_name, factory_name, trigger_name, run_id, **config
+        )
+
+    @provide_targeted_factory
+    def cancel_trigger(
+        self,
+        trigger_name: str,
+        run_id: str,
+        resource_group_name: Optional[str] = None,
+        factory_name: Optional[str] = None,
+        **config: Any,
+    ) -> None:
+        """
+        Cancel the trigger.
+
+        :param trigger_name: The trigger name.
+        :param run_id: The trigger run identifier.
+        :param resource_group_name: The resource group name.
+        :param factory_name: The factory name.
+        :param config: Extra parameters for the ADF client.
+        """
+        self.get_conn().trigger_runs.cancel(resource_group_name, factory_name, trigger_name, run_id, **config)
diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml
index fa0d112..da7b330 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -54,6 +54,10 @@ integrations:
   - integration-name: Microsoft Azure FileShare
     external-doc-url: https://cloud.google.com/storage/
     tags: [azure]
+  - integration-name: Microsoft Azure Data Factory
+    external-doc-url: https://azure.microsoft.com/en-us/services/data-factory/
+    logo: /integration-logos/azure/Azure Data Factory.svg
+    tags: [azure]
   - integration-name: Microsoft Azure
     external-doc-url: https://azure.microsoft.com/
     tags: [azure]
@@ -113,6 +117,9 @@ hooks:
   - integration-name: Microsoft Azure Blob Storage
     python-modules:
       - airflow.providers.microsoft.azure.hooks.wasb
+  - integration-name: Microsoft Azure Data Factory
+    python-modules:
+      - airflow.providers.microsoft.azure.hooks.azure_data_factory
 
 transfers:
   - source-integration-name: Local
@@ -138,3 +145,4 @@ hook-class-names:
   - airflow.providers.microsoft.azure.hooks.azure_data_lake.AzureDataLakeHook
   - airflow.providers.microsoft.azure.hooks.azure_container_instance.AzureContainerInstanceHook
   - airflow.providers.microsoft.azure.hooks.wasb.WasbHook
+  - airflow.providers.microsoft.azure.hooks.azure_data_factory.AzureDataFactoryHook
diff --git a/docs/integration-logos/azure/Azure Data Factory.svg b/docs/integration-logos/azure/Azure Data Factory.svg
new file mode 100644
index 0000000..481d3d4
--- /dev/null
+++ b/docs/integration-logos/azure/Azure Data Factory.svg	
@@ -0,0 +1 @@
+<svg id="f9ed9690-6753-43a7-8b32-d66ac7b8a99a" xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18"><defs><linearGradient id="f710a364-083f-494c-9d96-89b92ee2d5a8" x1="0.5" y1="9.77" x2="9" y2="9.77" gradientUnits="userSpaceOnUse"><stop offset="0" stop-color="#005ba1" /><stop offset="0.07" stop-color="#0060a9" /><stop offset="0.36" stop-color="#0071c8" /><stop offset="0.52" stop-color="#0078d4" /><stop offset="0.64" stop-color="#0074cd" /><stop offset="0.81" stop [...]
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 0e89285..238021e 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1062,6 +1062,7 @@ png
 podName
 podSpec
 podspec
+poller
 polyfill
 postMessage
 postfix
diff --git a/setup.py b/setup.py
index 4ee7a5c..0846ec9 100644
--- a/setup.py
+++ b/setup.py
@@ -217,6 +217,7 @@ azure = [
     'azure-keyvault>=4.1.0',
     'azure-kusto-data>=0.0.43,<0.1',
     'azure-mgmt-containerinstance>=1.5.0,<2.0',
+    'azure-mgmt-datafactory>=0.13.0',
     'azure-mgmt-datalake-store>=0.5.0',
     'azure-mgmt-resource>=2.2.0',
     'azure-storage-blob>=12.7.0',
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
new file mode 100644
index 0000000..ea445ec
--- /dev/null
+++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
@@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=redefined-outer-name,unused-argument
+
+import json
+from unittest.mock import MagicMock, Mock
+
+import pytest
+from pytest import fixture
+
+from airflow.exceptions import AirflowException
+from airflow.models.connection import Connection
+from airflow.providers.microsoft.azure.hooks.azure_data_factory import (
+    AzureDataFactoryHook,
+    provide_targeted_factory,
+)
+from airflow.utils import db
+
+DEFAULT_RESOURCE_GROUP = "defaultResourceGroup"
+RESOURCE_GROUP = "testResourceGroup"
+
+DEFAULT_FACTORY = "defaultFactory"
+FACTORY = "testFactory"
+
+MODEL = object()
+NAME = "testName"
+ID = "testId"
+
+
+def setup_module():
+    connection = Connection(
+        conn_id="azure_data_factory_test",
+        conn_type="azure_data_factory",
+        login="clientId",
+        password="clientSecret",
+        extra=json.dumps(
+            {
+                "tenantId": "tenantId",
+                "subscriptionId": "subscriptionId",
+                "resourceGroup": DEFAULT_RESOURCE_GROUP,
+                "factory": DEFAULT_FACTORY,
+            }
+        ),
+    )
+
+    db.merge_conn(connection)
+
+
+@fixture
+def hook():
+    client = AzureDataFactoryHook(conn_id="azure_data_factory_test")
+    client._conn = MagicMock(
+        spec=[
+            "factories",
+            "linked_services",
+            "datasets",
+            "pipelines",
+            "pipeline_runs",
+            "triggers",
+            "trigger_runs",
+        ]
+    )
+
+    return client
+
+
+def parametrize(explicit_factory, implicit_factory):
+    def wrapper(func):
+        return pytest.mark.parametrize(
+            ("user_args", "sdk_args"),
+            (explicit_factory, implicit_factory),
+            ids=("explicit factory", "implicit factory"),
+        )(func)
+
+    return wrapper
+
+
+def test_provide_targeted_factory():
+    def echo(_, resource_group_name=None, factory_name=None):
+        return resource_group_name, factory_name
+
+    conn = MagicMock()
+    hook = MagicMock()
+    hook.get_connection.return_value = conn
+
+    conn.extra_dejson = {}
+    assert provide_targeted_factory(echo)(hook, RESOURCE_GROUP, FACTORY) == (RESOURCE_GROUP, FACTORY)
+
+    conn.extra_dejson = {"resourceGroup": DEFAULT_RESOURCE_GROUP, "factory": DEFAULT_FACTORY}
+    assert provide_targeted_factory(echo)(hook) == (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)
+
+    with pytest.raises(AirflowException):
+        conn.extra_dejson = {}
+        provide_targeted_factory(echo)(hook)
+
+
+@parametrize(
+    explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)),
+    implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)),
+)
+def test_get_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_factory(*user_args)
+
+    hook._conn.factories.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)),
+    implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)),
+)
+def test_create_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_factory(*user_args)
+
+    hook._conn.factories.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)),
+    implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)),
+)
+def test_update_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._factory_exists = Mock(return_value=True)
+    hook.update_factory(*user_args)
+
+    hook._conn.factories.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)),
+    implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)),
+)
+def test_update_factory_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._factory_exists = Mock(return_value=False)
+
+    with pytest.raises(AirflowException, match=r"Factory .+ does not exist"):
+        hook.update_factory(*user_args)
+
+
+@parametrize(
+    explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)),
+    implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)),
+)
+def test_delete_factory(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_factory(*user_args)
+
+    hook._conn.factories.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_get_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_linked_service(*user_args)
+
+    hook._conn.linked_services.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_create_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_linked_service(*user_args)
+
+    hook._conn.linked_services.create_or_update(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._linked_service_exists = Mock(return_value=True)
+    hook.update_linked_service(*user_args)
+
+    hook._conn.linked_services.create_or_update(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_linked_service_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._linked_service_exists = Mock(return_value=False)
+
+    with pytest.raises(AirflowException, match=r"Linked service .+ does not exist"):
+        hook.update_linked_service(*user_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_delete_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_linked_service(*user_args)
+
+    hook._conn.linked_services.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_get_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_dataset(*user_args)
+
+    hook._conn.datasets.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_create_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_dataset(*user_args)
+
+    hook._conn.datasets.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._dataset_exists = Mock(return_value=True)
+    hook.update_dataset(*user_args)
+
+    hook._conn.datasets.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_dataset_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._dataset_exists = Mock(return_value=False)
+
+    with pytest.raises(AirflowException, match=r"Dataset .+ does not exist"):
+        hook.update_dataset(*user_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_delete_dataset(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_dataset(*user_args)
+
+    hook._conn.datasets.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_get_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_pipeline(*user_args)
+
+    hook._conn.pipelines.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_create_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_pipeline(*user_args)
+
+    hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._pipeline_exists = Mock(return_value=True)
+    hook.update_pipeline(*user_args)
+
+    hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_pipeline_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._pipeline_exists = Mock(return_value=False)
+
+    with pytest.raises(AirflowException, match=r"Pipeline .+ does not exist"):
+        hook.update_pipeline(*user_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_delete_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_pipeline(*user_args)
+
+    hook._conn.pipelines.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_run_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.run_pipeline(*user_args)
+
+    hook._conn.pipelines.create_run.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, ID)),
+    implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)),
+)
+def test_get_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_pipeline_run(*user_args)
+
+    hook._conn.pipeline_runs.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, ID)),
+    implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)),
+)
+def test_cancel_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.cancel_pipeline_run(*user_args)
+
+    hook._conn.pipeline_runs.cancel.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_get_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.get_trigger(*user_args)
+
+    hook._conn.triggers.get.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_create_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.create_trigger(*user_args)
+
+    hook._conn.triggers.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._trigger_exists = Mock(return_value=True)
+    hook.update_trigger(*user_args)
+
+    hook._conn.triggers.create_or_update.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)),
+    implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)),
+)
+def test_update_trigger_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook._trigger_exists = Mock(return_value=False)
+
+    with pytest.raises(AirflowException, match=r"Trigger .+ does not exist"):
+        hook.update_trigger(*user_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_delete_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.delete_trigger(*user_args)
+
+    hook._conn.triggers.delete.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_start_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.start_trigger(*user_args)
+
+    hook._conn.triggers.start.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)),
+    implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)),
+)
+def test_stop_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.stop_trigger(*user_args)
+
+    hook._conn.triggers.stop.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, ID)),
+    implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, ID)),
+)
+def test_rerun_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.rerun_trigger(*user_args)
+
+    hook._conn.trigger_runs.rerun.assert_called_with(*sdk_args)
+
+
+@parametrize(
+    explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, ID)),
+    implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, ID)),
+)
+def test_cancel_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
+    hook.cancel_trigger(*user_args)
+
+    hook._conn.trigger_runs.cancel.assert_called_with(*sdk_args)


[airflow] 11/28: Support google-cloud-redis>=2.0.0 (#13117)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3ef6d6f13a2dfe14c883e75b35c0874717c284cd
Author: Kamil BreguĊ‚a <mi...@users.noreply.github.com>
AuthorDate: Tue Dec 22 16:25:04 2020 +0100

    Support google-cloud-redis>=2.0.0 (#13117)
    
    (cherry picked from commit 0b626c8042b304a52d6c481fa6eb689d655f33d3)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |  64 +++++++++
 .../example_dags/example_cloud_memorystore.py      |   4 +-
 .../google/cloud/hooks/cloud_memorystore.py        | 144 ++++++++++++++-------
 .../google/cloud/operators/cloud_memorystore.py    |  11 +-
 setup.py                                           |   2 +-
 .../google/cloud/hooks/test_cloud_memorystore.py   |  57 ++++----
 .../cloud/operators/test_cloud_memorystore.py      |   4 +-
 7 files changed, 208 insertions(+), 78 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
new file mode 100644
index 0000000..b54b240
--- /dev/null
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -0,0 +1,64 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
+
+# Migration Guide
+
+## 2.0.0
+
+### Update ``google-cloud-*`` libraries
+
+This release of the provider package contains third-party library updates, which may require updating your DAG files or custom hooks and operators, if you were using objects from those libraries. Updating of these libraries is necessary to be able to use new features made available by new versions of the libraries and to obtain bug fixes that are only available for new versions of the library.
+
+Details are covered in the UPDATING.md files for each library, but there are some details that you should pay attention to.
+
+| Library name | Previous constraints | Current constraints | |
+| --- | --- | --- | --- |
+| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=1.0.0,<2.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
+| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
+| [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
+| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0``  | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
+
+
+### The field names use the snake_case convention
+
+If your DAG uses an object from the above mentioned libraries passed by XCom, it is necessary to update the naming convention of the fields that are read. Previously, the fields used the CamelSnake convention, now the snake_case convention is used.
+
+**Before:**
+
+```python
+set_acl_permission = GCSBucketCreateAclEntryOperator(
+    task_id="gcs-set-acl-permission",
+    bucket=BUCKET_NAME,
+    entity="user-{{ task_instance.xcom_pull('get-instance')['persistenceIamIdentity']"
+    ".split(':', 2)[1] }}",
+    role="OWNER",
+)
+```
+
+**After:**
+
+```python
+set_acl_permission = GCSBucketCreateAclEntryOperator(
+    task_id="gcs-set-acl-permission",
+    bucket=BUCKET_NAME,
+    entity="user-{{ task_instance.xcom_pull('get-instance')['persistence_iam_identity']"
+    ".split(':', 2)[1] }}",
+    role="OWNER",
+)
+```
diff --git a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py
index 441c165..acb50b4 100644
--- a/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py
+++ b/airflow/providers/google/cloud/example_dags/example_cloud_memorystore.py
@@ -22,7 +22,7 @@ import os
 from urllib.parse import urlparse
 
 from google.cloud.memcache_v1beta2.types import cloud_memcache
-from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest, Instance
+from google.cloud.redis_v1 import FailoverInstanceRequest, Instance
 
 from airflow import models
 from airflow.operators.bash import BashOperator
@@ -161,7 +161,7 @@ with models.DAG(
     set_acl_permission = GCSBucketCreateAclEntryOperator(
         task_id="gcs-set-acl-permission",
         bucket=BUCKET_NAME,
-        entity="user-{{ task_instance.xcom_pull('get-instance')['persistenceIamIdentity']"
+        entity="user-{{ task_instance.xcom_pull('get-instance')['persistence_iam_identity']"
         ".split(':', 2)[1] }}",
         role="OWNER",
     )
diff --git a/airflow/providers/google/cloud/hooks/cloud_memorystore.py b/airflow/providers/google/cloud/hooks/cloud_memorystore.py
index bfc01f9..caf1cd6 100644
--- a/airflow/providers/google/cloud/hooks/cloud_memorystore.py
+++ b/airflow/providers/google/cloud/hooks/cloud_memorystore.py
@@ -23,10 +23,14 @@ from google.api_core.exceptions import NotFound
 from google.api_core.retry import Retry
 from google.cloud.memcache_v1beta2 import CloudMemcacheClient
 from google.cloud.memcache_v1beta2.types import cloud_memcache
-from google.cloud.redis_v1 import CloudRedisClient
-from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest
-from google.cloud.redis_v1.types import FieldMask, InputConfig, Instance, OutputConfig
-from google.protobuf.json_format import ParseDict
+from google.cloud.redis_v1 import (
+    CloudRedisClient,
+    FailoverInstanceRequest,
+    InputConfig,
+    Instance,
+    OutputConfig,
+)
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow import version
 from airflow.exceptions import AirflowException
@@ -70,7 +74,7 @@ class CloudMemorystoreHook(GoogleBaseHook):
         )
         self._client: Optional[CloudRedisClient] = None
 
-    def get_conn(self):
+    def get_conn(self) -> CloudRedisClient:
         """Retrieves client library object that allow access to Cloud Memorystore service."""
         if not self._client:
             self._client = CloudRedisClient(credentials=self._get_credentials())
@@ -143,35 +147,36 @@ class CloudMemorystoreHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        parent = CloudRedisClient.location_path(project_id, location)
-        instance_name = CloudRedisClient.instance_path(project_id, location, instance_id)
+        if isinstance(instance, dict):
+            instance = Instance(**instance)
+        elif not isinstance(instance, Instance):
+            raise AirflowException("instance is not instance of Instance type or python dict")
+
+        parent = f"projects/{project_id}/locations/{location}"
+        instance_name = f"projects/{project_id}/locations/{location}/instances/{instance_id}"
         try:
+            self.log.info("Fetching instance: %s", instance_name)
             instance = client.get_instance(
-                name=instance_name, retry=retry, timeout=timeout, metadata=metadata
+                request={'name': instance_name}, retry=retry, timeout=timeout, metadata=metadata or ()
             )
             self.log.info("Instance exists. Skipping creation.")
             return instance
         except NotFound:
             self.log.info("Instance not exists.")
 
-        if isinstance(instance, dict):
-            instance = ParseDict(instance, Instance())
-        elif not isinstance(instance, Instance):
-            raise AirflowException("instance is not instance of Instance type or python dict")
-
         self._append_label(instance, "airflow-version", "v" + version.version)
 
         result = client.create_instance(
-            parent=parent,
-            instance_id=instance_id,
-            instance=instance,
+            request={'parent': parent, 'instance_id': instance_id, 'instance': instance},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         result.result()
         self.log.info("Instance created.")
-        return client.get_instance(name=instance_name, retry=retry, timeout=timeout, metadata=metadata)
+        return client.get_instance(
+            request={'name': instance_name}, retry=retry, timeout=timeout, metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def delete_instance(
@@ -203,15 +208,25 @@ class CloudMemorystoreHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = CloudRedisClient.instance_path(project_id, location, instance)
+        name = f"projects/{project_id}/locations/{location}/instances/{instance}"
         self.log.info("Fetching Instance: %s", name)
-        instance = client.get_instance(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        instance = client.get_instance(
+            request={'name': name},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
+        )
 
         if not instance:
             return
 
         self.log.info("Deleting Instance: %s", name)
-        result = client.delete_instance(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        result = client.delete_instance(
+            request={'name': name},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
+        )
         result.result()
         self.log.info("Instance deleted: %s", name)
 
@@ -253,10 +268,13 @@ class CloudMemorystoreHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = CloudRedisClient.instance_path(project_id, location, instance)
+        name = f"projects/{project_id}/locations/{location}/instances/{instance}"
         self.log.info("Exporting Instance: %s", name)
         result = client.export_instance(
-            name=name, output_config=output_config, retry=retry, timeout=timeout, metadata=metadata
+            request={'name': name, 'output_config': output_config},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
         result.result()
         self.log.info("Instance exported: %s", name)
@@ -297,15 +315,14 @@ class CloudMemorystoreHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = CloudRedisClient.instance_path(project_id, location, instance)
+        name = f"projects/{project_id}/locations/{location}/instances/{instance}"
         self.log.info("Failovering Instance: %s", name)
 
         result = client.failover_instance(
-            name=name,
-            data_protection_mode=data_protection_mode,
+            request={'name': name, 'data_protection_mode': data_protection_mode},
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         result.result()
         self.log.info("Instance failovered: %s", name)
@@ -340,8 +357,13 @@ class CloudMemorystoreHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = CloudRedisClient.instance_path(project_id, location, instance)
-        result = client.get_instance(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        name = f"projects/{project_id}/locations/{location}/instances/{instance}"
+        result = client.get_instance(
+            request={'name': name},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
+        )
         self.log.info("Fetched Instance: %s", name)
         return result
 
@@ -384,10 +406,13 @@ class CloudMemorystoreHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        name = CloudRedisClient.instance_path(project_id, location, instance)
+        name = f"projects/{project_id}/locations/{location}/instances/{instance}"
         self.log.info("Importing Instance: %s", name)
         result = client.import_instance(
-            name=name, input_config=input_config, retry=retry, timeout=timeout, metadata=metadata
+            request={'name': name, 'input_config': input_config},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
         result.result()
         self.log.info("Instance imported: %s", name)
@@ -428,9 +453,12 @@ class CloudMemorystoreHook(GoogleBaseHook):
         :type metadata: Sequence[Tuple[str, str]]
         """
         client = self.get_conn()
-        parent = CloudRedisClient.location_path(project_id, location)
+        parent = f"projects/{project_id}/locations/{location}"
         result = client.list_instances(
-            parent=parent, page_size=page_size, retry=retry, timeout=timeout, metadata=metadata
+            request={'parent': parent, 'page_size': page_size},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
         self.log.info("Fetched instances")
         return result
@@ -485,17 +513,20 @@ class CloudMemorystoreHook(GoogleBaseHook):
         client = self.get_conn()
 
         if isinstance(instance, dict):
-            instance = ParseDict(instance, Instance())
+            instance = Instance(**instance)
         elif not isinstance(instance, Instance):
             raise AirflowException("instance is not instance of Instance type or python dict")
 
         if location and instance_id:
-            name = CloudRedisClient.instance_path(project_id, location, instance_id)
+            name = f"projects/{project_id}/locations/{location}/instances/{instance_id}"
             instance.name = name
 
         self.log.info("Updating instances: %s", instance.name)
         result = client.update_instance(
-            update_mask=update_mask, instance=instance, retry=retry, timeout=timeout, metadata=metadata
+            request={'update_mask': update_mask, 'instance': instance},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
         result.result()
         self.log.info("Instance updated: %s", instance.name)
@@ -610,7 +641,12 @@ class CloudMemorystoreMemcachedHook(GoogleBaseHook):
 
         self.log.info("Applying update to instance: %s", instance_id)
         result = client.apply_parameters(
-            name=name, node_ids=node_ids, apply_all=apply_all, retry=retry, timeout=timeout, metadata=metadata
+            name=name,
+            node_ids=node_ids,
+            apply_all=apply_all,
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
         )
         result.result()
         self.log.info("Instance updated: %s", instance_id)
@@ -688,11 +724,16 @@ class CloudMemorystoreMemcachedHook(GoogleBaseHook):
             resource=instance,
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         result.result()
         self.log.info("Instance created.")
-        return client.get_instance(name=instance_name, retry=retry, timeout=timeout, metadata=metadata)
+        return client.get_instance(
+            name=instance_name,
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def delete_instance(
@@ -727,13 +768,23 @@ class CloudMemorystoreMemcachedHook(GoogleBaseHook):
         metadata = metadata or ()
         name = CloudMemcacheClient.instance_path(project_id, location, instance)
         self.log.info("Fetching Instance: %s", name)
-        instance = client.get_instance(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        instance = client.get_instance(
+            name=name,
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
+        )
 
         if not instance:
             return
 
         self.log.info("Deleting Instance: %s", name)
-        result = client.delete_instance(name=name, retry=retry, timeout=timeout, metadata=metadata)
+        result = client.delete_instance(
+            name=name,
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
+        )
         result.result()
         self.log.info("Instance deleted: %s", name)
 
@@ -808,7 +859,12 @@ class CloudMemorystoreMemcachedHook(GoogleBaseHook):
         parent = path_template.expand(
             "projects/{project}/locations/{location}", project=project_id, location=location
         )
-        result = client.list_instances(parent=parent, retry=retry, timeout=timeout, metadata=metadata)
+        result = client.list_instances(
+            parent=parent,
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata or (),
+        )
         self.log.info("Fetched instances")
         return result
 
@@ -871,7 +927,7 @@ class CloudMemorystoreMemcachedHook(GoogleBaseHook):
 
         self.log.info("Updating instances: %s", instance.name)
         result = client.update_instance(
-            update_mask=update_mask, resource=instance, retry=retry, timeout=timeout, metadata=metadata
+            update_mask=update_mask, resource=instance, retry=retry, timeout=timeout, metadata=metadata or ()
         )
         result.result()
         self.log.info("Instance updated: %s", instance.name)
@@ -934,7 +990,7 @@ class CloudMemorystoreMemcachedHook(GoogleBaseHook):
             parameters=parameters,
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
         result.result()
         self.log.info("Update staged for instance: %s", instance_id)
diff --git a/airflow/providers/google/cloud/operators/cloud_memorystore.py b/airflow/providers/google/cloud/operators/cloud_memorystore.py
index 0ac2640..64a6251 100644
--- a/airflow/providers/google/cloud/operators/cloud_memorystore.py
+++ b/airflow/providers/google/cloud/operators/cloud_memorystore.py
@@ -20,9 +20,8 @@ from typing import Dict, Optional, Sequence, Tuple, Union
 
 from google.api_core.retry import Retry
 from google.cloud.memcache_v1beta2.types import cloud_memcache
-from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest
-from google.cloud.redis_v1.types import FieldMask, InputConfig, Instance, OutputConfig
-from google.protobuf.json_format import MessageToDict
+from google.cloud.redis_v1 import FailoverInstanceRequest, InputConfig, Instance, OutputConfig
+from google.protobuf.field_mask_pb2 import FieldMask
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.cloud_memorystore import (
@@ -134,7 +133,7 @@ class CloudMemorystoreCreateInstanceOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(result)
+        return Instance.to_dict(result)
 
 
 class CloudMemorystoreDeleteInstanceOperator(BaseOperator):
@@ -492,7 +491,7 @@ class CloudMemorystoreGetInstanceOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        return MessageToDict(result)
+        return Instance.to_dict(result)
 
 
 class CloudMemorystoreImportOperator(BaseOperator):
@@ -677,7 +676,7 @@ class CloudMemorystoreListInstancesOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        instances = [MessageToDict(a) for a in result]
+        instances = [Instance.to_dict(a) for a in result]
         return instances
 
 
diff --git a/setup.py b/setup.py
index ff9fd71..ae18e57 100644
--- a/setup.py
+++ b/setup.py
@@ -297,7 +297,7 @@ google = [
     'google-cloud-monitoring>=0.34.0,<2.0.0',
     'google-cloud-os-login>=2.0.0,<3.0.0',
     'google-cloud-pubsub>=2.0.0,<3.0.0',
-    'google-cloud-redis>=0.3.0,<2.0.0',
+    'google-cloud-redis>=2.0.0,<3.0.0',
     'google-cloud-secret-manager>=0.2.0,<2.0.0',
     'google-cloud-spanner>=1.10.0,<2.0.0',
     'google-cloud-speech>=0.36.3,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_cloud_memorystore.py b/tests/providers/google/cloud/hooks/test_cloud_memorystore.py
index 40de3b8..9e6f442 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_memorystore.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_memorystore.py
@@ -85,7 +85,10 @@ class TestCloudMemorystoreWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_instance.assert_called_once_with(
-            name=TEST_NAME_DEFAULT_PROJECT_ID, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
+            request=dict(name=TEST_NAME_DEFAULT_PROJECT_ID),
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
         )
         assert Instance(name=TEST_NAME) == result
 
@@ -116,13 +119,15 @@ class TestCloudMemorystoreWithDefaultProjectIdHook(TestCase):
             ]
         )
         mock_get_conn.return_value.create_instance.assert_called_once_with(
-            instance=Instance(
-                name=TEST_NAME,
-                labels={"airflow-version": "v" + version.version.replace(".", "-").replace("+", "-")},
+            request=dict(
+                parent=TEST_PARENT_DEFAULT_PROJECT_ID,
+                instance=Instance(
+                    name=TEST_NAME,
+                    labels={"airflow-version": "v" + version.version.replace(".", "-").replace("+", "-")},
+                ),
+                instance_id=TEST_INSTANCE_ID,
             ),
-            instance_id=TEST_INSTANCE_ID,
             metadata=TEST_METADATA,
-            parent=TEST_PARENT_DEFAULT_PROJECT_ID,
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
         )
@@ -143,7 +148,10 @@ class TestCloudMemorystoreWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_instance.assert_called_once_with(
-            name=TEST_NAME_DEFAULT_PROJECT_ID, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
+            request=dict(name=TEST_NAME_DEFAULT_PROJECT_ID),
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
         )
 
     @mock.patch(
@@ -161,7 +169,10 @@ class TestCloudMemorystoreWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_instance.assert_called_once_with(
-            name=TEST_NAME_DEFAULT_PROJECT_ID, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
+            request=dict(name=TEST_NAME_DEFAULT_PROJECT_ID),
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
         )
 
     @mock.patch(
@@ -179,8 +190,7 @@ class TestCloudMemorystoreWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.list_instances.assert_called_once_with(
-            parent=TEST_PARENT_DEFAULT_PROJECT_ID,
-            page_size=TEST_PAGE_SIZE,
+            request=dict(parent=TEST_PARENT_DEFAULT_PROJECT_ID, page_size=TEST_PAGE_SIZE),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -203,8 +213,7 @@ class TestCloudMemorystoreWithDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.update_instance.assert_called_once_with(
-            update_mask=TEST_UPDATE_MASK,
-            instance=Instance(name=TEST_NAME_DEFAULT_PROJECT_ID),
+            request=dict(update_mask=TEST_UPDATE_MASK, instance=Instance(name=TEST_NAME_DEFAULT_PROJECT_ID)),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -234,7 +243,7 @@ class TestCloudMemorystoreWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_instance.assert_called_once_with(
-            name="projects/test-project-id/locations/test-location/instances/test-instance-id",
+            request=dict(name="projects/test-project-id/locations/test-location/instances/test-instance-id"),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -275,13 +284,15 @@ class TestCloudMemorystoreWithoutDefaultProjectIdHook(TestCase):
         )
 
         mock_get_conn.return_value.create_instance.assert_called_once_with(
-            instance=Instance(
-                name=TEST_NAME,
-                labels={"airflow-version": "v" + version.version.replace(".", "-").replace("+", "-")},
+            request=dict(
+                parent=TEST_PARENT,
+                instance=Instance(
+                    name=TEST_NAME,
+                    labels={"airflow-version": "v" + version.version.replace(".", "-").replace("+", "-")},
+                ),
+                instance_id=TEST_INSTANCE_ID,
             ),
-            instance_id=TEST_INSTANCE_ID,
             metadata=TEST_METADATA,
-            parent=TEST_PARENT,
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
         )
@@ -316,7 +327,7 @@ class TestCloudMemorystoreWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.delete_instance.assert_called_once_with(
-            name=TEST_NAME, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
+            request=dict(name=TEST_NAME), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
         )
 
     @mock.patch(
@@ -347,7 +358,7 @@ class TestCloudMemorystoreWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.get_instance.assert_called_once_with(
-            name=TEST_NAME, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
+            request=dict(name=TEST_NAME), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
         )
 
     @mock.patch(
@@ -378,8 +389,7 @@ class TestCloudMemorystoreWithoutDefaultProjectIdHook(TestCase):
             metadata=TEST_METADATA,
         )
         mock_get_conn.return_value.list_instances.assert_called_once_with(
-            parent=TEST_PARENT,
-            page_size=TEST_PAGE_SIZE,
+            request=dict(parent=TEST_PARENT, page_size=TEST_PAGE_SIZE),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
@@ -413,8 +423,7 @@ class TestCloudMemorystoreWithoutDefaultProjectIdHook(TestCase):
             project_id=TEST_PROJECT_ID,
         )
         mock_get_conn.return_value.update_instance.assert_called_once_with(
-            update_mask=TEST_UPDATE_MASK,
-            instance=Instance(name=TEST_NAME),
+            request=dict(update_mask={'paths': ['memory_size_gb']}, instance=Instance(name=TEST_NAME)),
             retry=TEST_RETRY,
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
diff --git a/tests/providers/google/cloud/operators/test_cloud_memorystore.py b/tests/providers/google/cloud/operators/test_cloud_memorystore.py
index 8ef60bd..6db8a3a 100644
--- a/tests/providers/google/cloud/operators/test_cloud_memorystore.py
+++ b/tests/providers/google/cloud/operators/test_cloud_memorystore.py
@@ -20,7 +20,7 @@ from unittest import TestCase, mock
 
 from google.api_core.retry import Retry
 from google.cloud.memcache_v1beta2.types import cloud_memcache
-from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest
+from google.cloud.redis_v1 import FailoverInstanceRequest
 from google.cloud.redis_v1.types import Instance
 
 from airflow.providers.google.cloud.operators.cloud_memorystore import (
@@ -78,6 +78,7 @@ class TestCloudMemorystoreCreateInstanceOperator(TestCase):
             gcp_conn_id=TEST_GCP_CONN_ID,
             impersonation_chain=TEST_IMPERSONATION_CHAIN,
         )
+        mock_hook.return_value.create_instance.return_value = Instance(name=TEST_NAME)
         task.execute(mock.MagicMock())
         mock_hook.assert_called_once_with(
             gcp_conn_id=TEST_GCP_CONN_ID,
@@ -199,6 +200,7 @@ class TestCloudMemorystoreGetInstanceOperator(TestCase):
             gcp_conn_id=TEST_GCP_CONN_ID,
             impersonation_chain=TEST_IMPERSONATION_CHAIN,
         )
+        mock_hook.return_value.get_instance.return_value = Instance(name=TEST_NAME)
         task.execute(mock.MagicMock())
         mock_hook.assert_called_once_with(
             gcp_conn_id=TEST_GCP_CONN_ID,


[airflow] 14/28: Salesforce provider requires tableau (#13593)

Posted by po...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 56cc293debe91977605273c0ea666660032c29af
Author: Daniel Standish <ds...@pax.com>
AuthorDate: Sun Jan 10 02:20:34 2021 -0800

    Salesforce provider requires tableau (#13593)
    
    Co-authored-by: Daniel Standish <ds...@users.noreply.github.com>
    (cherry picked from commit 46edea3411498a4c2e1d8840ba0dcd93daa1e25f)
---
 setup.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/setup.py b/setup.py
index 628ecd1..75f5db5 100644
--- a/setup.py
+++ b/setup.py
@@ -403,6 +403,7 @@ redis = [
 ]
 salesforce = [
     'simple-salesforce>=1.0.0',
+    'tableauserverclient',
 ]
 samba = [
     'pysmbclient>=0.1.3',