You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/03/03 09:32:50 UTC
[airflow] 26/41: Support google-cloud-datacatalog>=3.0.0 (#13534)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 02cb5e1af6c4d1b2823729d3f2801fd9d05bdf43
Author: Kamil BreguĊa <mi...@users.noreply.github.com>
AuthorDate: Mon Jan 11 09:39:19 2021 +0100
Support google-cloud-datacatalog>=3.0.0 (#13534)
(cherry picked from commit 947dbb73bba736eb146f33117545a18fc2fd3c09)
---
airflow/providers/google/ADDITIONAL_INFO.md | 2 +-
.../cloud/example_dags/example_datacatalog.py | 10 +-
.../providers/google/cloud/hooks/datacatalog.py | 220 ++++++++++++-------
.../google/cloud/operators/datacatalog.py | 47 ++--
setup.py | 2 +-
.../google/cloud/hooks/test_datacatalog.py | 237 +++++++++++++--------
.../google/cloud/operators/test_datacatalog.py | 49 +++--
7 files changed, 357 insertions(+), 210 deletions(-)
diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md
index eca05df..d80f9e1 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -30,7 +30,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
| Library name | Previous constraints | Current constraints | |
| --- | --- | --- | --- |
| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
-| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=1.0.0,<2.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
+| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
| [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
diff --git a/airflow/providers/google/cloud/example_dags/example_datacatalog.py b/airflow/providers/google/cloud/example_dags/example_datacatalog.py
index c8597a6..cc4b73a 100644
--- a/airflow/providers/google/cloud/example_dags/example_datacatalog.py
+++ b/airflow/providers/google/cloud/example_dags/example_datacatalog.py
@@ -19,7 +19,7 @@
"""
Example Airflow DAG that interacts with Google Data Catalog service
"""
-from google.cloud.datacatalog_v1beta1.proto.tags_pb2 import FieldType, TagField, TagTemplateField
+from google.cloud.datacatalog_v1beta1 import FieldType, TagField, TagTemplateField
from airflow import models
from airflow.operators.bash_operator import BashOperator
@@ -91,7 +91,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte
entry_id=ENTRY_ID,
entry={
"display_name": "Wizard",
- "type": "FILESET",
+ "type_": "FILESET",
"gcs_fileset_spec": {"file_patterns": ["gs://test-datacatalog/**"]},
},
)
@@ -144,7 +144,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte
"display_name": "Awesome Tag Template",
"fields": {
FIELD_NAME_1: TagTemplateField(
- display_name="first-field", type=FieldType(primitive_type="STRING")
+ display_name="first-field", type_=dict(primitive_type="STRING")
)
},
},
@@ -172,7 +172,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte
tag_template=TEMPLATE_ID,
tag_template_field_id=FIELD_NAME_2,
tag_template_field=TagTemplateField(
- display_name="second-field", type=FieldType(primitive_type="STRING")
+ display_name="second-field", type_=FieldType(primitive_type="STRING")
),
)
# [END howto_operator_gcp_datacatalog_create_tag_template_field]
@@ -305,7 +305,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte
# [START howto_operator_gcp_datacatalog_lookup_entry_result]
lookup_entry_result = BashOperator(
task_id="lookup_entry_result",
- bash_command="echo \"{{ task_instance.xcom_pull('lookup_entry')['displayName'] }}\"",
+ bash_command="echo \"{{ task_instance.xcom_pull('lookup_entry')['display_name'] }}\"",
)
# [END howto_operator_gcp_datacatalog_lookup_entry_result]
diff --git a/airflow/providers/google/cloud/hooks/datacatalog.py b/airflow/providers/google/cloud/hooks/datacatalog.py
index 70b488d..0d6cc75 100644
--- a/airflow/providers/google/cloud/hooks/datacatalog.py
+++ b/airflow/providers/google/cloud/hooks/datacatalog.py
@@ -18,16 +18,18 @@
from typing import Dict, Optional, Sequence, Tuple, Union
from google.api_core.retry import Retry
-from google.cloud.datacatalog_v1beta1 import DataCatalogClient
-from google.cloud.datacatalog_v1beta1.types import (
+from google.cloud import datacatalog
+from google.cloud.datacatalog_v1beta1 import (
+ CreateTagRequest,
+ DataCatalogClient,
Entry,
EntryGroup,
- FieldMask,
SearchCatalogRequest,
Tag,
TagTemplate,
TagTemplateField,
)
+from google.protobuf.field_mask_pb2 import FieldMask
from airflow import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@@ -115,10 +117,13 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- parent = DataCatalogClient.entry_group_path(project_id, location, entry_group)
+ parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}"
self.log.info('Creating a new entry: parent=%s', parent)
result = client.create_entry(
- parent=parent, entry_id=entry_id, entry=entry, retry=retry, timeout=timeout, metadata=metadata
+ request={'parent': parent, 'entry_id': entry_id, 'entry': entry},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata or (),
)
self.log.info('Created a entry: name=%s', result.name)
return result
@@ -161,16 +166,14 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- parent = DataCatalogClient.location_path(project_id, location)
+ parent = f"projects/{project_id}/locations/{location}"
self.log.info('Creating a new entry group: parent=%s', parent)
result = client.create_entry_group(
- parent=parent,
- entry_group_id=entry_group_id,
- entry_group=entry_group,
+ request={'parent': parent, 'entry_group_id': entry_group_id, 'entry_group': entry_group},
retry=retry,
timeout=timeout,
- metadata=metadata,
+ metadata=metadata or (),
)
self.log.info('Created a entry group: name=%s', result.name)
@@ -218,15 +221,34 @@ class CloudDataCatalogHook(GoogleBaseHook):
"""
client = self.get_conn()
if template_id:
- template_path = DataCatalogClient.tag_template_path(project_id, location, template_id)
+ template_path = f"projects/{project_id}/locations/{location}/tagTemplates/{template_id}"
if isinstance(tag, Tag):
tag.template = template_path
else:
tag["template"] = template_path
- parent = DataCatalogClient.entry_path(project_id, location, entry_group, entry)
+ parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
self.log.info('Creating a new tag: parent=%s', parent)
- result = client.create_tag(parent=parent, tag=tag, retry=retry, timeout=timeout, metadata=metadata)
+ # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a
+ # primitive type, so we need to convert it manually.
+ # See: https://github.com/googleapis/python-datacatalog/issues/84
+ if isinstance(tag, dict):
+ tag = Tag(
+ name=tag.get('name'),
+ template=tag.get('template'),
+ template_display_name=tag.get('template_display_name'),
+ column=tag.get('column'),
+ fields={
+ k: datacatalog.TagField(**v) if isinstance(v, dict) else v
+ for k, v in tag.get("fields", {}).items()
+ },
+ )
+ request = CreateTagRequest(
+ parent=parent,
+ tag=tag,
+ )
+
+ result = client.create_tag(request=request, retry=retry, timeout=timeout, metadata=metadata or ())
self.log.info('Created a tag: name=%s', result.name)
return result
@@ -267,17 +289,30 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- parent = DataCatalogClient.location_path(project_id, location)
+ parent = f"projects/{project_id}/locations/{location}"
self.log.info('Creating a new tag template: parent=%s', parent)
+ # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a
+ # primitive type, so we need to convert it manually.
+ # See: https://github.com/googleapis/python-datacatalog/issues/84
+ if isinstance(tag_template, dict):
+ tag_template = datacatalog.TagTemplate(
+ name=tag_template.get("name"),
+ display_name=tag_template.get("display_name"),
+ fields={
+ k: datacatalog.TagTemplateField(**v) if isinstance(v, dict) else v
+ for k, v in tag_template.get("fields", {}).items()
+ },
+ )
+ request = datacatalog.CreateTagTemplateRequest(
+ parent=parent, tag_template_id=tag_template_id, tag_template=tag_template
+ )
result = client.create_tag_template(
- parent=parent,
- tag_template_id=tag_template_id,
- tag_template=tag_template,
+ request=request,
retry=retry,
timeout=timeout,
- metadata=metadata,
+ metadata=metadata or (),
)
self.log.info('Created a tag template: name=%s', result.name)
@@ -325,17 +360,19 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- parent = DataCatalogClient.tag_template_path(project_id, location, tag_template)
+ parent = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}"
self.log.info('Creating a new tag template field: parent=%s', parent)
result = client.create_tag_template_field(
- parent=parent,
- tag_template_field_id=tag_template_field_id,
- tag_template_field=tag_template_field,
+ request={
+ 'parent': parent,
+ 'tag_template_field_id': tag_template_field_id,
+ 'tag_template_field': tag_template_field,
+ },
retry=retry,
timeout=timeout,
- metadata=metadata,
+ metadata=metadata or (),
)
self.log.info('Created a tag template field: name=%s', result.name)
@@ -375,9 +412,9 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.entry_path(project_id, location, entry_group, entry)
+ name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
self.log.info('Deleting a entry: name=%s', name)
- client.delete_entry(name=name, retry=retry, timeout=timeout, metadata=metadata)
+ client.delete_entry(request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ())
self.log.info('Deleted a entry: name=%s', name)
@GoogleBaseHook.fallback_to_default_project_id
@@ -412,10 +449,12 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.entry_group_path(project_id, location, entry_group)
+ name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}"
self.log.info('Deleting a entry group: name=%s', name)
- client.delete_entry_group(name=name, retry=retry, timeout=timeout, metadata=metadata)
+ client.delete_entry_group(
+ request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+ )
self.log.info('Deleted a entry group: name=%s', name)
@GoogleBaseHook.fallback_to_default_project_id
@@ -454,10 +493,12 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.tag_path(project_id, location, entry_group, entry, tag)
+ name = (
+ f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}/tags/{tag}"
+ )
self.log.info('Deleting a tag: name=%s', name)
- client.delete_tag(name=name, retry=retry, timeout=timeout, metadata=metadata)
+ client.delete_tag(request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ())
self.log.info('Deleted a tag: name=%s', name)
@GoogleBaseHook.fallback_to_default_project_id
@@ -495,10 +536,12 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.tag_template_path(project_id, location, tag_template)
+ name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}"
self.log.info('Deleting a tag template: name=%s', name)
- client.delete_tag_template(name=name, force=force, retry=retry, timeout=timeout, metadata=metadata)
+ client.delete_tag_template(
+ request={'name': name, 'force': force}, retry=retry, timeout=timeout, metadata=metadata or ()
+ )
self.log.info('Deleted a tag template: name=%s', name)
@GoogleBaseHook.fallback_to_default_project_id
@@ -537,11 +580,11 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.tag_template_field_path(project_id, location, tag_template, field)
+ name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}"
self.log.info('Deleting a tag template field: name=%s', name)
client.delete_tag_template_field(
- name=name, force=force, retry=retry, timeout=timeout, metadata=metadata
+ request={'name': name, 'force': force}, retry=retry, timeout=timeout, metadata=metadata or ()
)
self.log.info('Deleted a tag template field: name=%s', name)
@@ -578,10 +621,12 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.entry_path(project_id, location, entry_group, entry)
+ name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
self.log.info('Getting a entry: name=%s', name)
- result = client.get_entry(name=name, retry=retry, timeout=timeout, metadata=metadata)
+ result = client.get_entry(
+ request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+ )
self.log.info('Received a entry: name=%s', result.name)
return result
@@ -607,8 +652,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
:param read_mask: The fields to return. If not set or empty, all fields are returned.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type read_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type read_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param project_id: The ID of the Google Cloud project that owns the entry group.
If set to ``None`` or missing, the default project_id from the Google Cloud connection is used.
:type project_id: str
@@ -622,12 +667,15 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.entry_group_path(project_id, location, entry_group)
+ name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}"
self.log.info('Getting a entry group: name=%s', name)
result = client.get_entry_group(
- name=name, read_mask=read_mask, retry=retry, timeout=timeout, metadata=metadata
+ request={'name': name, 'read_mask': read_mask},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata or (),
)
self.log.info('Received a entry group: name=%s', result.name)
@@ -664,11 +712,13 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.tag_template_path(project_id, location, tag_template)
+ name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}"
self.log.info('Getting a tag template: name=%s', name)
- result = client.get_tag_template(name=name, retry=retry, timeout=timeout, metadata=metadata)
+ result = client.get_tag_template(
+ request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
+ )
self.log.info('Received a tag template: name=%s', result.name)
@@ -712,12 +762,15 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- parent = DataCatalogClient.entry_path(project_id, location, entry_group, entry)
+ parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
self.log.info('Listing tag on entry: entry_name=%s', parent)
result = client.list_tags(
- parent=parent, page_size=page_size, retry=retry, timeout=timeout, metadata=metadata
+ request={'parent': parent, 'page_size': page_size},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata or (),
)
self.log.info('Received tags.')
@@ -811,12 +864,18 @@ class CloudDataCatalogHook(GoogleBaseHook):
if linked_resource:
self.log.info('Getting entry: linked_resource=%s', linked_resource)
result = client.lookup_entry(
- linked_resource=linked_resource, retry=retry, timeout=timeout, metadata=metadata
+ request={'linked_resource': linked_resource},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata or (),
)
else:
self.log.info('Getting entry: sql_resource=%s', sql_resource)
result = client.lookup_entry(
- sql_resource=sql_resource, retry=retry, timeout=timeout, metadata=metadata
+ request={'sql_resource': sql_resource},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata or (),
)
self.log.info('Received entry. name=%s', result.name)
@@ -860,18 +919,17 @@ class CloudDataCatalogHook(GoogleBaseHook):
:type metadata: Sequence[Tuple[str, str]]
"""
client = self.get_conn()
- name = DataCatalogClient.tag_template_field_path(project_id, location, tag_template, field)
+ name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}"
self.log.info(
'Renaming field: old_name=%s, new_tag_template_field_id=%s', name, new_tag_template_field_id
)
result = client.rename_tag_template_field(
- name=name,
- new_tag_template_field_id=new_tag_template_field_id,
+ request={'name': name, 'new_tag_template_field_id': new_tag_template_field_id},
retry=retry,
timeout=timeout,
- metadata=metadata,
+ metadata=metadata or (),
)
self.log.info('Renamed tag template field.')
@@ -946,13 +1004,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
order_by,
)
result = client.search_catalog(
- scope=scope,
- query=query,
- page_size=page_size,
- order_by=order_by,
+ request={'scope': scope, 'query': query, 'page_size': page_size, 'order_by': order_by},
retry=retry,
timeout=timeout,
- metadata=metadata,
+ metadata=metadata or (),
)
self.log.info('Received items.')
@@ -984,8 +1039,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
updated.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param location: Required. The location of the entry to update.
:type location: str
:param entry_group: The entry group ID for the entry that is being updated.
@@ -1006,7 +1061,9 @@ class CloudDataCatalogHook(GoogleBaseHook):
"""
client = self.get_conn()
if project_id and location and entry_group and entry_id:
- full_entry_name = DataCatalogClient.entry_path(project_id, location, entry_group, entry_id)
+ full_entry_name = (
+ f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry_id}"
+ )
if isinstance(entry, Entry):
entry.name = full_entry_name
elif isinstance(entry, dict):
@@ -1025,7 +1082,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
if isinstance(entry, dict):
entry = Entry(**entry)
result = client.update_entry(
- entry=entry, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata
+ request={'entry': entry, 'update_mask': update_mask},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata or (),
)
self.log.info('Updated entry.')
@@ -1059,7 +1119,7 @@ class CloudDataCatalogHook(GoogleBaseHook):
If a dict is provided, it must be of the same form as the protobuf message
:class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param location: Required. The location of the tag to rename.
:type location: str
:param entry_group: The entry group ID for the tag that is being updated.
@@ -1082,7 +1142,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
"""
client = self.get_conn()
if project_id and location and entry_group and entry and tag_id:
- full_tag_name = DataCatalogClient.tag_path(project_id, location, entry_group, entry, tag_id)
+ full_tag_name = (
+ f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}"
+ f"/tags/{tag_id}"
+ )
if isinstance(tag, Tag):
tag.name = full_tag_name
elif isinstance(tag, dict):
@@ -1102,7 +1165,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
if isinstance(tag, dict):
tag = Tag(**tag)
result = client.update_tag(
- tag=tag, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata
+ request={'tag': tag, 'update_mask': update_mask},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata or (),
)
self.log.info('Updated tag.')
@@ -1137,8 +1203,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
If absent or empty, all of the allowed fields above will be updated.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param location: Required. The location of the tag template to rename.
:type location: str
:param tag_template_id: Optional. The tag template ID for the entry that is being updated.
@@ -1157,8 +1223,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
"""
client = self.get_conn()
if project_id and location and tag_template:
- full_tag_template_name = DataCatalogClient.tag_template_path(
- project_id, location, tag_template_id
+ full_tag_template_name = (
+ f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template_id}"
)
if isinstance(tag_template, TagTemplate):
tag_template.name = full_tag_template_name
@@ -1179,11 +1245,10 @@ class CloudDataCatalogHook(GoogleBaseHook):
if isinstance(tag_template, dict):
tag_template = TagTemplate(**tag_template)
result = client.update_tag_template(
- tag_template=tag_template,
- update_mask=update_mask,
+ request={'tag_template': tag_template, 'update_mask': update_mask},
retry=retry,
timeout=timeout,
- metadata=metadata,
+ metadata=metadata or (),
)
self.log.info('Updated tag template.')
@@ -1222,8 +1287,8 @@ class CloudDataCatalogHook(GoogleBaseHook):
Therefore, enum values can only be added, existing enum values cannot be deleted nor renamed.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param tag_template_field_name: Optional. The name of the tag template field to rename.
:type tag_template_field_name: str
:param location: Optional. The location of the tag to rename.
@@ -1246,19 +1311,22 @@ class CloudDataCatalogHook(GoogleBaseHook):
"""
client = self.get_conn()
if project_id and location and tag_template and tag_template_field_id:
- tag_template_field_name = DataCatalogClient.tag_template_field_path(
- project_id, location, tag_template, tag_template_field_id
+ tag_template_field_name = (
+ f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}"
+ f"/fields/{tag_template_field_id}"
)
self.log.info("Updating tag template field: name=%s", tag_template_field_name)
result = client.update_tag_template_field(
- name=tag_template_field_name,
- tag_template_field=tag_template_field,
- update_mask=update_mask,
+ request={
+ 'name': tag_template_field_name,
+ 'tag_template_field': tag_template_field,
+ 'update_mask': update_mask,
+ },
retry=retry,
timeout=timeout,
- metadata=metadata,
+ metadata=metadata or (),
)
self.log.info('Updated tag template field.')
diff --git a/airflow/providers/google/cloud/operators/datacatalog.py b/airflow/providers/google/cloud/operators/datacatalog.py
index 00b2765..4b0da05 100644
--- a/airflow/providers/google/cloud/operators/datacatalog.py
+++ b/airflow/providers/google/cloud/operators/datacatalog.py
@@ -19,17 +19,16 @@ from typing import Dict, Optional, Sequence, Tuple, Union
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry
-from google.cloud.datacatalog_v1beta1 import DataCatalogClient
+from google.cloud.datacatalog_v1beta1 import DataCatalogClient, SearchCatalogResult
from google.cloud.datacatalog_v1beta1.types import (
Entry,
EntryGroup,
- FieldMask,
SearchCatalogRequest,
Tag,
TagTemplate,
TagTemplateField,
)
-from google.protobuf.json_format import MessageToDict
+from google.protobuf.field_mask_pb2 import FieldMask
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook
@@ -153,7 +152,7 @@ class CloudDataCatalogCreateEntryOperator(BaseOperator):
_, _, entry_id = result.name.rpartition("/")
self.log.info("Current entry_id ID: %s", entry_id)
context["task_instance"].xcom_push(key="entry_id", value=entry_id)
- return MessageToDict(result)
+ return Entry.to_dict(result)
class CloudDataCatalogCreateEntryGroupOperator(BaseOperator):
@@ -268,7 +267,7 @@ class CloudDataCatalogCreateEntryGroupOperator(BaseOperator):
_, _, entry_group_id = result.name.rpartition("/")
self.log.info("Current entry group ID: %s", entry_group_id)
context["task_instance"].xcom_push(key="entry_group_id", value=entry_group_id)
- return MessageToDict(result)
+ return EntryGroup.to_dict(result)
class CloudDataCatalogCreateTagOperator(BaseOperator):
@@ -404,7 +403,7 @@ class CloudDataCatalogCreateTagOperator(BaseOperator):
_, _, tag_id = tag.name.rpartition("/")
self.log.info("Current Tag ID: %s", tag_id)
context["task_instance"].xcom_push(key="tag_id", value=tag_id)
- return MessageToDict(tag)
+ return Tag.to_dict(tag)
class CloudDataCatalogCreateTagTemplateOperator(BaseOperator):
@@ -516,7 +515,7 @@ class CloudDataCatalogCreateTagTemplateOperator(BaseOperator):
_, _, tag_template = result.name.rpartition("/")
self.log.info("Current Tag ID: %s", tag_template)
context["task_instance"].xcom_push(key="tag_template_id", value=tag_template)
- return MessageToDict(result)
+ return TagTemplate.to_dict(result)
class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator):
@@ -638,7 +637,7 @@ class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator):
self.log.info("Current Tag ID: %s", self.tag_template_field_id)
context["task_instance"].xcom_push(key="tag_template_field_id", value=self.tag_template_field_id)
- return MessageToDict(result)
+ return TagTemplateField.to_dict(result)
class CloudDataCatalogDeleteEntryOperator(BaseOperator):
@@ -1216,7 +1215,7 @@ class CloudDataCatalogGetEntryOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- return MessageToDict(result)
+ return Entry.to_dict(result)
class CloudDataCatalogGetEntryGroupOperator(BaseOperator):
@@ -1234,8 +1233,8 @@ class CloudDataCatalogGetEntryGroupOperator(BaseOperator):
:param read_mask: The fields to return. If not set or empty, all fields are returned.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type read_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type read_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param project_id: The ID of the Google Cloud project that owns the entry group.
If set to ``None`` or missing, the default project_id from the Google Cloud connection is used.
:type project_id: Optional[str]
@@ -1312,7 +1311,7 @@ class CloudDataCatalogGetEntryGroupOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- return MessageToDict(result)
+ return EntryGroup.to_dict(result)
class CloudDataCatalogGetTagTemplateOperator(BaseOperator):
@@ -1399,7 +1398,7 @@ class CloudDataCatalogGetTagTemplateOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- return MessageToDict(result)
+ return TagTemplate.to_dict(result)
class CloudDataCatalogListTagsOperator(BaseOperator):
@@ -1501,7 +1500,7 @@ class CloudDataCatalogListTagsOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- return [MessageToDict(item) for item in result]
+ return [Tag.to_dict(item) for item in result]
class CloudDataCatalogLookupEntryOperator(BaseOperator):
@@ -1589,7 +1588,7 @@ class CloudDataCatalogLookupEntryOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- return MessageToDict(result)
+ return Entry.to_dict(result)
class CloudDataCatalogRenameTagTemplateFieldOperator(BaseOperator):
@@ -1809,7 +1808,7 @@ class CloudDataCatalogSearchCatalogOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- return [MessageToDict(item) for item in result]
+ return [SearchCatalogResult.to_dict(item) for item in result]
class CloudDataCatalogUpdateEntryOperator(BaseOperator):
@@ -1829,8 +1828,8 @@ class CloudDataCatalogUpdateEntryOperator(BaseOperator):
updated.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param location: Required. The location of the entry to update.
:type location: str
:param entry_group: The entry group ID for the entry that is being updated.
@@ -1940,8 +1939,8 @@ class CloudDataCatalogUpdateTagOperator(BaseOperator):
updated. Currently the only modifiable field is the field ``fields``.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param location: Required. The location of the tag to rename.
:type location: str
:param entry_group: The entry group ID for the tag that is being updated.
@@ -2060,8 +2059,8 @@ class CloudDataCatalogUpdateTagTemplateOperator(BaseOperator):
If absent or empty, all of the allowed fields above will be updated.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param location: Required. The location of the tag template to rename.
:type location: str
:param tag_template_id: Optional. The tag template ID for the entry that is being updated.
@@ -2172,8 +2171,8 @@ class CloudDataCatalogUpdateTagTemplateFieldOperator(BaseOperator):
Therefore, enum values can only be added, existing enum values cannot be deleted nor renamed.
If a dict is provided, it must be of the same form as the protobuf message
- :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask`
- :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask]
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
:param tag_template_field_name: Optional. The name of the tag template field to rename.
:type tag_template_field_name: str
:param location: Optional. The location of the tag to rename.
diff --git a/setup.py b/setup.py
index 75f5db5..5314814 100644
--- a/setup.py
+++ b/setup.py
@@ -287,7 +287,7 @@ google = [
'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0',
'google-cloud-bigtable>=1.0.0,<2.0.0',
'google-cloud-container>=0.1.1,<2.0.0',
- 'google-cloud-datacatalog>=1.0.0,<2.0.0',
+ 'google-cloud-datacatalog>=3.0.0,<4.0.0',
'google-cloud-dataproc>=1.0.1,<2.0.0',
'google-cloud-dlp>=0.11.0,<2.0.0',
'google-cloud-kms>=2.0.0,<3.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_datacatalog.py b/tests/providers/google/cloud/hooks/test_datacatalog.py
index f5192c5..99d785f 100644
--- a/tests/providers/google/cloud/hooks/test_datacatalog.py
+++ b/tests/providers/google/cloud/hooks/test_datacatalog.py
@@ -22,6 +22,7 @@ from unittest import TestCase, mock
import pytest
from google.api_core.retry import Retry
+from google.cloud.datacatalog_v1beta1 import CreateTagRequest, CreateTagTemplateRequest
from google.cloud.datacatalog_v1beta1.types import Entry, Tag, TagTemplate
from airflow import AirflowException
@@ -38,7 +39,7 @@ TEST_ENTRY_ID: str = "test-entry-id"
TEST_ENTRY: Dict = {}
TEST_RETRY: Retry = Retry()
TEST_TIMEOUT: float = 4
-TEST_METADATA: Sequence[Tuple[str, str]] = []
+TEST_METADATA: Sequence[Tuple[str, str]] = ()
TEST_ENTRY_GROUP_ID: str = "test-entry-group-id"
TEST_ENTRY_GROUP: Dict = {}
TEST_TAG: Dict = {}
@@ -102,7 +103,7 @@ class TestCloudDataCatalog(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.lookup_entry.assert_called_once_with(
- linked_resource=TEST_LINKED_RESOURCE,
+ request=dict(linked_resource=TEST_LINKED_RESOURCE),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -118,7 +119,10 @@ class TestCloudDataCatalog(TestCase):
sql_resource=TEST_SQL_RESOURCE, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
)
mock_get_conn.return_value.lookup_entry.assert_called_once_with(
- sql_resource=TEST_SQL_RESOURCE, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA
+ request=dict(sql_resource=TEST_SQL_RESOURCE),
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
)
@mock.patch(
@@ -148,10 +152,9 @@ class TestCloudDataCatalog(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.search_catalog.assert_called_once_with(
- scope=TEST_SCOPE,
- query=TEST_QUERY,
- page_size=TEST_PAGE_SIZE,
- order_by=TEST_ORDER_BY,
+ request=dict(
+ scope=TEST_SCOPE, query=TEST_QUERY, page_size=TEST_PAGE_SIZE, order_by=TEST_ORDER_BY
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -184,9 +187,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_entry.assert_called_once_with(
- parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
- entry_id=TEST_ENTRY_ID,
- entry=TEST_ENTRY,
+ request=dict(
+ parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
+ entry_id=TEST_ENTRY_ID,
+ entry=TEST_ENTRY,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -207,9 +212,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_entry_group.assert_called_once_with(
- parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1),
- entry_group_id=TEST_ENTRY_GROUP_ID,
- entry_group=TEST_ENTRY_GROUP,
+ request=dict(
+ parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1),
+ entry_group_id=TEST_ENTRY_GROUP_ID,
+ entry_group=TEST_ENTRY_GROUP,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -232,8 +239,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_tag.assert_called_once_with(
- parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
- tag={"template": TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)},
+ request=CreateTagRequest(
+ parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+ tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -256,8 +265,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_tag.assert_called_once_with(
- parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
- tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
+ request=CreateTagRequest(
+ parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+ tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -278,9 +289,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_tag_template.assert_called_once_with(
- parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1),
- tag_template_id=TEST_TAG_TEMPLATE_ID,
- tag_template=TEST_TAG_TEMPLATE,
+ request=CreateTagTemplateRequest(
+ parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1),
+ tag_template_id=TEST_TAG_TEMPLATE_ID,
+ tag_template=TEST_TAG_TEMPLATE,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -302,9 +315,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_tag_template_field.assert_called_once_with(
- parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
- tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID,
- tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+ request=dict(
+ parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
+ tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID,
+ tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -325,7 +340,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_entry.assert_called_once_with(
- name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+ request=dict(
+ name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -345,7 +362,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_entry_group.assert_called_once_with(
- name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
+ request=dict(
+ name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -367,7 +386,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_tag.assert_called_once_with(
- name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1),
+ request=dict(
+ name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -388,8 +409,7 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_tag_template.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
- force=TEST_FORCE,
+ request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1), force=TEST_FORCE),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -411,8 +431,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_tag_template_field.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
- force=TEST_FORCE,
+ request=dict(
+ name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
+ force=TEST_FORCE,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -433,7 +455,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.get_entry.assert_called_once_with(
- name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+ request=dict(
+ name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -454,8 +478,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.get_entry_group.assert_called_once_with(
- name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
- read_mask=TEST_READ_MASK,
+ request=dict(
+ name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1),
+ read_mask=TEST_READ_MASK,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -475,7 +501,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.get_tag_template.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
+ request=dict(
+ name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -497,8 +525,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.list_tags.assert_called_once_with(
- parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
- page_size=TEST_PAGE_SIZE,
+ request=dict(
+ parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+ page_size=TEST_PAGE_SIZE,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -524,8 +554,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.list_tags.assert_called_once_with(
- parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
- page_size=100,
+ request=dict(
+ parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1),
+ page_size=100,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -548,8 +580,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.rename_tag_template_field.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
- new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID,
+ request=dict(
+ name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
+ new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -572,8 +606,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.update_entry.assert_called_once_with(
- entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1)),
- update_mask=TEST_UPDATE_MASK,
+ request=dict(
+ entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1)),
+ update_mask=TEST_UPDATE_MASK,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -597,8 +633,7 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.update_tag.assert_called_once_with(
- tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1)),
- update_mask=TEST_UPDATE_MASK,
+ request=dict(tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1)), update_mask=TEST_UPDATE_MASK),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -620,8 +655,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.update_tag_template.assert_called_once_with(
- tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
- update_mask=TEST_UPDATE_MASK,
+ request=dict(
+ tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)),
+ update_mask=TEST_UPDATE_MASK,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -644,9 +681,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.update_tag_template_field.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
- tag_template_field=TEST_TAG_TEMPLATE_FIELD,
- update_mask=TEST_UPDATE_MASK,
+ request=dict(
+ name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1),
+ tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+ update_mask=TEST_UPDATE_MASK,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -680,9 +719,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_entry.assert_called_once_with(
- parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
- entry_id=TEST_ENTRY_ID,
- entry=TEST_ENTRY,
+ request=dict(
+ parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
+ entry_id=TEST_ENTRY_ID,
+ entry=TEST_ENTRY,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -704,9 +745,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_entry_group.assert_called_once_with(
- parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2),
- entry_group_id=TEST_ENTRY_GROUP_ID,
- entry_group=TEST_ENTRY_GROUP,
+ request=dict(
+ parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2),
+ entry_group_id=TEST_ENTRY_GROUP_ID,
+ entry_group=TEST_ENTRY_GROUP,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -730,8 +773,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_tag.assert_called_once_with(
- parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
- tag={"template": TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)},
+ request=CreateTagRequest(
+ parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
+ tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -755,8 +800,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_tag.assert_called_once_with(
- parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
- tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
+ request=CreateTagRequest(
+ parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
+ tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -778,9 +825,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_tag_template.assert_called_once_with(
- parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2),
- tag_template_id=TEST_TAG_TEMPLATE_ID,
- tag_template=TEST_TAG_TEMPLATE,
+ request=CreateTagTemplateRequest(
+ parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2),
+ tag_template_id=TEST_TAG_TEMPLATE_ID,
+ tag_template=TEST_TAG_TEMPLATE,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -803,9 +852,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.create_tag_template_field.assert_called_once_with(
- parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2),
- tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID,
- tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+ request=dict(
+ parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2),
+ tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID,
+ tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -827,7 +878,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_entry.assert_called_once_with(
- name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
+ request=dict(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -848,7 +899,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_entry_group.assert_called_once_with(
- name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
+ request=dict(name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2)),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -871,7 +922,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_tag.assert_called_once_with(
- name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2),
+ request=dict(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -893,8 +944,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_tag_template.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2),
- force=TEST_FORCE,
+ request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2), force=TEST_FORCE),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -917,8 +967,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.delete_tag_template_field.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
- force=TEST_FORCE,
+ request=dict(name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2), force=TEST_FORCE),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -940,7 +989,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.get_entry.assert_called_once_with(
- name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
+ request=dict(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -962,8 +1011,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.get_entry_group.assert_called_once_with(
- name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
- read_mask=TEST_READ_MASK,
+ request=dict(
+ name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2),
+ read_mask=TEST_READ_MASK,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -984,7 +1035,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.get_tag_template.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2),
+ request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -1007,8 +1058,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.list_tags.assert_called_once_with(
- parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
- page_size=TEST_PAGE_SIZE,
+ request=dict(parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), page_size=TEST_PAGE_SIZE),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -1035,8 +1085,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.list_tags.assert_called_once_with(
- parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2),
- page_size=100,
+ request=dict(parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), page_size=100),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -1060,8 +1109,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.rename_tag_template_field.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
- new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID,
+ request=dict(
+ name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
+ new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -1085,8 +1136,9 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.update_entry.assert_called_once_with(
- entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)),
- update_mask=TEST_UPDATE_MASK,
+ request=dict(
+ entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -1111,8 +1163,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.update_tag.assert_called_once_with(
- tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)),
- update_mask=TEST_UPDATE_MASK,
+ request=dict(tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -1135,8 +1186,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.update_tag_template.assert_called_once_with(
- tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
- update_mask=TEST_UPDATE_MASK,
+ request=dict(
+ tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)),
+ update_mask=TEST_UPDATE_MASK,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
@@ -1160,9 +1213,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase):
metadata=TEST_METADATA,
)
mock_get_conn.return_value.update_tag_template_field.assert_called_once_with(
- name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
- tag_template_field=TEST_TAG_TEMPLATE_FIELD,
- update_mask=TEST_UPDATE_MASK,
+ request=dict(
+ name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2),
+ tag_template_field=TEST_TAG_TEMPLATE_FIELD,
+ update_mask=TEST_UPDATE_MASK,
+ ),
retry=TEST_RETRY,
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
diff --git a/tests/providers/google/cloud/operators/test_datacatalog.py b/tests/providers/google/cloud/operators/test_datacatalog.py
index b575dd4..517b35c 100644
--- a/tests/providers/google/cloud/operators/test_datacatalog.py
+++ b/tests/providers/google/cloud/operators/test_datacatalog.py
@@ -87,15 +87,25 @@ TEST_TAG_PATH: str = (
)
TEST_ENTRY: Entry = Entry(name=TEST_ENTRY_PATH)
-TEST_ENTRY_DICT: Dict = dict(name=TEST_ENTRY_PATH)
+TEST_ENTRY_DICT: Dict = {
+ 'description': '',
+ 'display_name': '',
+ 'linked_resource': '',
+ 'name': TEST_ENTRY_PATH,
+}
TEST_ENTRY_GROUP: EntryGroup = EntryGroup(name=TEST_ENTRY_GROUP_PATH)
-TEST_ENTRY_GROUP_DICT: Dict = dict(name=TEST_ENTRY_GROUP_PATH)
-TEST_TAG: EntryGroup = Tag(name=TEST_TAG_PATH)
-TEST_TAG_DICT: Dict = dict(name=TEST_TAG_PATH)
+TEST_ENTRY_GROUP_DICT: Dict = {'description': '', 'display_name': '', 'name': TEST_ENTRY_GROUP_PATH}
+TEST_TAG: Tag = Tag(name=TEST_TAG_PATH)
+TEST_TAG_DICT: Dict = {'fields': {}, 'name': TEST_TAG_PATH, 'template': '', 'template_display_name': ''}
TEST_TAG_TEMPLATE: TagTemplate = TagTemplate(name=TEST_TAG_TEMPLATE_PATH)
-TEST_TAG_TEMPLATE_DICT: Dict = dict(name=TEST_TAG_TEMPLATE_PATH)
-TEST_TAG_TEMPLATE_FIELD: Dict = TagTemplateField(name=TEST_TAG_TEMPLATE_FIELD_ID)
-TEST_TAG_TEMPLATE_FIELD_DICT: Dict = dict(name=TEST_TAG_TEMPLATE_FIELD_ID)
+TEST_TAG_TEMPLATE_DICT: Dict = {'display_name': '', 'fields': {}, 'name': TEST_TAG_TEMPLATE_PATH}
+TEST_TAG_TEMPLATE_FIELD: TagTemplateField = TagTemplateField(name=TEST_TAG_TEMPLATE_FIELD_ID)
+TEST_TAG_TEMPLATE_FIELD_DICT: Dict = {
+ 'display_name': '',
+ 'is_required': False,
+ 'name': TEST_TAG_TEMPLATE_FIELD_ID,
+ 'order': 0,
+}
class TestCloudDataCatalogCreateEntryOperator(TestCase):
@@ -498,7 +508,10 @@ class TestCloudDataCatalogDeleteTagTemplateFieldOperator(TestCase):
class TestCloudDataCatalogGetEntryOperator(TestCase):
- @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+ @mock.patch(
+ "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+ **{"return_value.get_entry.return_value": TEST_ENTRY}, # type: ignore
+ )
def test_assert_valid_hook_call(self, mock_hook) -> None:
task = CloudDataCatalogGetEntryOperator(
task_id="task_id",
@@ -529,7 +542,10 @@ class TestCloudDataCatalogGetEntryOperator(TestCase):
class TestCloudDataCatalogGetEntryGroupOperator(TestCase):
- @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+ @mock.patch(
+ "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+ **{"return_value.get_entry_group.return_value": TEST_ENTRY_GROUP}, # type: ignore
+ )
def test_assert_valid_hook_call(self, mock_hook) -> None:
task = CloudDataCatalogGetEntryGroupOperator(
task_id="task_id",
@@ -560,7 +576,10 @@ class TestCloudDataCatalogGetEntryGroupOperator(TestCase):
class TestCloudDataCatalogGetTagTemplateOperator(TestCase):
- @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+ @mock.patch(
+ "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+ **{"return_value.get_tag_template.return_value": TEST_TAG_TEMPLATE}, # type: ignore
+ )
def test_assert_valid_hook_call(self, mock_hook) -> None:
task = CloudDataCatalogGetTagTemplateOperator(
task_id="task_id",
@@ -589,7 +608,10 @@ class TestCloudDataCatalogGetTagTemplateOperator(TestCase):
class TestCloudDataCatalogListTagsOperator(TestCase):
- @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+ @mock.patch(
+ "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+ **{"return_value.list_tags.return_value": [TEST_TAG]}, # type: ignore
+ )
def test_assert_valid_hook_call(self, mock_hook) -> None:
task = CloudDataCatalogListTagsOperator(
task_id="task_id",
@@ -622,7 +644,10 @@ class TestCloudDataCatalogListTagsOperator(TestCase):
class TestCloudDataCatalogLookupEntryOperator(TestCase):
- @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
+ @mock.patch(
+ "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
+ **{"return_value.lookup_entry.return_value": TEST_ENTRY}, # type: ignore
+ )
def test_assert_valid_hook_call(self, mock_hook) -> None:
task = CloudDataCatalogLookupEntryOperator(
task_id="task_id",