You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/12/07 09:07:37 UTC
(airflow) branch main updated: Add object methods in weaviate hook (#35934)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b2464477c4 Add object methods in weaviate hook (#35934)
b2464477c4 is described below
commit b2464477c472894f142c1a85f04a92af033e700e
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Thu Dec 7 10:07:29 2023 +0100
Add object methods in weaviate hook (#35934)
* Add object methods in weaviate hook
This PR adds methods to the weaviate hook to help get/create/update
and delete objects
* fixup! Add object methods in weaviate hook
* Add typing
* improve get_or_create_object method
* add object validate
* Improve get_all_objects to return a dataframe if requested
* use conn
* Add object exists method
* fixup! Add object exists method
* Fix docs
* Update get_or_create_object method
---
airflow/providers/weaviate/hooks/weaviate.py | 155 ++++++++++++++++++++-
airflow/providers/weaviate/provider.yaml | 1 +
generated/provider_dependencies.json | 1 +
tests/providers/weaviate/hooks/test_weaviate.py | 178 ++++++++++++++++++++++++
4 files changed, 334 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py
index 66d820bbbe..ae32727ffb 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -19,14 +19,23 @@ from __future__ import annotations
import warnings
from functools import cached_property
-from typing import Any
+from typing import TYPE_CHECKING, Sequence
from weaviate import Client as WeaviateClient
from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword
+from weaviate.exceptions import ObjectAlreadyExistsException
+from weaviate.util import generate_uuid5
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
+if TYPE_CHECKING:
+ from typing import Any
+
+ import pandas as pd
+ from weaviate import ConsistencyLevel
+ from weaviate.types import UUID
+
class WeaviateHook(BaseHook):
"""
@@ -200,3 +209,147 @@ class WeaviateHook(BaseHook):
.do()
)
return results
+
+ def create_object(
+ self, data_object: dict | str, class_name: str, **kwargs
+ ) -> str | dict[str, Any] | None:
+ """Create a new object.
+
+ :param data_object: Object to be added. If type is str it should be either a URL or a file.
+ :param class_name: Class name associated with the object given.
+ :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create()
+ """
+ client = self.conn
+ # generate deterministic uuid if not provided
+ uuid = kwargs.pop("uuid", generate_uuid5(data_object))
+ try:
+ return client.data_object.create(data_object, class_name, uuid=uuid, **kwargs)
+ except ObjectAlreadyExistsException:
+ self.log.warning("Object with the UUID %s already exists", uuid)
+ return None
+
+ def get_or_create_object(
+ self,
+ data_object: dict | str | None = None,
+ class_name: str | None = None,
+ vector: Sequence | None = None,
+ consistency_level: ConsistencyLevel | None = None,
+ tenant: str | None = None,
+ **kwargs,
+ ) -> str | dict[str, Any] | None:
+ """Get or Create a new object.
+
+ Returns the object if already exists
+
+ :param data_object: Object to be added. If type is str it should be either a URL or a file. This is required
+ to create a new object.
+ :param class_name: Class name associated with the object given. This is required to create a new object.
+ :param vector: Vector associated with the object given. This argument is only used when creating object.
+ :param consistency_level: Consistency level to be used. Applies to both create and get operations.
+ :tenant: Tenant to be used. Applies to both create and get operations.
+ :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() and
+ weaviate_client.data_object.get()
+ """
+ obj = self.get_object(
+ class_name=class_name, consistency_level=consistency_level, tenant=tenant, **kwargs
+ )
+ if not obj:
+ if not (data_object and class_name):
+ raise ValueError("data_object and class_name are required to create a new object")
+ uuid = kwargs.pop("uuid", generate_uuid5(data_object))
+ return self.create_object(
+ data_object,
+ class_name,
+ vector=vector,
+ uuid=uuid,
+ consistency_level=consistency_level,
+ tenant=tenant,
+ )
+ return obj
+
+ def get_object(self, **kwargs) -> dict[str, Any] | None:
+ """Get objects or an object from weaviate.
+
+ :param kwargs: parameters to be passed to weaviate_client.data_object.get() or
+ weaviate_client.data_object.get_by_id()
+ """
+ client = self.conn
+ return client.data_object.get(**kwargs)
+
+ def get_all_objects(
+ self, after: str | UUID | None = None, as_dataframe: bool = False, **kwargs
+ ) -> list[dict[str, Any]] | pd.DataFrame:
+ """Get all objects from weaviate.
+
+ if after is provided, it will be used as the starting point for the listing.
+
+ :param after: uuid of the object to start listing from
+ :param as_dataframe: if True, returns a pandas dataframe
+ :param kwargs: parameters to be passed to weaviate_client.data_object.get()
+ """
+ all_objects = []
+ after = kwargs.pop("after", after)
+ while True:
+ results = self.get_object(after=after, **kwargs) or {}
+ if not results.get("objects"):
+ break
+ all_objects.extend(results["objects"])
+ after = results["objects"][-1]["id"]
+ if as_dataframe:
+ import pandas
+
+ return pandas.DataFrame(all_objects)
+ return all_objects
+
+ def delete_object(self, uuid: UUID | str, **kwargs) -> None:
+ """Delete an object from weaviate.
+
+ :param uuid: uuid of the object to be deleted
+ :param kwargs: Optional parameters to be passed to weaviate_client.data_object.delete()
+ """
+ client = self.conn
+ client.data_object.delete(uuid, **kwargs)
+
+ def update_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None:
+ """Update an object in weaviate.
+
+ :param data_object: The object states the fields that should be updated. Fields not specified in the
+ 'data_object' remain unchanged. Fields that are None will not be changed.
+ If type is str it should be either an URL or a file.
+ :param class_name: Class name associated with the object given.
+ :param uuid: uuid of the object to be updated
+ :param kwargs: Optional parameters to be passed to weaviate_client.data_object.update()
+ """
+ client = self.conn
+ client.data_object.update(data_object, class_name, uuid, **kwargs)
+
+ def replace_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None:
+ """Replace an object in weaviate.
+
+ :param data_object: The object states the fields that should be updated. Fields not specified in the
+ 'data_object' will be set to None. If type is str it should be either an URL or a file.
+ :param class_name: Class name associated with the object given.
+ :param uuid: uuid of the object to be replaced
+ :param kwargs: Optional parameters to be passed to weaviate_client.data_object.replace()
+ """
+ client = self.conn
+ client.data_object.replace(data_object, class_name, uuid, **kwargs)
+
+ def validate_object(self, data_object: dict | str, class_name: str, **kwargs):
+ """Validate an object in weaviate.
+
+ :param data_object: The object to be validated. If type is str it should be either an URL or a file.
+ :param class_name: Class name associated with the object given.
+ :param kwargs: Optional parameters to be passed to weaviate_client.data_object.validate()
+ """
+ client = self.conn
+ client.data_object.validate(data_object, class_name, **kwargs)
+
+ def object_exists(self, uuid: str | UUID, **kwargs) -> bool:
+ """Check if an object exists in weaviate.
+
+ :param uuid: The UUID of the object that may or may not exist within Weaviate.
+ :param kwargs: Optional parameters to be passed to weaviate_client.data_object.exists()
+ """
+ client = self.conn
+ return client.data_object.exists(uuid, **kwargs)
diff --git a/airflow/providers/weaviate/provider.yaml b/airflow/providers/weaviate/provider.yaml
index 015fa66c6f..86d60ba708 100644
--- a/airflow/providers/weaviate/provider.yaml
+++ b/airflow/providers/weaviate/provider.yaml
@@ -39,6 +39,7 @@ integrations:
dependencies:
- apache-airflow>=2.6.0
- weaviate-client>=3.24.2
+ - pandas>=0.17.1
hooks:
- integration-name: Weaviate
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index c47736d9ac..24906c94cf 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -949,6 +949,7 @@
"weaviate": {
"deps": [
"apache-airflow>=2.6.0",
+ "pandas>=0.17.1",
"weaviate-client>=3.24.2"
],
"cross-providers-deps": [],
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py
index 0274004fc0..6f4be77429 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -20,6 +20,7 @@ from unittest import mock
from unittest.mock import MagicMock, Mock, patch
import pytest
+from weaviate import ObjectAlreadyExistsException
from airflow.models import Connection
from airflow.providers.weaviate.hooks.weaviate import WeaviateHook
@@ -179,6 +180,183 @@ class TestWeaviateHook:
additional_headers={},
)
+ @mock.patch("airflow.providers.weaviate.hooks.weaviate.generate_uuid5")
+ def test_create_object(self, mock_gen_uuid, weaviate_hook):
+ """
+ Test the create_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ return_value = weaviate_hook.create_object({"name": "Test"}, "TestClass")
+ mock_gen_uuid.assert_called_once()
+ mock_client.data_object.create.assert_called_once_with(
+ {"name": "Test"}, "TestClass", uuid=mock_gen_uuid.return_value
+ )
+ assert return_value
+
+ def test_create_object_already_exists_return_none(self, weaviate_hook):
+ """
+ Test the create_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ mock_client.data_object.create.side_effect = ObjectAlreadyExistsException
+ return_value = weaviate_hook.create_object({"name": "Test"}, "TestClass")
+ assert return_value is None
+
+ def test_get_object(self, weaviate_hook):
+ """
+ Test the get_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.get_object(class_name="TestClass", uuid="uuid")
+ mock_client.data_object.get.assert_called_once_with(class_name="TestClass", uuid="uuid")
+
+ def test_get_of_get_or_create_object(self, weaviate_hook):
+ """
+ Test the get part of get_or_create_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name="TestClass")
+ mock_client.data_object.get.assert_called_once_with(
+ class_name="TestClass",
+ consistency_level=None,
+ tenant=None,
+ )
+
+ @mock.patch("airflow.providers.weaviate.hooks.weaviate.generate_uuid5")
+ def test_create_of_get_or_create_object(self, mock_gen_uuid, weaviate_hook):
+ """
+ Test the create part of get_or_create_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.get_object = MagicMock(return_value=None)
+ mock_create_object = MagicMock()
+ weaviate_hook.create_object = mock_create_object
+ weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name="TestClass")
+ mock_create_object.assert_called_once_with(
+ {"name": "Test"},
+ "TestClass",
+ uuid=mock_gen_uuid.return_value,
+ consistency_level=None,
+ tenant=None,
+ vector=None,
+ )
+
+ def test_create_of_get_or_create_object_raises_valueerror(self, weaviate_hook):
+ """
+ Test that if data_object is None or class_name is None, ValueError is raised.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.get_object = MagicMock(return_value=None)
+ mock_create_object = MagicMock()
+ weaviate_hook.create_object = mock_create_object
+ with pytest.raises(ValueError):
+ weaviate_hook.get_or_create_object(data_object=None, class_name="TestClass")
+ with pytest.raises(ValueError):
+ weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name=None)
+
+ def test_get_all_objects(self, weaviate_hook):
+ """
+ Test the get_all_objects method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ objects = [
+ {"deprecations": None, "objects": [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]},
+ {"deprecations": None, "objects": []},
+ ]
+ mock_get_object = MagicMock()
+ weaviate_hook.get_object = mock_get_object
+ mock_get_object.side_effect = objects
+
+ return_value = weaviate_hook.get_all_objects(class_name="TestClass")
+ assert weaviate_hook.get_object.call_args_list == [
+ mock.call(after=None, class_name="TestClass"),
+ mock.call(after=3, class_name="TestClass"),
+ ]
+ assert return_value == [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]
+
+ def test_get_all_objects_returns_dataframe(self, weaviate_hook):
+ """
+ Test the get_all_objects method of WeaviateHook can return a dataframe.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ objects = [
+ {"deprecations": None, "objects": [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]},
+ {"deprecations": None, "objects": []},
+ ]
+ mock_get_object = MagicMock()
+ weaviate_hook.get_object = mock_get_object
+ mock_get_object.side_effect = objects
+
+ return_value = weaviate_hook.get_all_objects(class_name="TestClass", as_dataframe=True)
+ assert weaviate_hook.get_object.call_args_list == [
+ mock.call(after=None, class_name="TestClass"),
+ mock.call(after=3, class_name="TestClass"),
+ ]
+ import pandas
+
+ assert isinstance(return_value, pandas.DataFrame)
+
+ def test_delete_object(self, weaviate_hook):
+ """
+ Test the delete_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.delete_object(uuid="uuid", class_name="TestClass")
+ mock_client.data_object.delete.assert_called_once_with("uuid", class_name="TestClass")
+
+ def test_update_object(self, weaviate_hook):
+ """
+ Test the update_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.update_object(
+ uuid="uuid", class_name="TestClass", data_object={"name": "Test"}, tenant="2d"
+ )
+ mock_client.data_object.update.assert_called_once_with(
+ {"name": "Test"}, "TestClass", "uuid", tenant="2d"
+ )
+
+ def test_validate_object(self, weaviate_hook):
+ """
+ Test the validate_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.validate_object(class_name="TestClass", data_object={"name": "Test"}, uuid="2d")
+ mock_client.data_object.validate.assert_called_once_with({"name": "Test"}, "TestClass", uuid="2d")
+
+ def test_replace_object(self, weaviate_hook):
+ """
+ Test the replace_object method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.replace_object(
+ uuid="uuid", class_name="TestClass", data_object={"name": "Test"}, tenant="2d"
+ )
+ mock_client.data_object.replace.assert_called_once_with(
+ {"name": "Test"}, "TestClass", "uuid", tenant="2d"
+ )
+
+ def test_object_exists(self, weaviate_hook):
+ """
+ Test the object_exists method of WeaviateHook.
+ """
+ mock_client = MagicMock()
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+ weaviate_hook.object_exists(class_name="TestClass", uuid="2d")
+ mock_client.data_object.exists.assert_called_once_with("2d", class_name="TestClass")
+
def test_create_class(weaviate_hook):
"""