You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/12/07 09:07:37 UTC

(airflow) branch main updated: Add object methods in weaviate hook (#35934)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b2464477c4 Add object methods in weaviate hook (#35934)
b2464477c4 is described below

commit b2464477c472894f142c1a85f04a92af033e700e
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Thu Dec 7 10:07:29 2023 +0100

    Add object methods in weaviate hook (#35934)
    
    * Add object methods in weaviate hook
    
    This PR adds methods to the weaviate hook to help get/create/update
    and delete objects
    
    * fixup! Add object methods in weaviate hook
    
    * Add typing
    
    * improve get_or_create_object method
    
    * add object validate
    
    * Improve get_all_objects to return a dataframe if requested
    
    * use conn
    
    * Add object exists method
    
    * fixup! Add object exists method
    
    * Fix docs
    
    * Update get_or_create_object method
---
 airflow/providers/weaviate/hooks/weaviate.py    | 155 ++++++++++++++++++++-
 airflow/providers/weaviate/provider.yaml        |   1 +
 generated/provider_dependencies.json            |   1 +
 tests/providers/weaviate/hooks/test_weaviate.py | 178 ++++++++++++++++++++++++
 4 files changed, 334 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py
index 66d820bbbe..ae32727ffb 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -19,14 +19,23 @@ from __future__ import annotations
 
 import warnings
 from functools import cached_property
-from typing import Any
+from typing import TYPE_CHECKING, Sequence
 
 from weaviate import Client as WeaviateClient
 from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword
+from weaviate.exceptions import ObjectAlreadyExistsException
+from weaviate.util import generate_uuid5
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
 
+if TYPE_CHECKING:
+    from typing import Any
+
+    import pandas as pd
+    from weaviate import ConsistencyLevel
+    from weaviate.types import UUID
+
 
 class WeaviateHook(BaseHook):
     """
@@ -200,3 +209,147 @@ class WeaviateHook(BaseHook):
             .do()
         )
         return results
+
+    def create_object(
+        self, data_object: dict | str, class_name: str, **kwargs
+    ) -> str | dict[str, Any] | None:
+        """Create a new object.
+
+        :param data_object: Object to be added. If type is str it should be either a URL or a file.
+        :param class_name: Class name associated with the object given.
+        :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create()
+        """
+        client = self.conn
+        # generate deterministic uuid if not provided
+        uuid = kwargs.pop("uuid", generate_uuid5(data_object))
+        try:
+            return client.data_object.create(data_object, class_name, uuid=uuid, **kwargs)
+        except ObjectAlreadyExistsException:
+            self.log.warning("Object with the UUID %s already exists", uuid)
+            return None
+
+    def get_or_create_object(
+        self,
+        data_object: dict | str | None = None,
+        class_name: str | None = None,
+        vector: Sequence | None = None,
+        consistency_level: ConsistencyLevel | None = None,
+        tenant: str | None = None,
+        **kwargs,
+    ) -> str | dict[str, Any] | None:
+        """Get or Create a new object.
+
+        Returns the object if already exists
+
+        :param data_object: Object to be added. If type is str it should be either a URL or a file. This is required
+            to create a new object.
+        :param class_name: Class name associated with the object given. This is required to create a new object.
+        :param vector: Vector associated with the object given. This argument is only used when creating object.
+        :param consistency_level: Consistency level to be used. Applies to both create and get operations.
+        :tenant: Tenant to be used. Applies to both create and get operations.
+        :param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() and
+            weaviate_client.data_object.get()
+        """
+        obj = self.get_object(
+            class_name=class_name, consistency_level=consistency_level, tenant=tenant, **kwargs
+        )
+        if not obj:
+            if not (data_object and class_name):
+                raise ValueError("data_object and class_name are required to create a new object")
+            uuid = kwargs.pop("uuid", generate_uuid5(data_object))
+            return self.create_object(
+                data_object,
+                class_name,
+                vector=vector,
+                uuid=uuid,
+                consistency_level=consistency_level,
+                tenant=tenant,
+            )
+        return obj
+
+    def get_object(self, **kwargs) -> dict[str, Any] | None:
+        """Get objects or an object from weaviate.
+
+        :param kwargs: parameters to be passed to weaviate_client.data_object.get() or
+            weaviate_client.data_object.get_by_id()
+        """
+        client = self.conn
+        return client.data_object.get(**kwargs)
+
+    def get_all_objects(
+        self, after: str | UUID | None = None, as_dataframe: bool = False, **kwargs
+    ) -> list[dict[str, Any]] | pd.DataFrame:
+        """Get all objects from weaviate.
+
+        if after is provided, it will be used as the starting point for the listing.
+
+        :param after: uuid of the object to start listing from
+        :param as_dataframe: if True, returns a pandas dataframe
+        :param kwargs: parameters to be passed to weaviate_client.data_object.get()
+        """
+        all_objects = []
+        after = kwargs.pop("after", after)
+        while True:
+            results = self.get_object(after=after, **kwargs) or {}
+            if not results.get("objects"):
+                break
+            all_objects.extend(results["objects"])
+            after = results["objects"][-1]["id"]
+        if as_dataframe:
+            import pandas
+
+            return pandas.DataFrame(all_objects)
+        return all_objects
+
+    def delete_object(self, uuid: UUID | str, **kwargs) -> None:
+        """Delete an object from weaviate.
+
+        :param uuid: uuid of the object to be deleted
+        :param kwargs: Optional parameters to be passed to weaviate_client.data_object.delete()
+        """
+        client = self.conn
+        client.data_object.delete(uuid, **kwargs)
+
+    def update_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None:
+        """Update an object in weaviate.
+
+        :param data_object: The object states the fields that should be updated. Fields not specified in the
+            'data_object' remain unchanged. Fields that are None will not be changed.
+            If type is str it should be either an URL or a file.
+        :param class_name: Class name associated with the object given.
+        :param uuid: uuid of the object to be updated
+        :param kwargs: Optional parameters to be passed to weaviate_client.data_object.update()
+        """
+        client = self.conn
+        client.data_object.update(data_object, class_name, uuid, **kwargs)
+
+    def replace_object(self, data_object: dict | str, class_name: str, uuid: UUID | str, **kwargs) -> None:
+        """Replace an object in weaviate.
+
+        :param data_object: The object states the fields that should be updated. Fields not specified in the
+            'data_object' will be set to None. If type is str it should be either an URL or a file.
+        :param class_name: Class name associated with the object given.
+        :param uuid: uuid of the object to be replaced
+        :param kwargs: Optional parameters to be passed to weaviate_client.data_object.replace()
+        """
+        client = self.conn
+        client.data_object.replace(data_object, class_name, uuid, **kwargs)
+
+    def validate_object(self, data_object: dict | str, class_name: str, **kwargs):
+        """Validate an object in weaviate.
+
+        :param data_object: The object to be validated. If type is str it should be either an URL or a file.
+        :param class_name: Class name associated with the object given.
+        :param kwargs: Optional parameters to be passed to weaviate_client.data_object.validate()
+        """
+        client = self.conn
+        client.data_object.validate(data_object, class_name, **kwargs)
+
+    def object_exists(self, uuid: str | UUID, **kwargs) -> bool:
+        """Check if an object exists in weaviate.
+
+        :param uuid: The UUID of the object that may or may not exist within Weaviate.
+        :param kwargs: Optional parameters to be passed to weaviate_client.data_object.exists()
+        """
+        client = self.conn
+        return client.data_object.exists(uuid, **kwargs)
diff --git a/airflow/providers/weaviate/provider.yaml b/airflow/providers/weaviate/provider.yaml
index 015fa66c6f..86d60ba708 100644
--- a/airflow/providers/weaviate/provider.yaml
+++ b/airflow/providers/weaviate/provider.yaml
@@ -39,6 +39,7 @@ integrations:
 dependencies:
   - apache-airflow>=2.6.0
   - weaviate-client>=3.24.2
+  - pandas>=0.17.1
 
 hooks:
   - integration-name: Weaviate
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index c47736d9ac..24906c94cf 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -949,6 +949,7 @@
   "weaviate": {
     "deps": [
       "apache-airflow>=2.6.0",
+      "pandas>=0.17.1",
       "weaviate-client>=3.24.2"
     ],
     "cross-providers-deps": [],
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py
index 0274004fc0..6f4be77429 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -20,6 +20,7 @@ from unittest import mock
 from unittest.mock import MagicMock, Mock, patch
 
 import pytest
+from weaviate import ObjectAlreadyExistsException
 
 from airflow.models import Connection
 from airflow.providers.weaviate.hooks.weaviate import WeaviateHook
@@ -179,6 +180,183 @@ class TestWeaviateHook:
             additional_headers={},
         )
 
+    @mock.patch("airflow.providers.weaviate.hooks.weaviate.generate_uuid5")
+    def test_create_object(self, mock_gen_uuid, weaviate_hook):
+        """
+        Test the create_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        return_value = weaviate_hook.create_object({"name": "Test"}, "TestClass")
+        mock_gen_uuid.assert_called_once()
+        mock_client.data_object.create.assert_called_once_with(
+            {"name": "Test"}, "TestClass", uuid=mock_gen_uuid.return_value
+        )
+        assert return_value
+
+    def test_create_object_already_exists_return_none(self, weaviate_hook):
+        """
+        Test the create_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        mock_client.data_object.create.side_effect = ObjectAlreadyExistsException
+        return_value = weaviate_hook.create_object({"name": "Test"}, "TestClass")
+        assert return_value is None
+
+    def test_get_object(self, weaviate_hook):
+        """
+        Test the get_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.get_object(class_name="TestClass", uuid="uuid")
+        mock_client.data_object.get.assert_called_once_with(class_name="TestClass", uuid="uuid")
+
+    def test_get_of_get_or_create_object(self, weaviate_hook):
+        """
+        Test the get part of get_or_create_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name="TestClass")
+        mock_client.data_object.get.assert_called_once_with(
+            class_name="TestClass",
+            consistency_level=None,
+            tenant=None,
+        )
+
+    @mock.patch("airflow.providers.weaviate.hooks.weaviate.generate_uuid5")
+    def test_create_of_get_or_create_object(self, mock_gen_uuid, weaviate_hook):
+        """
+        Test the create part of get_or_create_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.get_object = MagicMock(return_value=None)
+        mock_create_object = MagicMock()
+        weaviate_hook.create_object = mock_create_object
+        weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name="TestClass")
+        mock_create_object.assert_called_once_with(
+            {"name": "Test"},
+            "TestClass",
+            uuid=mock_gen_uuid.return_value,
+            consistency_level=None,
+            tenant=None,
+            vector=None,
+        )
+
+    def test_create_of_get_or_create_object_raises_valueerror(self, weaviate_hook):
+        """
+        Test that if data_object is None or class_name is None, ValueError is raised.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.get_object = MagicMock(return_value=None)
+        mock_create_object = MagicMock()
+        weaviate_hook.create_object = mock_create_object
+        with pytest.raises(ValueError):
+            weaviate_hook.get_or_create_object(data_object=None, class_name="TestClass")
+        with pytest.raises(ValueError):
+            weaviate_hook.get_or_create_object(data_object={"name": "Test"}, class_name=None)
+
+    def test_get_all_objects(self, weaviate_hook):
+        """
+        Test the get_all_objects method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        objects = [
+            {"deprecations": None, "objects": [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]},
+            {"deprecations": None, "objects": []},
+        ]
+        mock_get_object = MagicMock()
+        weaviate_hook.get_object = mock_get_object
+        mock_get_object.side_effect = objects
+
+        return_value = weaviate_hook.get_all_objects(class_name="TestClass")
+        assert weaviate_hook.get_object.call_args_list == [
+            mock.call(after=None, class_name="TestClass"),
+            mock.call(after=3, class_name="TestClass"),
+        ]
+        assert return_value == [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]
+
+    def test_get_all_objects_returns_dataframe(self, weaviate_hook):
+        """
+        Test the get_all_objects method of WeaviateHook can return a dataframe.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        objects = [
+            {"deprecations": None, "objects": [{"name": "Test1", "id": 2}, {"name": "Test2", "id": 3}]},
+            {"deprecations": None, "objects": []},
+        ]
+        mock_get_object = MagicMock()
+        weaviate_hook.get_object = mock_get_object
+        mock_get_object.side_effect = objects
+
+        return_value = weaviate_hook.get_all_objects(class_name="TestClass", as_dataframe=True)
+        assert weaviate_hook.get_object.call_args_list == [
+            mock.call(after=None, class_name="TestClass"),
+            mock.call(after=3, class_name="TestClass"),
+        ]
+        import pandas
+
+        assert isinstance(return_value, pandas.DataFrame)
+
+    def test_delete_object(self, weaviate_hook):
+        """
+        Test the delete_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.delete_object(uuid="uuid", class_name="TestClass")
+        mock_client.data_object.delete.assert_called_once_with("uuid", class_name="TestClass")
+
+    def test_update_object(self, weaviate_hook):
+        """
+        Test the update_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.update_object(
+            uuid="uuid", class_name="TestClass", data_object={"name": "Test"}, tenant="2d"
+        )
+        mock_client.data_object.update.assert_called_once_with(
+            {"name": "Test"}, "TestClass", "uuid", tenant="2d"
+        )
+
+    def test_validate_object(self, weaviate_hook):
+        """
+        Test the validate_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.validate_object(class_name="TestClass", data_object={"name": "Test"}, uuid="2d")
+        mock_client.data_object.validate.assert_called_once_with({"name": "Test"}, "TestClass", uuid="2d")
+
+    def test_replace_object(self, weaviate_hook):
+        """
+        Test the replace_object method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.replace_object(
+            uuid="uuid", class_name="TestClass", data_object={"name": "Test"}, tenant="2d"
+        )
+        mock_client.data_object.replace.assert_called_once_with(
+            {"name": "Test"}, "TestClass", "uuid", tenant="2d"
+        )
+
+    def test_object_exists(self, weaviate_hook):
+        """
+        Test the object_exists method of WeaviateHook.
+        """
+        mock_client = MagicMock()
+        weaviate_hook.get_conn = MagicMock(return_value=mock_client)
+        weaviate_hook.object_exists(class_name="TestClass", uuid="2d")
+        mock_client.data_object.exists.assert_called_once_with("2d", class_name="TestClass")
+
 
 def test_create_class(weaviate_hook):
     """