import asyncio
from collections.abc import Sequence
from typing import Optional
from pharia_data_sdk.connectors.document_index.document_index import (
AsyncDocumentIndexClient,
CollectionPath,
DocumentIndexClient,
DocumentPath,
DocumentTextPosition,
Filters,
SearchQuery,
)
from pharia_data_sdk.connectors.retrievers.base_retriever import (
AsyncBaseRetriever,
BaseRetriever,
Document,
DocumentChunk,
SearchResult,
)
[docs]
class DocumentIndexRetriever(BaseRetriever[DocumentPath]):
"""Search through documents within collections in the `DocumentIndexClient`.
This retriever lets you search for relevant documents in the given Document Index collection.
Example:
>>> import os
>>> from pharia_data_sdk.connectors import DocumentIndexClient, DocumentIndexRetriever
>>> document_index = DocumentIndexClient(os.getenv("AA_TOKEN"))
>>> retriever = DocumentIndexRetriever(document_index, "asymmetric", "aleph-alpha", "wikipedia-de", 3)
>>> documents = retriever.get_relevant_documents_with_scores("Who invented the airplane?")
"""
def __init__(
self,
document_index: DocumentIndexClient,
index_name: str,
namespace: str,
collection: str,
k: int = 1,
threshold: float = 0.0,
) -> None:
"""Initialize the DocumentIndexRetriever.
Args:
document_index: The Document Index client.
index_name: The name of the Document Index index to use.
namespace: The Document Index namespace.
collection: The Document Index collection to use. This is the search context for the retriever.
k: The number of most-relevant documents to return when searching. Defaults to 1.
threshold: The minimum score for search results. For semantic indexes, this is the cosine
similarity between the query and the document chunk. For hybrid indexes, this corresponds
to fusion rank. Defaults to 0.0.
"""
self._document_index = document_index
self._index_name = index_name
self._collection_path = CollectionPath(
namespace=namespace, collection=collection
)
self._k = k
self._threshold = threshold
def _get_absolute_position(
self, id: DocumentPath, document_text_position: DocumentTextPosition
) -> dict[str, int]:
doc = self._document_index.document(id)
previous_item_length = sum(
len(text) for text in doc.contents[0 : document_text_position.item]
)
start = previous_item_length + document_text_position.start_position
end = previous_item_length + document_text_position.end_position
return {"start": start, "end": end}
[docs]
def get_relevant_documents_with_scores(
self, query: str, filters: Optional[list[Filters]] = None
) -> Sequence[SearchResult[DocumentPath]]:
search_query = SearchQuery(
query=query, max_results=self._k, min_score=self._threshold, filters=filters
)
response = self._document_index.search(
self._collection_path, self._index_name, search_query
)
relevant_chunks = [
SearchResult(
id=result.document_path,
score=result.score,
document_chunk=DocumentChunk(
text=result.section,
**self._get_absolute_position(
id=result.document_path,
document_text_position=result.chunk_position,
),
),
)
for result in response
]
return relevant_chunks
[docs]
def get_full_document(self, id: DocumentPath) -> Document:
contents = self._document_index.document(id)
return Document(text="\n".join(contents.contents), metadata=contents.metadata)
[docs]
class AsyncDocumentIndexRetriever(AsyncBaseRetriever[DocumentPath]):
def __init__(
self,
document_index: AsyncDocumentIndexClient,
index_name: str,
namespace: str,
collection: str,
k: int = 1,
threshold: float = 0.0,
) -> None:
self._document_index = document_index
self._index_name = index_name
self._collection_path = CollectionPath(
namespace=namespace, collection=collection
)
self._k = k
self._threshold = threshold
async def _get_absolute_position(
self, id: DocumentPath, document_text_position: DocumentTextPosition
) -> dict[str, int]:
doc = await self._document_index.document(id)
previous_item_length = sum(
len(text) for text in doc.contents[0 : document_text_position.item]
)
start = previous_item_length + document_text_position.start_position
end = previous_item_length + document_text_position.end_position
return {"start": start, "end": end}
[docs]
async def get_relevant_documents_with_scores(
self, query: str, filters: Optional[list[Filters]] = None
) -> Sequence[SearchResult[DocumentPath]]:
search_query = SearchQuery(
query=query, max_results=self._k, min_score=self._threshold, filters=filters
)
response = await self._document_index.search(
self._collection_path, self._index_name, search_query
)
position_tasks = [
self._get_absolute_position(result.document_path, result.chunk_position)
for result in response
]
positions = await asyncio.gather(*position_tasks)
relevant_chunks = [
SearchResult(
id=result.document_path,
score=result.score,
document_chunk=DocumentChunk(text=result.section, **position),
)
for result, position in zip(response, positions, strict=False)
]
return relevant_chunks
[docs]
async def get_full_document(self, id: DocumentPath) -> Document:
contents = await self._document_index.document(id)
return Document(text="\n".join(contents.contents), metadata=contents.metadata)