You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ut...@apache.org on 2023/12/23 06:15:42 UTC
(airflow) branch main updated: Add create_or_replace_document_objects method to weaviate provider (#36177)
This is an automated email from the ASF dual-hosted git repository.
utkarsharma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 75d74b1f3a Add create_or_replace_document_objects method to weaviate provider (#36177)
75d74b1f3a is described below
commit 75d74b1f3a535fdc3624077bde3a34d1abcf641e
Author: Utkarsh Sharma <ut...@gmail.com>
AuthorDate: Sat Dec 23 11:45:35 2023 +0530
Add create_or_replace_document_objects method to weaviate provider (#36177)
* Add CRUD operations arounf schema and class objects
* Handle casees when the properties are not in same order
* Change the methods name and docstring
* Resolve conflicts
* Make sure the retrying is working as expected
* Address PR comments
* Remove retry logic from
* Remove vector_col params and dataframe support
* Remove unwanted retry logic
* Address PR comments
* Resolve ruff-lint issues
* Remove unwanted changes
* Remove unwanted changes
* Change the exception to rety on
* [WIP] Add ingest methods
* Fix static checks
* Optimize for error case
* Fix test docstring and handle missing case
* Remove duplicate rows from dataframe
* Updated the logic for object creation based on documents
* Update docstring for 'create_or_replace_document_objects' method
* refactored code
* Update airflow/providers/weaviate/hooks/weaviate.py
Co-authored-by: Josh Fell <48...@users.noreply.github.com>
* Update airflow/providers/weaviate/hooks/weaviate.py
Co-authored-by: Josh Fell <48...@users.noreply.github.com>
* Remove typo
* Add better names
* Add testcases
* Addressed PR comments
* Fix testcases
* Fix docstring
* Fix docstring
* Address static code issue
* Fix typing issue
* Fix typing issue
---------
Co-authored-by: Josh Fell <48...@users.noreply.github.com>
---
airflow/providers/weaviate/hooks/weaviate.py | 505 +++++++++++++++++++++++-
tests/providers/weaviate/hooks/test_weaviate.py | 205 +++++++++-
2 files changed, 699 insertions(+), 11 deletions(-)
diff --git a/airflow/providers/weaviate/hooks/weaviate.py b/airflow/providers/weaviate/hooks/weaviate.py
index 6c9b1b6787..d0b8db37cb 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -24,9 +24,11 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, cast
import requests
+import weaviate.exceptions
from tenacity import Retrying, retry, retry_if_exception, retry_if_exception_type, stop_after_attempt
from weaviate import Client as WeaviateClient
from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword
+from weaviate.data.replication import ConsistencyLevel
from weaviate.exceptions import ObjectAlreadyExistsException
from weaviate.util import generate_uuid5
@@ -34,10 +36,9 @@ from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
if TYPE_CHECKING:
- from typing import Literal, Sequence
+ from typing import Callable, Collection, Literal, Sequence
import pandas as pd
- from weaviate import ConsistencyLevel
from weaviate.types import UUID
ExitingSchemaOptions = Literal["replace", "fail", "ignore"]
@@ -62,7 +63,7 @@ def check_http_error_is_retryable(exc: BaseException):
class WeaviateHook(BaseHook):
"""
- Interact with Weaviate database to store vectors. This hook uses the `conn_id`.
+ Interact with Weaviate database to store vectors. This hook uses the 'conn_id'.
:param conn_id: The connection id to use when connecting to Weaviate. <howto/connection:weaviate>
"""
@@ -366,6 +367,7 @@ class WeaviateHook(BaseHook):
"""
# When the class properties are not in same order or not the same length. We convert them to dicts
# with property `name` as the key. This way we ensure, the subset is checked.
+
classes_objects = self._convert_properties_to_dict(classes_objects)
exiting_classes_list = self._convert_properties_to_dict(self.get_schema()["classes"])
@@ -383,25 +385,82 @@ class WeaviateHook(BaseHook):
self,
class_name: str,
data: list[dict[str, Any]] | pd.DataFrame,
+ insertion_errors: list,
batch_config_params: dict[str, Any] | None = None,
vector_col: str = "Vector",
+ uuid_col: str = "id",
retry_attempts_per_object: int = 5,
- ) -> None:
+ tenant: str | None = None,
+ ) -> list:
"""
Add multiple objects or object references at once into weaviate.
:param class_name: The name of the class that objects belongs to.
:param data: list or dataframe of objects we want to add.
+ :param insertion_errors: list to hold errors while inserting.
:param batch_config_params: dict of batch configuration option.
.. seealso:: `batch_config_params options <https://weaviate-python-client.readthedocs.io/en/v3.25.3/weaviate.batch.html#weaviate.batch.Batch.configure>`__
:param vector_col: name of the column containing the vector.
+ :param uuid_col: Name of the column containing the UUID.
:param retry_attempts_per_object: number of time to try in case of failure before giving up.
+ :param tenant: The tenant to which the object will be added.
"""
+ data = self._convert_dataframe_to_list(data)
+ total_results = 0
+ error_results = 0
+
+ def _process_batch_errors(
+ results: list,
+ verbose: bool = True,
+ ) -> None:
+ """
+ Helper function to processes the results from insert or delete batch operation and collects any errors.
+
+ :param results: Results from the batch operation.
+ :param verbose: Flag to enable verbose logging.
+ """
+ nonlocal total_results
+ nonlocal error_results
+ total_batch_results = len(results)
+ error_batch_results = 0
+ for item in results:
+ if "errors" in item["result"]:
+ error_batch_results = error_batch_results + 1
+ item_error = {"uuid": item["id"], "errors": item["result"]["errors"]}
+ if verbose:
+ self.log.info(
+ "Error occurred in batch process for %s with error %s",
+ item["id"],
+ item["result"]["errors"],
+ )
+ insertion_errors.append(item_error)
+ if verbose:
+ total_results = total_results + (total_batch_results - error_batch_results)
+ error_results = error_results + error_batch_results
+
+ self.log.info(
+ "Total Objects %s / Objects %s successfully inserted and Objects %s had errors.",
+ len(data),
+ total_results,
+ error_results,
+ )
+
client = self.conn
if not batch_config_params:
batch_config_params = {}
+
+ # configuration for context manager for __exit__ method to callback on errors for weaviate
+ # batch ingestion.
+ if not batch_config_params.get("callback"):
+ batch_config_params["callback"] = _process_batch_errors
+
+ if not batch_config_params.get("timeout_retries"):
+ batch_config_params["timeout_retries"] = 5
+
+ if not batch_config_params.get("connection_error_retries"):
+ batch_config_params["connection_error_retries"] = 5
+
client.batch.configure(**batch_config_params)
- data = self._convert_dataframe_to_list(data)
with client.batch as batch:
# Batch import all data
for index, data_obj in enumerate(data):
@@ -413,11 +472,22 @@ class WeaviateHook(BaseHook):
),
):
with attempt:
+ vector = data_obj.pop(vector_col, None)
+ uuid = data_obj.pop(uuid_col, None)
self.log.debug(
- "Attempt %s of importing data: %s", attempt.retry_state.attempt_number, index + 1
+ "Attempt %s of inserting object with uuid: %s",
+ attempt.retry_state.attempt_number,
+ uuid,
)
- vector = data_obj.pop(vector_col, None)
- batch.add_data_object(data_obj, class_name, vector=vector)
+ batch.add_data_object(
+ data_object=data_obj,
+ class_name=class_name,
+ vector=vector,
+ uuid=uuid,
+ tenant=tenant,
+ )
+ self.log.debug("Inserted object with uuid: %s into batch", uuid)
+ return insertion_errors
def query_with_vector(
self,
@@ -499,7 +569,7 @@ class WeaviateHook(BaseHook):
:param class_name: Class name associated with the object given. This is required to create a new object.
:param vector: Vector associated with the object given. This argument is only used when creating object.
:param consistency_level: Consistency level to be used. Applies to both create and get operations.
- :tenant: Tenant to be used. Applies to both create and get operations.
+ :param tenant: Tenant to be used. Applies to both create and get operations.
:param kwargs: Additional parameters to be passed to weaviate_client.data_object.create() and
weaviate_client.data_object.get()
"""
@@ -606,3 +676,420 @@ class WeaviateHook(BaseHook):
"""
client = self.conn
return client.data_object.exists(uuid, **kwargs)
+
+ def _delete_objects(self, uuids: Collection, class_name: str, retry_attempts_per_object: int = 5):
+ """
+ Helper function for `create_or_replace_objects()` to delete multiple objects.
+
+ :param uuids: Collection of uuids.
+ :param class_name: Name of the class in Weaviate schema where data is to be ingested.
+ :param retry_attempts_per_object: number of times to try in case of failure before giving up.
+ """
+ for uuid in uuids:
+ for attempt in Retrying(
+ stop=stop_after_attempt(retry_attempts_per_object),
+ retry=(
+ retry_if_exception(lambda exc: check_http_error_is_retryable(exc))
+ | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+ ),
+ ):
+ with attempt:
+ try:
+ self.delete_object(uuid=uuid, class_name=class_name)
+ self.log.debug("Deleted object with uuid %s", uuid)
+ except weaviate.exceptions.UnexpectedStatusCodeException as e:
+ if e.status_code == 404:
+ self.log.debug("Tried to delete a non existent object with uuid %s", uuid)
+ else:
+ self.log.debug("Error occurred while trying to delete object with uuid %s", uuid)
+ raise e
+
+ self.log.info("Deleted %s objects.", len(uuids))
+
+ def _generate_uuids(
+ self,
+ df: pd.DataFrame,
+ class_name: str,
+ unique_columns: list[str],
+ vector_column: str | None = None,
+ uuid_column: str | None = None,
+ ) -> tuple[pd.DataFrame, str]:
+ """
+ Adds UUIDs to a DataFrame, useful for replace operations where UUIDs must be known before ingestion.
+
+ By default, UUIDs are generated using a custom function if 'uuid_column' is not specified.
+ The function can potentially ingest the same data multiple times with different UUIDs.
+
+ :param df: A dataframe with data to generate a UUID from.
+ :param class_name: The name of the class use as part of the uuid namespace.
+ :param uuid_column: Name of the column to create. Default is 'id'.
+ :param unique_columns: A list of columns to use for UUID generation. By default, all columns except
+ vector_column will be used.
+ :param vector_column: Name of the column containing the vector data. If specified the vector will be
+ removed prior to generating the uuid.
+ """
+ column_names = df.columns.to_list()
+
+ difference_columns = set(unique_columns).difference(set(df.columns.to_list()))
+ if difference_columns:
+ raise ValueError(f"Columns {', '.join(difference_columns)} don't exist in dataframe")
+
+ if uuid_column is None:
+ self.log.info("No uuid_column provided. Generating UUIDs as column name `id`.")
+ if "id" in column_names:
+ raise ValueError(
+ "Property 'id' already in dataset. Consider renaming or specify 'uuid_column'."
+ )
+ else:
+ uuid_column = "id"
+
+ if uuid_column in column_names:
+ raise ValueError(
+ f"Property {uuid_column} already in dataset. Consider renaming or specify a different"
+ f" 'uuid_column'."
+ )
+
+ df[uuid_column] = (
+ df[unique_columns]
+ .drop(columns=[vector_column], inplace=False, errors="ignore")
+ .apply(lambda row: generate_uuid5(identifier=row.to_dict(), namespace=class_name), axis=1)
+ )
+
+ return df, uuid_column
+
+ def _get_documents_to_uuid_map(
+ self,
+ data: pd.DataFrame,
+ document_column: str,
+ uuid_column: str,
+ class_name: str,
+ offset: int = 0,
+ limit: int = 2000,
+ ) -> dict[str, set]:
+ """Helper function to get the document to uuid map of existing objects in db.
+
+ :param data: A single pandas DataFrame.
+ :param document_column: The name of the property to query.
+ :param class_name: The name of the class to query.
+ :param uuid_column: The name of the column containing the UUID.
+ :param offset: pagination parameter to indicate the which object to start fetching data.
+ :param limit: pagination param to indicate the number of records to fetch from start object.
+ """
+ documents_to_uuid: dict = {}
+ document_keys = set(data[document_column])
+ while True:
+ data_objects = (
+ self.conn.query.get(properties=[document_column], class_name=class_name)
+ .with_additional([uuid_column])
+ .with_where(
+ {
+ "operator": "Or",
+ "operands": [
+ {"valueText": key, "path": document_column, "operator": "Equal"}
+ for key in document_keys
+ ],
+ }
+ )
+ .with_offset(offset)
+ .with_limit(limit)
+ .do()["data"]["Get"][class_name]
+ )
+ if len(data_objects) == 0:
+ break
+ offset = offset + limit
+ documents_to_uuid.update(
+ self._prepare_document_to_uuid_map(
+ data=data_objects,
+ group_key=document_column,
+ get_value=lambda x: x["_additional"][uuid_column],
+ )
+ )
+ return documents_to_uuid
+
+ @staticmethod
+ def _prepare_document_to_uuid_map(
+ data: list[dict], group_key: str, get_value: Callable[[dict], str]
+ ) -> dict[str, set]:
+ """Helper function to prepare the map of grouped_key to set."""
+ grouped_key_to_set: dict = {}
+ for item in data:
+ document_url = item[group_key]
+
+ if document_url not in grouped_key_to_set:
+ grouped_key_to_set[document_url] = set()
+
+ grouped_key_to_set[document_url].add(get_value(item))
+ return grouped_key_to_set
+
+ def _get_segregated_documents(
+ self, data: pd.DataFrame, document_column: str, class_name: str, uuid_column: str
+ ) -> tuple[dict[str, set], set, set, set]:
+ """
+ Segregate documents into changed, unchanged and new document, when compared to Weaviate db.
+
+ :param data: A single pandas DataFrame.
+ :param document_column: The name of the property to query.
+ :param class_name: The name of the class to query.
+ :param uuid_column: The name of the column containing the UUID.
+ """
+ changed_documents = set()
+ unchanged_docs = set()
+ new_documents = set()
+ existing_documents_to_uuid = self._get_documents_to_uuid_map(
+ data=data, uuid_column=uuid_column, document_column=document_column, class_name=class_name
+ )
+
+ input_documents_to_uuid = self._prepare_document_to_uuid_map(
+ data=data.to_dict("records"),
+ group_key=document_column,
+ get_value=lambda x: x[uuid_column],
+ )
+
+ # segregate documents into changed, unchanged and non-existing documents.
+ for doc_url, doc_set in input_documents_to_uuid.items():
+ if doc_url in existing_documents_to_uuid:
+ if existing_documents_to_uuid[doc_url] != doc_set:
+ changed_documents.add(doc_url)
+ else:
+ unchanged_docs.add(doc_url)
+ else:
+ new_documents.add(doc_url)
+
+ return existing_documents_to_uuid, changed_documents, unchanged_docs, new_documents
+
+ def _delete_all_documents_objects(
+ self,
+ document_keys: list[str],
+ document_column: str,
+ class_name: str,
+ total_objects_count: int = 1,
+ batch_delete_error: list | None = None,
+ tenant: str | None = None,
+ batch_config_params: dict[str, Any] | None = None,
+ verbose: bool = False,
+ ):
+ """Delete all object that belong to list of documents.
+
+ :param document_keys: list of unique documents identifiers.
+ :param document_column: Column in DataFrame that identifying source document.
+ :param class_name: Name of the class in Weaviate schema where data is to be ingested.
+ :param total_objects_count: total number of objects to delete, needed as max limit on one delete
+ query is 10,000, if we have more objects to delete we need to run query multiple times.
+ :param batch_delete_error: list to hold errors while inserting.
+ :param tenant: The tenant to which the object will be added.
+ :param batch_config_params: Additional parameters for Weaviate batch configuration.
+ :param verbose: Flag to enable verbose output during the ingestion process.
+ """
+ batch_delete_error = batch_delete_error or []
+
+ if not batch_config_params:
+ batch_config_params = {}
+
+ # This limit is imposed by Weavaite database
+ MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS = 10000
+
+ self.conn.batch.configure(**batch_config_params)
+ with self.conn.batch as batch:
+ # ConsistencyLevel.ALL is essential here to guarantee complete deletion of objects
+ # across all nodes. Maintaining this level ensures data integrity, preventing
+ # irrelevant objects from providing misleading context for LLM models.
+ batch.consistency_level = ConsistencyLevel.ALL
+ while total_objects_count > 0:
+ document_objects = batch.delete_objects(
+ class_name=class_name,
+ where={
+ "operator": "Or",
+ "operands": [
+ {
+ "path": [document_column],
+ "operator": "Equal",
+ "valueText": key,
+ }
+ for key in document_keys
+ ],
+ },
+ output="verbose",
+ dry_run=False,
+ tenant=tenant,
+ )
+ total_objects_count = total_objects_count - MAX_LIMIT_ON_TOTAL_DELETABLE_OBJECTS
+ matched_objects = document_objects["results"]["matches"]
+ batch_delete_error = [
+ {"uuid": obj["id"]}
+ for obj in document_objects["results"]["objects"]
+ if "error" in obj["status"]
+ ]
+ if verbose:
+ self.log.info("Deleted %s Objects", matched_objects)
+
+ return batch_delete_error
+
+ def create_or_replace_document_objects(
+ self,
+ data: pd.DataFrame | list[dict[str, Any]] | list[pd.DataFrame],
+ class_name: str,
+ document_column: str,
+ existing: str = "skip",
+ uuid_column: str | None = None,
+ vector_column: str = "Vector",
+ batch_config_params: dict | None = None,
+ tenant: str | None = None,
+ verbose: bool = False,
+ ):
+ """
+ create or replace objects belonging to documents.
+
+ In real-world scenarios, information sources like Airflow docs, Stack Overflow, or other issues
+ are considered 'documents' here. It's crucial to keep the database objects in sync with these sources.
+ If any changes occur in these documents, this function aims to reflect those changes in the database.
+
+ .. note::
+
+ This function assumes responsibility for identifying changes in documents, dropping relevant
+ database objects, and recreating them based on updated information. It's crucial to handle this
+ process with care, ensuring backups and validation are in place to prevent data loss or
+ inconsistencies.
+
+ Provides users with multiple ways of dealing with existing values.
+ replace: replace the existing objects with new objects. This option requires to identify the
+ objects belonging to a document. which by default is done by using document_column field.
+ skip: skip the existing objects and only add the missing objects of a document.
+ error: raise an error if an object belonging to a existing document is tried to be created.
+
+ :param data: A single pandas DataFrame or a list of dicts to be ingested.
+ :param class_name: Name of the class in Weaviate schema where data is to be ingested.
+ :param existing: Strategy for handling existing data: 'skip', or 'replace'. Default is 'skip'.
+ :param document_column: Column in DataFrame that identifying source document.
+ :param uuid_column: Column with pre-generated UUIDs. If not provided, UUIDs will be generated.
+ :param vector_column: Column with embedding vectors for pre-embedded data.
+ :param batch_config_params: Additional parameters for Weaviate batch configuration.
+ :param tenant: The tenant to which the object will be added.
+ :param verbose: Flag to enable verbose output during the ingestion process.
+ :return: list of UUID which failed to create
+ """
+ import pandas as pd
+
+ if existing not in ["skip", "replace", "error"]:
+ raise ValueError("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'error'.")
+
+ if len(data) == 0:
+ return []
+
+ if isinstance(data, list) and isinstance(data[0], dict):
+ # This is done to narrow the type to List[Dict[str, Any].
+ data = pd.json_normalize(cast(List[Dict[str, Any]], data))
+ elif isinstance(data, list) and isinstance(data[0], pd.DataFrame):
+ # This is done to narrow the type to List[pd.DataFrame].
+ data = pd.concat(cast(List[pd.DataFrame], data), ignore_index=True)
+ else:
+ data = cast(pd.DataFrame, data)
+
+ unique_columns = sorted(data.columns.to_list())
+
+ if verbose:
+ self.log.info("%s objects came in for insertion.", data.shape[0])
+
+ if uuid_column is None or uuid_column not in data.columns:
+ (
+ data,
+ uuid_column,
+ ) = self._generate_uuids(
+ df=data,
+ class_name=class_name,
+ unique_columns=unique_columns,
+ vector_column=vector_column,
+ uuid_column=uuid_column,
+ )
+
+ # drop duplicate rows, using uuid_column and unique_columns. Removed `None` as it can be added to
+ # set when `uuid_column` is None.
+ data = data.drop_duplicates(subset=[document_column, uuid_column], keep="first")
+ if verbose:
+ self.log.info("%s objects remain after deduplication.", data.shape[0])
+
+ batch_delete_error: list = []
+ (
+ documents_to_uuid_map,
+ changed_documents,
+ unchanged_documents,
+ new_documents,
+ ) = self._get_segregated_documents(
+ data=data,
+ document_column=document_column,
+ uuid_column=uuid_column,
+ class_name=class_name,
+ )
+ if verbose:
+ self.log.info(
+ "Found %s changed documents, %s unchanged documents and %s non-existing documents",
+ len(changed_documents),
+ len(unchanged_documents),
+ len(new_documents),
+ )
+ for document in changed_documents:
+ self.log.info(
+ "Changed document: %s has %s objects.", document, len(documents_to_uuid_map[document])
+ )
+
+ self.log.info("Non-existing document: %s", ", ".join(new_documents))
+
+ if existing == "error" and len(changed_documents):
+ raise ValueError(
+ f"Documents {', '.join(changed_documents)} already exists. You can either skip or replace"
+ f" them by passing 'existing=skip' or 'existing=replace' respectively."
+ )
+ elif existing == "skip":
+ data = data[data[document_column].isin(new_documents)]
+ if verbose:
+ self.log.info(
+ "Since existing=skip, ingesting only non-existing document's object %s", data.shape[0]
+ )
+ elif existing == "replace":
+ total_objects_count = sum([len(documents_to_uuid_map[doc]) for doc in changed_documents])
+ if verbose:
+ self.log.info(
+ "Since existing='replace', deleting %s objects belonging changed documents %s",
+ total_objects_count,
+ changed_documents,
+ )
+ batch_delete_error = self._delete_all_documents_objects(
+ document_keys=list(changed_documents),
+ total_objects_count=total_objects_count,
+ document_column=document_column,
+ class_name=class_name,
+ batch_delete_error=batch_delete_error,
+ tenant=tenant,
+ batch_config_params=batch_config_params,
+ verbose=verbose,
+ )
+ data = data[data[document_column].isin(new_documents.union(changed_documents))]
+ self.log.info("Batch inserting %s objects for non-existing and changed documents.", data.shape[0])
+
+ insertion_errors: list = []
+ if data.shape[0]:
+ insertion_errors = self.batch_data(
+ class_name=class_name,
+ data=data,
+ insertion_errors=insertion_errors,
+ batch_config_params=batch_config_params,
+ vector_col=vector_column,
+ uuid_col=uuid_column,
+ tenant=tenant,
+ )
+ if insertion_errors or batch_delete_error:
+ if insertion_errors:
+ self.log.info("Failed to insert %s objects.", len(insertion_errors))
+ if batch_delete_error:
+ self.log.info("Failed to delete %s objects.", len(insertion_errors))
+ # Rollback object that were not created properly
+ self._delete_objects(
+ [item["uuid"] for item in insertion_errors + batch_delete_error], class_name=class_name
+ )
+
+ if verbose:
+ self.log.info(
+ "Total objects in class %s : %s ",
+ class_name,
+ self.conn.query.aggregate(class_name).with_meta_count().do(),
+ )
+ return insertion_errors, batch_delete_error
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py b/tests/providers/weaviate/hooks/test_weaviate.py
index fc6d7db6ae..5b8d5ce65d 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -428,7 +428,7 @@ def test_batch_data(data, expected_length, weaviate_hook):
test_class_name = "TestClass"
# Test the batch_data method
- weaviate_hook.batch_data(test_class_name, data)
+ weaviate_hook.batch_data(test_class_name, data, insertion_errors=[])
# Assert that the batch_data method was called with the correct arguments
mock_client.batch.configure.assert_called_once()
@@ -446,7 +446,7 @@ def test_batch_data_retry(get_conn, weaviate_hook):
error.response = response
side_effect = [None, error, None, error, None]
get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect = side_effect
- weaviate_hook.batch_data("TestClass", data)
+ weaviate_hook.batch_data("TestClass", data, insertion_errors=[])
assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect)
@@ -670,3 +670,204 @@ def test_contains_schema(get_schema, classes_to_test, expected_result, weaviate_
]
}
assert weaviate_hook.check_subset_of_schema(classes_to_test) == expected_result
+
+
+@mock.patch("weaviate.util.generate_uuid5")
+def test___generate_uuids(generate_uuid5, weaviate_hook):
+ df = pd.DataFrame.from_dict({"name": ["ross", "bob"], "age": ["12", "22"], "gender": ["m", "m"]})
+ with pytest.raises(ValueError, match=r"Columns last_name don't exist in dataframe"):
+ weaviate_hook._generate_uuids(
+ df=df, class_name="test", unique_columns=["name", "age", "gender", "last_name"]
+ )
+
+ df = pd.DataFrame.from_dict(
+ {"id": [1, 2], "name": ["ross", "bob"], "age": ["12", "22"], "gender": ["m", "m"]}
+ )
+ with pytest.raises(
+ ValueError, match=r"Property 'id' already in dataset. Consider renaming or specify" r" 'uuid_column'"
+ ):
+ weaviate_hook._generate_uuids(df=df, class_name="test", unique_columns=["name", "age", "gender"])
+
+ with pytest.raises(
+ ValueError,
+ match=r"Property age already in dataset. Consider renaming or specify" r" a different 'uuid_column'.",
+ ):
+ weaviate_hook._generate_uuids(
+ df=df, uuid_column="age", class_name="test", unique_columns=["name", "age", "gender"]
+ )
+
+
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.delete_object")
+def test__delete_objects(delete_object, weaviate_hook):
+ resp = requests.Response()
+ resp.status_code = 429
+ requests.exceptions.HTTPError(response=resp)
+ http_429_exception = requests.exceptions.HTTPError(response=resp)
+
+ resp = requests.Response()
+ resp.status_code = 404
+ not_found_exception = weaviate.exceptions.UnexpectedStatusCodeException(
+ message="object not found", response=resp
+ )
+
+ delete_object.side_effect = [not_found_exception, None, http_429_exception, http_429_exception, None]
+ weaviate_hook._delete_objects(uuids=["1", "2", "3"], class_name="test")
+ assert delete_object.call_count == 5
+
+
+def test__prepare_document_to_uuid_map(weaviate_hook):
+ input_data = [
+ {"id": "1", "name": "ross", "age": "12", "gender": "m"},
+ {"id": "2", "name": "bob", "age": "22", "gender": "m"},
+ {"id": "3", "name": "joy", "age": "15", "gender": "f"},
+ ]
+ grouped_data = weaviate_hook._prepare_document_to_uuid_map(
+ data=input_data, group_key="gender", get_value=lambda x: x["name"]
+ )
+ assert grouped_data == {"m": {"ross", "bob"}, "f": {"joy"}}
+
+
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._prepare_document_to_uuid_map")
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._get_documents_to_uuid_map")
+def test___get_segregated_documents(_get_documents_to_uuid_map, _prepare_document_to_uuid_map, weaviate_hook):
+ _get_documents_to_uuid_map.return_value = {
+ "abc.doc": {"uuid1", "uuid2", "uuid2"},
+ "xyz.doc": {"uuid4", "uuid5"},
+ "dfg.doc": {"uuid8", "uuid0", "uuid12"},
+ }
+ _prepare_document_to_uuid_map.return_value = {
+ "abc.doc": {"uuid1", "uuid56", "uuid2"},
+ "xyz.doc": {"uuid4", "uuid5"},
+ "hjk.doc": {"uuid8", "uuid0", "uuid12"},
+ }
+ (
+ _,
+ changed_documents,
+ unchanged_docs,
+ new_documents,
+ ) = weaviate_hook._get_segregated_documents(
+ data=pd.DataFrame(),
+ document_column="doc_key",
+ uuid_column="id",
+ class_name="doc",
+ )
+ assert changed_documents == {"abc.doc"}
+ assert unchanged_docs == {"xyz.doc"}
+ assert new_documents == {"hjk.doc"}
+
+
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._get_segregated_documents")
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._generate_uuids")
+def test_error_option_of_create_or_replace_document_objects(
+ _generate_uuids, _get_segregated_documents, weaviate_hook
+):
+ df = pd.DataFrame.from_dict(
+ {
+ "id": ["1", "2", "3"],
+ "name": ["ross", "bob", "joy"],
+ "age": ["12", "22", "15"],
+ "gender": ["m", "m", "f"],
+ "doc": ["abc.xml", "zyx.html", "zyx.html"],
+ }
+ )
+
+ _get_segregated_documents.return_value = ({}, {"abc.xml"}, {}, {"zyx.html"})
+ _generate_uuids.return_value = (df, "id")
+ with pytest.raises(
+ ValueError, match="Documents abc.xml already exists. You can either" " skip or replace"
+ ):
+ weaviate_hook.create_or_replace_document_objects(
+ data=df, document_column="doc", class_name="test", existing="error"
+ )
+
+
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._delete_objects")
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.batch_data")
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._get_segregated_documents")
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._generate_uuids")
+def test_skip_option_of_create_or_replace_document_objects(
+ _generate_uuids, _get_segregated_documents, batch_data, _delete_objects, weaviate_hook
+):
+ df = pd.DataFrame.from_dict(
+ {
+ "id": ["1", "2", "3"],
+ "name": ["ross", "bob", "joy"],
+ "age": ["12", "22", "15"],
+ "gender": ["m", "m", "f"],
+ "doc": ["abc.xml", "zyx.html", "zyx.html"],
+ }
+ )
+
+ class_name = "test"
+ documents_to_uuid_map, changed_documents, unchanged_documents, new_documents = (
+ {},
+ {"abc.xml"},
+ {},
+ {"zyx.html"},
+ )
+ _get_segregated_documents.return_value = (
+ documents_to_uuid_map,
+ changed_documents,
+ unchanged_documents,
+ new_documents,
+ )
+ _generate_uuids.return_value = (df, "id")
+
+ weaviate_hook.create_or_replace_document_objects(
+ data=df, class_name=class_name, existing="skip", document_column="doc"
+ )
+
+ pd.testing.assert_frame_equal(
+ batch_data.call_args_list[0].kwargs["data"], df[df["doc"].isin(new_documents)]
+ )
+
+
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._delete_all_documents_objects")
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.batch_data")
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._get_segregated_documents")
+@mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook._generate_uuids")
+def test_replace_option_of_create_or_replace_document_objects(
+ _generate_uuids, _get_segregated_documents, batch_data, _delete_all_documents_objects, weaviate_hook
+):
+ df = pd.DataFrame.from_dict(
+ {
+ "id": ["1", "2", "3"],
+ "name": ["ross", "bob", "joy"],
+ "age": ["12", "22", "15"],
+ "gender": ["m", "m", "f"],
+ "doc": ["abc.xml", "zyx.html", "zyx.html"],
+ }
+ )
+
+ class_name = "test"
+ documents_to_uuid_map, changed_documents, unchanged_documents, new_documents = (
+ {"abc.xml": {"uuid"}},
+ {"abc.xml"},
+ {},
+ {"zyx.html"},
+ )
+ batch_data.return_value = []
+ _get_segregated_documents.return_value = (
+ documents_to_uuid_map,
+ changed_documents,
+ unchanged_documents,
+ new_documents,
+ )
+ _generate_uuids.return_value = (df, "id")
+ weaviate_hook.create_or_replace_document_objects(
+ data=df, class_name=class_name, existing="replace", document_column="doc"
+ )
+ _delete_all_documents_objects.assert_called_with(
+ document_keys=list(changed_documents),
+ total_objects_count=1,
+ document_column="doc",
+ class_name="test",
+ batch_delete_error=[],
+ tenant=None,
+ batch_config_params=None,
+ verbose=False,
+ )
+ pd.testing.assert_frame_equal(
+ batch_data.call_args_list[0].kwargs["data"],
+ df[df["doc"].isin(changed_documents.union(new_documents))],
+ )