You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ut...@apache.org on 2023/12/08 11:29:45 UTC

(airflow) branch main updated: Add retry mechanism and dataframe support for WeaviateIngestOperator (#36085)

This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a8333b778a Add retry mechanism and dataframe support for WeaviateIngestOperator (#36085)
a8333b778a is described below

commit a8333b778ac2ec905d6f51ab408e807d1294bd5a
Author: Utkarsh Sharma <ut...@gmail.com>
AuthorDate: Fri Dec 8 16:59:38 2023 +0530

    Add retry mechanism and dataframe support for WeaviateIngestOperator (#36085)
    
    * Add retry and dataframe support
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
 airflow/providers/weaviate/hooks/weaviate.py       | 60 ++++++++++++++++++----
 airflow/providers/weaviate/operators/weaviate.py   | 44 ++++++++++++----
 .../operators/weaviate.rst                         |  2 +-
 tests/providers/weaviate/hooks/test_weaviate.py    | 31 +++++++++--
 .../providers/weaviate/operators/test_weaviate.py  |  8 +--
 5 files changed, 119 insertions(+), 26 deletions(-)

diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py
index ae32727ffb..8a68b1c3f0 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -17,10 +17,14 @@
 
 from __future__ import annotations
 
+import contextlib
+import json
 import warnings
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Dict, List, cast
 
+import requests
+from tenacity import Retrying, retry_if_exception, stop_after_attempt
 from weaviate import Client as WeaviateClient
 from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword
 from weaviate.exceptions import ObjectAlreadyExistsException
@@ -30,7 +34,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
 
 if TYPE_CHECKING:
-    from typing import Any
+    from typing import Sequence
 
     import pandas as pd
     from weaviate import ConsistencyLevel
@@ -144,22 +148,60 @@ class WeaviateHook(BaseHook):
         client = self.conn
         client.schema.create(schema_json)
 
+    @staticmethod
+    def check_http_error_should_retry(exc: BaseException):
+        return isinstance(exc, requests.HTTPError) and not exc.response.ok
+
+    @staticmethod
+    def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame) -> list[dict[str, Any]]:
+        """Helper function to convert dataframe to list of dicts.
+
+        In scenario where Pandas isn't installed and we pass data as a list of dictionaries, importing
+        Pandas will fail, which is invalid. This function handles this scenario.
+        """
+        with contextlib.suppress(ImportError):
+            import pandas
+
+            if isinstance(data, pandas.DataFrame):
+                data = json.loads(data.to_json(orient="records"))
+        return cast(List[Dict[str, Any]], data)
+
     def batch_data(
-        self, class_name: str, data: list[dict[str, Any]], batch_config_params: dict[str, Any] | None = None
+        self,
+        class_name: str,
+        data: list[dict[str, Any]] | pd.DataFrame,
+        batch_config_params: dict[str, Any] | None = None,
+        vector_col: str = "Vector",
+        retry_attempts_per_object: int = 5,
     ) -> None:
+        """
+        Add multiple objects or object references at once into weaviate.
+
+        :param class_name: The name of the class that objects belongs to.
+        :param data: list or dataframe of objects we want to add.
+        :param batch_config_params: dict of batch configuration option.
+            .. seealso:: `batch_config_params options <https://weaviate-python-client.readthedocs.io/en/v3.25.3/weaviate.batch.html#weaviate.batch.Batch.configure>`__
+        :param vector_col: name of the column containing the vector.
+        :param retry_attempts_per_object: number of time to try in case of failure before giving up.
+        """
         client = self.conn
         if not batch_config_params:
             batch_config_params = {}
         client.batch.configure(**batch_config_params)
+        data = self._convert_dataframe_to_list(data)
         with client.batch as batch:
             # Batch import all data
             for index, data_obj in enumerate(data):
-                self.log.debug("importing data: %s", index + 1)
-                vector = data_obj.pop("Vector", None)
-                if vector is not None:
-                    batch.add_data_object(data_obj, class_name, vector=vector)
-                else:
-                    batch.add_data_object(data_obj, class_name)
+                for attempt in Retrying(
+                    stop=stop_after_attempt(retry_attempts_per_object),
+                    retry=retry_if_exception(self.check_http_error_should_retry),
+                ):
+                    with attempt:
+                        self.log.debug(
+                            "Attempt %s of importing data: %s", attempt.retry_state.attempt_number, index + 1
+                        )
+                        vector = data_obj.pop(vector_col, None)
+                        batch.add_data_object(data_obj, class_name, vector=vector)
 
     def delete_class(self, class_name: str) -> None:
         """Delete an existing class."""
diff --git a/airflow/providers/weaviate/operators/weaviate.py b/airflow/providers/weaviate/operators/weaviate.py
index fdf2f37f16..4e07a59edb 100644
--- a/airflow/providers/weaviate/operators/weaviate.py
+++ b/airflow/providers/weaviate/operators/weaviate.py
@@ -17,13 +17,17 @@
 
 from __future__ import annotations
 
+import warnings
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.models import BaseOperator
 from airflow.providers.weaviate.hooks.weaviate import WeaviateHook
 
 if TYPE_CHECKING:
+    import pandas as pd
+
     from airflow.utils.context import Context
 
 
@@ -35,14 +39,16 @@ class WeaviateIngestOperator(BaseOperator):
         For more information on how to use this operator, take a look at the guide:
         :ref:`howto/operator:WeaviateIngestOperator`
 
-    Operator that accepts input json to generate embeddings on or accepting provided custom vectors
-    and store them in the Weaviate class.
+    Operator that accepts input json or pandas dataframe to generate embeddings on or accepting provided
+    custom vectors and store them in the Weaviate class.
 
     :param conn_id: The Weaviate connection.
     :param class_name: The Weaviate class to be used for storing the data objects into.
-    :param input_json: The JSON representing Weaviate data objects to generate embeddings on (or provides
-        custom vectors) and store them in the Weaviate class. Either input_json or input_callable should be
-        provided.
+    :param input_data: The list of dicts or pandas dataframe representing Weaviate data objects to generate
+        embeddings on (or provides custom vectors) and store them in the Weaviate class.
+    :param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on (or provides
+        custom vectors) and store them in the Weaviate class.
+    :param vector_col: key/column name in which the vectors are stored.
     """
 
     template_fields: Sequence[str] = ("input_json",)
@@ -51,15 +57,30 @@ class WeaviateIngestOperator(BaseOperator):
         self,
         conn_id: str,
         class_name: str,
-        input_json: list[dict[str, Any]],
+        input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
+        input_data: list[dict[str, Any]] | pd.DataFrame | None = None,
+        vector_col: str = "Vector",
         **kwargs: Any,
     ) -> None:
         self.batch_params = kwargs.pop("batch_params", {})
         self.hook_params = kwargs.pop("hook_params", {})
+
         super().__init__(**kwargs)
         self.class_name = class_name
         self.conn_id = conn_id
-        self.input_json = input_json
+        self.vector_col = vector_col
+
+        if input_data is not None:
+            self.input_data = input_data
+        elif input_json is not None:
+            warnings.warn(
+                "Passing 'input_json' to WeaviateIngestOperator is deprecated and"
+                " you should use 'input_data' instead",
+                AirflowProviderDeprecationWarning,
+            )
+            self.input_data = input_json
+        else:
+            raise TypeError("Either input_json or input_data is required")
 
     @cached_property
     def hook(self) -> WeaviateHook:
@@ -67,5 +88,10 @@ class WeaviateIngestOperator(BaseOperator):
         return WeaviateHook(conn_id=self.conn_id, **self.hook_params)
 
     def execute(self, context: Context) -> None:
-        self.log.debug("Input json: %s", self.input_json)
-        self.hook.batch_data(self.class_name, self.input_json, **self.batch_params)
+        self.log.debug("Input data: %s", self.input_data)
+        self.hook.batch_data(
+            self.class_name,
+            self.input_data,
+            **self.batch_params,
+            vector_col=self.vector_col,
+        )
diff --git a/docs/apache-airflow-providers-weaviate/operators/weaviate.rst b/docs/apache-airflow-providers-weaviate/operators/weaviate.rst
index 05063376f4..5ec262ab7a 100644
--- a/docs/apache-airflow-providers-weaviate/operators/weaviate.rst
+++ b/docs/apache-airflow-providers-weaviate/operators/weaviate.rst
@@ -28,7 +28,7 @@ into the database.
 Using the Operator
 ^^^^^^^^^^^^^^^^^^
 
-The WeaviateIngestOperator requires the ``input_text`` as an input to the operator. Use the ``conn_id`` parameter to specify the Weaviate connection to use to
+The WeaviateIngestOperator requires the ``input_data`` as an input to the operator. Use the ``conn_id`` parameter to specify the Weaviate connection to use to
 connect to your account.
 
 An example using the operator to ingest data with custom vectors retrieved from XCOM:
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py
index 6f4be77429..acda7e9c2e 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -19,7 +19,9 @@ from __future__ import annotations
 from unittest import mock
 from unittest.mock import MagicMock, Mock, patch
 
+import pandas as pd
 import pytest
+import requests
 from weaviate import ObjectAlreadyExistsException
 
 from airflow.models import Connection
@@ -404,7 +406,15 @@ def test_create_schema(weaviate_hook):
     mock_client.schema.create.assert_called_once_with(test_schema_json)
 
 
-def test_batch_data(weaviate_hook):
+@pytest.mark.parametrize(
+    argnames=["data", "expected_length"],
+    argvalues=[
+        ([{"name": "John"}, {"name": "Jane"}], 2),
+        (pd.DataFrame.from_dict({"name": ["John", "Jane"]}), 2),
+    ],
+    ids=("data as list of dicts", "data as dataframe"),
+)
+def test_batch_data(data, expected_length, weaviate_hook):
     """
     Test the batch_data method of WeaviateHook.
     """
@@ -414,12 +424,25 @@ def test_batch_data(weaviate_hook):
 
     # Define test data
     test_class_name = "TestClass"
-    test_data = [{"name": "John"}, {"name": "Jane"}]
 
     # Test the batch_data method
-    weaviate_hook.batch_data(test_class_name, test_data)
+    weaviate_hook.batch_data(test_class_name, data)
 
     # Assert that the batch_data method was called with the correct arguments
     mock_client.batch.configure.assert_called_once()
     mock_batch_context = mock_client.batch.__enter__.return_value
-    assert mock_batch_context.add_data_object.call_count == len(test_data)
+    assert mock_batch_context.add_data_object.call_count == expected_length
+
+
+@patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn")
+def test_batch_data_retry(get_conn, weaviate_hook):
+    """Test to ensure retrying working as expected"""
+    data = [{"name": "chandler"}, {"name": "joey"}, {"name": "ross"}]
+    response = requests.Response()
+    response.status_code = 429
+    error = requests.exceptions.HTTPError()
+    error.response = response
+    side_effect = [None, error, None, error, None]
+    get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect = side_effect
+    weaviate_hook.batch_data("TestClass", data)
+    assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect)
diff --git a/tests/providers/weaviate/operators/test_weaviate.py b/tests/providers/weaviate/operators/test_weaviate.py
index 5099e05766..7490b64dc6 100644
--- a/tests/providers/weaviate/operators/test_weaviate.py
+++ b/tests/providers/weaviate/operators/test_weaviate.py
@@ -36,7 +36,7 @@ class TestWeaviateIngestOperator:
     def test_constructor(self, operator):
         assert operator.conn_id == "weaviate_conn"
         assert operator.class_name == "my_class"
-        assert operator.input_json == {"data": "sample_data"}
+        assert operator.input_data == {"data": "sample_data"}
         assert operator.batch_params == {}
         assert operator.hook_params == {}
 
@@ -46,5 +46,7 @@ class TestWeaviateIngestOperator:
 
         operator.execute(context=None)
 
-        operator.hook.batch_data.assert_called_once_with("my_class", {"data": "sample_data"}, **{})
-        mock_log.debug.assert_called_once_with("Input json: %s", {"data": "sample_data"})
+        operator.hook.batch_data.assert_called_once_with(
+            "my_class", {"data": "sample_data"}, vector_col="Vector", **{}
+        )
+        mock_log.debug.assert_called_once_with("Input data: %s", {"data": "sample_data"})