You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ut...@apache.org on 2023/12/08 11:29:45 UTC
(airflow) branch main updated: Add retry mechanism and dataframe support for WeaviateIngestOperator (#36085)
This is an automated email from the ASF dual-hosted git repository.
utkarsharma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a8333b778a Add retry mechanism and dataframe support for WeaviateIngestOperator (#36085)
a8333b778a is described below
commit a8333b778ac2ec905d6f51ab408e807d1294bd5a
Author: Utkarsh Sharma <ut...@gmail.com>
AuthorDate: Fri Dec 8 16:59:38 2023 +0530
Add retry mechanism and dataframe support for WeaviateIngestOperator (#36085)
* Add retry and dataframe support
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
airflow/providers/weaviate/hooks/weaviate.py | 60 ++++++++++++++++++----
airflow/providers/weaviate/operators/weaviate.py | 44 ++++++++++++----
.../operators/weaviate.rst | 2 +-
tests/providers/weaviate/hooks/test_weaviate.py | 31 +++++++++--
.../providers/weaviate/operators/test_weaviate.py | 8 +--
5 files changed, 119 insertions(+), 26 deletions(-)
diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py
index ae32727ffb..8a68b1c3f0 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -17,10 +17,14 @@
from __future__ import annotations
+import contextlib
+import json
import warnings
from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Dict, List, cast
+import requests
+from tenacity import Retrying, retry_if_exception, stop_after_attempt
from weaviate import Client as WeaviateClient
from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword
from weaviate.exceptions import ObjectAlreadyExistsException
@@ -30,7 +34,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
if TYPE_CHECKING:
- from typing import Any
+ from typing import Sequence
import pandas as pd
from weaviate import ConsistencyLevel
@@ -144,22 +148,60 @@ class WeaviateHook(BaseHook):
client = self.conn
client.schema.create(schema_json)
+ @staticmethod
+ def check_http_error_should_retry(exc: BaseException):
+ return isinstance(exc, requests.HTTPError) and not exc.response.ok
+
+ @staticmethod
+ def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame) -> list[dict[str, Any]]:
+ """Helper function to convert dataframe to list of dicts.
+
+ In scenario where Pandas isn't installed and we pass data as a list of dictionaries, importing
+ Pandas will fail, which is invalid. This function handles this scenario.
+ """
+ with contextlib.suppress(ImportError):
+ import pandas
+
+ if isinstance(data, pandas.DataFrame):
+ data = json.loads(data.to_json(orient="records"))
+ return cast(List[Dict[str, Any]], data)
+
def batch_data(
- self, class_name: str, data: list[dict[str, Any]], batch_config_params: dict[str, Any] | None = None
+ self,
+ class_name: str,
+ data: list[dict[str, Any]] | pd.DataFrame,
+ batch_config_params: dict[str, Any] | None = None,
+ vector_col: str = "Vector",
+ retry_attempts_per_object: int = 5,
) -> None:
+ """
+ Add multiple objects or object references at once into weaviate.
+
+ :param class_name: The name of the class that objects belongs to.
+ :param data: list or dataframe of objects we want to add.
+ :param batch_config_params: dict of batch configuration option.
+ .. seealso:: `batch_config_params options <https://weaviate-python-client.readthedocs.io/en/v3.25.3/weaviate.batch.html#weaviate.batch.Batch.configure>`__
+ :param vector_col: name of the column containing the vector.
+ :param retry_attempts_per_object: number of time to try in case of failure before giving up.
+ """
client = self.conn
if not batch_config_params:
batch_config_params = {}
client.batch.configure(**batch_config_params)
+ data = self._convert_dataframe_to_list(data)
with client.batch as batch:
# Batch import all data
for index, data_obj in enumerate(data):
- self.log.debug("importing data: %s", index + 1)
- vector = data_obj.pop("Vector", None)
- if vector is not None:
- batch.add_data_object(data_obj, class_name, vector=vector)
- else:
- batch.add_data_object(data_obj, class_name)
+ for attempt in Retrying(
+ stop=stop_after_attempt(retry_attempts_per_object),
+ retry=retry_if_exception(self.check_http_error_should_retry),
+ ):
+ with attempt:
+ self.log.debug(
+ "Attempt %s of importing data: %s", attempt.retry_state.attempt_number, index + 1
+ )
+ vector = data_obj.pop(vector_col, None)
+ batch.add_data_object(data_obj, class_name, vector=vector)
def delete_class(self, class_name: str) -> None:
"""Delete an existing class."""
diff --git a/airflow/providers/weaviate/operators/weaviate.py b/airflow/providers/weaviate/operators/weaviate.py
index fdf2f37f16..4e07a59edb 100644
--- a/airflow/providers/weaviate/operators/weaviate.py
+++ b/airflow/providers/weaviate/operators/weaviate.py
@@ -17,13 +17,17 @@
from __future__ import annotations
+import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.weaviate.hooks.weaviate import WeaviateHook
if TYPE_CHECKING:
+ import pandas as pd
+
from airflow.utils.context import Context
@@ -35,14 +39,16 @@ class WeaviateIngestOperator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:WeaviateIngestOperator`
- Operator that accepts input json to generate embeddings on or accepting provided custom vectors
- and store them in the Weaviate class.
+ Operator that accepts input json or pandas dataframe to generate embeddings on or accepting provided
+ custom vectors and store them in the Weaviate class.
:param conn_id: The Weaviate connection.
:param class_name: The Weaviate class to be used for storing the data objects into.
- :param input_json: The JSON representing Weaviate data objects to generate embeddings on (or provides
- custom vectors) and store them in the Weaviate class. Either input_json or input_callable should be
- provided.
+ :param input_data: The list of dicts or pandas dataframe representing Weaviate data objects to generate
+ embeddings on (or provides custom vectors) and store them in the Weaviate class.
+ :param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on (or provides
+ custom vectors) and store them in the Weaviate class.
+ :param vector_col: key/column name in which the vectors are stored.
"""
template_fields: Sequence[str] = ("input_json",)
@@ -51,15 +57,30 @@ class WeaviateIngestOperator(BaseOperator):
self,
conn_id: str,
class_name: str,
- input_json: list[dict[str, Any]],
+ input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
+ input_data: list[dict[str, Any]] | pd.DataFrame | None = None,
+ vector_col: str = "Vector",
**kwargs: Any,
) -> None:
self.batch_params = kwargs.pop("batch_params", {})
self.hook_params = kwargs.pop("hook_params", {})
+
super().__init__(**kwargs)
self.class_name = class_name
self.conn_id = conn_id
- self.input_json = input_json
+ self.vector_col = vector_col
+
+ if input_data is not None:
+ self.input_data = input_data
+ elif input_json is not None:
+ warnings.warn(
+ "Passing 'input_json' to WeaviateIngestOperator is deprecated and"
+ " you should use 'input_data' instead",
+ AirflowProviderDeprecationWarning,
+ )
+ self.input_data = input_json
+ else:
+ raise TypeError("Either input_json or input_data is required")
@cached_property
def hook(self) -> WeaviateHook:
@@ -67,5 +88,10 @@ class WeaviateIngestOperator(BaseOperator):
return WeaviateHook(conn_id=self.conn_id, **self.hook_params)
def execute(self, context: Context) -> None:
- self.log.debug("Input json: %s", self.input_json)
- self.hook.batch_data(self.class_name, self.input_json, **self.batch_params)
+ self.log.debug("Input data: %s", self.input_data)
+ self.hook.batch_data(
+ self.class_name,
+ self.input_data,
+ **self.batch_params,
+ vector_col=self.vector_col,
+ )
diff --git a/docs/apache-airflow-providers-weaviate/operators/weaviate.rst b/docs/apache-airflow-providers-weaviate/operators/weaviate.rst
index 05063376f4..5ec262ab7a 100644
--- a/docs/apache-airflow-providers-weaviate/operators/weaviate.rst
+++ b/docs/apache-airflow-providers-weaviate/operators/weaviate.rst
@@ -28,7 +28,7 @@ into the database.
Using the Operator
^^^^^^^^^^^^^^^^^^
-The WeaviateIngestOperator requires the ``input_text`` as an input to the operator. Use the ``conn_id`` parameter to specify the Weaviate connection to use to
+The WeaviateIngestOperator requires the ``input_data`` as an input to the operator. Use the ``conn_id`` parameter to specify the Weaviate connection to use to
connect to your account.
An example using the operator to ingest data with custom vectors retrieved from XCOM:
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py
index 6f4be77429..acda7e9c2e 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -19,7 +19,9 @@ from __future__ import annotations
from unittest import mock
from unittest.mock import MagicMock, Mock, patch
+import pandas as pd
import pytest
+import requests
from weaviate import ObjectAlreadyExistsException
from airflow.models import Connection
@@ -404,7 +406,15 @@ def test_create_schema(weaviate_hook):
mock_client.schema.create.assert_called_once_with(test_schema_json)
-def test_batch_data(weaviate_hook):
+@pytest.mark.parametrize(
+ argnames=["data", "expected_length"],
+ argvalues=[
+ ([{"name": "John"}, {"name": "Jane"}], 2),
+ (pd.DataFrame.from_dict({"name": ["John", "Jane"]}), 2),
+ ],
+ ids=("data as list of dicts", "data as dataframe"),
+)
+def test_batch_data(data, expected_length, weaviate_hook):
"""
Test the batch_data method of WeaviateHook.
"""
@@ -414,12 +424,25 @@ def test_batch_data(weaviate_hook):
# Define test data
test_class_name = "TestClass"
- test_data = [{"name": "John"}, {"name": "Jane"}]
# Test the batch_data method
- weaviate_hook.batch_data(test_class_name, test_data)
+ weaviate_hook.batch_data(test_class_name, data)
# Assert that the batch_data method was called with the correct arguments
mock_client.batch.configure.assert_called_once()
mock_batch_context = mock_client.batch.__enter__.return_value
- assert mock_batch_context.add_data_object.call_count == len(test_data)
+ assert mock_batch_context.add_data_object.call_count == expected_length
+
+
+@patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn")
+def test_batch_data_retry(get_conn, weaviate_hook):
+ """Test to ensure retrying working as expected"""
+ data = [{"name": "chandler"}, {"name": "joey"}, {"name": "ross"}]
+ response = requests.Response()
+ response.status_code = 429
+ error = requests.exceptions.HTTPError()
+ error.response = response
+ side_effect = [None, error, None, error, None]
+ get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect = side_effect
+ weaviate_hook.batch_data("TestClass", data)
+ assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect)
diff --git a/tests/providers/weaviate/operators/test_weaviate.py b/tests/providers/weaviate/operators/test_weaviate.py
index 5099e05766..7490b64dc6 100644
--- a/tests/providers/weaviate/operators/test_weaviate.py
+++ b/tests/providers/weaviate/operators/test_weaviate.py
@@ -36,7 +36,7 @@ class TestWeaviateIngestOperator:
def test_constructor(self, operator):
assert operator.conn_id == "weaviate_conn"
assert operator.class_name == "my_class"
- assert operator.input_json == {"data": "sample_data"}
+ assert operator.input_data == {"data": "sample_data"}
assert operator.batch_params == {}
assert operator.hook_params == {}
@@ -46,5 +46,7 @@ class TestWeaviateIngestOperator:
operator.execute(context=None)
- operator.hook.batch_data.assert_called_once_with("my_class", {"data": "sample_data"}, **{})
- mock_log.debug.assert_called_once_with("Input json: %s", {"data": "sample_data"})
+ operator.hook.batch_data.assert_called_once_with(
+ "my_class", {"data": "sample_data"}, vector_col="Vector", **{}
+ )
+ mock_log.debug.assert_called_once_with("Input data: %s", {"data": "sample_data"})