Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): reranking backend integrated in with rag #1090

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
b63f085
Initial reranking setup
CollectiveUnicorn Sep 20, 2024
09fd6c3
Naive reranking implemented with query
CollectiveUnicorn Sep 20, 2024
f14e22d
Adds endpoint to check current rag configuraiton
CollectiveUnicorn Sep 20, 2024
5a9bb34
Fixes typo
CollectiveUnicorn Sep 20, 2024
10163ac
Adds route to fast api router
CollectiveUnicorn Sep 23, 2024
a82b596
Fixes issue in endpoint configs
CollectiveUnicorn Sep 23, 2024
6b6eb82
Ensures that the class level variable has a default value
CollectiveUnicorn Sep 23, 2024
7337762
Makes the enable_reranking var a classvar
CollectiveUnicorn Sep 23, 2024
88cdbca
Creates separate response type
CollectiveUnicorn Sep 23, 2024
75c42f8
Additional comments
CollectiveUnicorn Sep 24, 2024
329a296
Cleans up comments and uses correct class for post requests
CollectiveUnicorn Sep 24, 2024
d19be30
Adds the model config to the search endpoint so that it can be passed…
CollectiveUnicorn Sep 24, 2024
be75d01
Adds output to evaluate reranking, refactors class
CollectiveUnicorn Sep 24, 2024
ebe977c
Adds more output
CollectiveUnicorn Sep 24, 2024
2e85b5e
More logging
CollectiveUnicorn Sep 24, 2024
40a8061
Refactors logging and adds additional outputs
CollectiveUnicorn Sep 24, 2024
d74ba5f
Updates the similarity measure after reranking
CollectiveUnicorn Sep 24, 2024
8ad5216
Adds additional logging
CollectiveUnicorn Sep 24, 2024
4634ed0
Improves readability of logging
CollectiveUnicorn Sep 25, 2024
5e4837a
Simply debug output for further readability
CollectiveUnicorn Sep 25, 2024
9e051d7
Change user prompt to system prompt
CollectiveUnicorn Sep 25, 2024
e8316c1
Replaces custom reranker with library and llm with FlashRank
CollectiveUnicorn Sep 25, 2024
b355e86
Fixes invalid dictionary index
CollectiveUnicorn Sep 25, 2024
77b7249
Adds more ranking models and configuration for ranking models
CollectiveUnicorn Sep 25, 2024
75e70fd
Adds score and rank to search item response
CollectiveUnicorn Sep 25, 2024
2db8263
Ensures that the configured model is used when ranking
CollectiveUnicorn Sep 25, 2024
f46b599
Removes transformers options from rerankers
CollectiveUnicorn Sep 25, 2024
811861e
Replaces duplicate global class with singleton class
CollectiveUnicorn Sep 25, 2024
c3e86c8
Returns transformers and removes rankllm
CollectiveUnicorn Sep 25, 2024
f9273a3
Switch to flashrank exlcusively, separate retrieval vs reranking topk
CollectiveUnicorn Sep 25, 2024
8d4d829
Adds disclaimer to support ranking model list
CollectiveUnicorn Sep 25, 2024
8e435be
Removes todo
CollectiveUnicorn Sep 25, 2024
e61fda9
Adds configure endpoint API test
CollectiveUnicorn Sep 25, 2024
0a429a7
Changes comment formatting
CollectiveUnicorn Sep 25, 2024
a42a94a
Swap method order in test
CollectiveUnicorn Sep 25, 2024
6591d48
Adds RAG + reranking e2e test
CollectiveUnicorn Sep 25, 2024
c245c94
Refactors test_routes api test
CollectiveUnicorn Sep 26, 2024
8259220
Refactors logging, var names, and removes unneeded logs
CollectiveUnicorn Sep 26, 2024
3083301
Merge branch 'main' into 1089-feat-reranking-backend-integrated-in-wi…
CollectiveUnicorn Sep 26, 2024
61cad5e
Adds dev flag for the configure endpoint
CollectiveUnicorn Sep 26, 2024
ac8f58e
Adds dev_only decorator for fast_api dev endpoints
CollectiveUnicorn Sep 26, 2024
713e7e8
Removes unnecessary function
CollectiveUnicorn Sep 26, 2024
1433299
Removes unneeded imports and functions params
CollectiveUnicorn Sep 26, 2024
bb3b926
Bumps eval deps versions and adds check for empty rag results
CollectiveUnicorn Sep 27, 2024
70d864e
Removes dev_only flag in favor of a simpler solution
CollectiveUnicorn Sep 27, 2024
93e7d9d
Updates documentation and simplifies setting and getting configuratio…
CollectiveUnicorn Sep 27, 2024
27f6306
Merge branch 'main' into 1089-feat-reranking-backend-integrated-in-wi…
CollectiveUnicorn Sep 27, 2024
fd54aa2
Fixes get request for config
CollectiveUnicorn Sep 27, 2024
d2db2c7
Ensure that the singleton gets updated
CollectiveUnicorn Sep 27, 2024
fd4bea1
Adds zarf configs for dev mode
CollectiveUnicorn Sep 27, 2024
0872a5e
Make deep copy to prevent issues with variables overwriting
CollectiveUnicorn Sep 27, 2024
0fb97b8
Fixes update logic
CollectiveUnicorn Sep 27, 2024
b86fdf3
Moves the update function to the payload class
CollectiveUnicorn Sep 27, 2024
825df2a
Prevents default overwriting
CollectiveUnicorn Sep 27, 2024
cec2a26
Prevents default overwriting
CollectiveUnicorn Sep 27, 2024
588b3ec
Adds to configuration test
CollectiveUnicorn Sep 27, 2024
e55e68f
Merge branch 'main' into 1089-feat-reranking-backend-integrated-in-wi…
CollectiveUnicorn Sep 30, 2024
f8e6d20
Update test_rag_files.py
CollectiveUnicorn Sep 30, 2024
5a23a51
Adds small fixes
CollectiveUnicorn Sep 30, 2024
0fb0353
Merge branch 'main' into 1089-feat-reranking-backend-integrated-in-wi…
CollectiveUnicorn Sep 30, 2024
0db333e
Ruff linting
CollectiveUnicorn Oct 1, 2024
d87e247
Merge branch 'main' into 1089-feat-reranking-backend-integrated-in-wi…
CollectiveUnicorn Oct 1, 2024
2cbe4bd
Fix unnecessary environment variables
CollectiveUnicorn Oct 1, 2024
01e77df
Swaps env out with new helper function
CollectiveUnicorn Oct 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
run: make test-api-unit
env:
LFAI_RUN_REPEATER_TESTS: true
DEV: true

integration:
runs-on: ai-ubuntu-big-boy-8-core
Expand Down
2 changes: 2 additions & 0 deletions packages/api/chart/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ api:
value: "*.toml"
- name: DEFAULT_EMBEDDINGS_MODEL
value: "text-embeddings"
- name: DEV
value: "false"
- name: PORT
value: "8080"
- name: SUPABASE_URL
Expand Down
2 changes: 2 additions & 0 deletions packages/api/values/registry1-values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ api:
value: "*.toml"
- name: DEFAULT_EMBEDDINGS_MODEL
value: "###ZARF_VAR_DEFAULT_EMBEDDINGS_MODEL###"
- name: DEV
value: "###ZARF_VAR_DEV###"
- name: PORT
value: "8080"
- name: SUPABASE_URL
Expand Down
2 changes: 2 additions & 0 deletions packages/api/values/upstream-values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ api:
value: "*.toml"
- name: DEFAULT_EMBEDDINGS_MODEL
value: "###ZARF_VAR_DEFAULT_EMBEDDINGS_MODEL###"
- name: DEV
value: "###ZARF_VAR_DEV###"
- name: PORT
value: "8080"
- name: SUPABASE_URL
Expand Down
3 changes: 3 additions & 0 deletions packages/api/zarf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ variables:
description: "Flag to expose the OpenAPI schema for debugging."
- name: DEFAULT_EMBEDDINGS_MODEL
default: "text-embeddings"
- name: DEV
default: "false"
description: "Flag to enable development endpoints."

components:
- name: leapfrogai-api
Expand Down
69 changes: 69 additions & 0 deletions src/leapfrogai_api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,72 @@ See the ["Access" section of the DEVELOPMENT.md](../../docs/DEVELOPMENT.md#acces
### Tests

See the [tests directory documentation](../../tests/README.md) for more details.

### Reranking Configuration

The LeapfrogAI API includes a Retrieval Augmented Generation (RAG) pipeline for enhanced question answering. This section details how to configure its reranking options. All RAG configurations are managed through the `/leapfrogai/v1/rag/configure` API endpoint.

#### 1. Enabling/Disabling Reranking

Reranking improves the accuracy and relevance of RAG responses. You can enable or disable it using the `enable_reranking` parameter:

* **Enable Reranking:** Send a PATCH request to `/leapfrogai/v1/rag/configure` with the following JSON payload:

```json
{
"enable_reranking": true
}
```

* **Disable Reranking:** Send a PATCH request with:

```json
{
"enable_reranking": false
}
```

#### 2. Selecting a Reranking Model

Multiple reranking models are supported, each offering different performance characteristics. Choose your preferred model using the `ranking_model` parameter. Ensure you've installed any necessary Python dependencies for your chosen model (see the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) on dependencies).

* **Supported Models:** The system supports several models, including (but not limited to) `flashrank`, `rankllm`, `cross-encoder`, and `colbert`. Refer to the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) for a complete list and details on their capabilities.

* **Model Selection:** Use a PATCH request to `/leapfrogai/v1/rag/configure` with the desired model:

```json
{
"enable_reranking": true, // Reranking must be enabled
"ranking_model": "rankllm" // Or another supported model
}
```

#### 3. Adjusting the Number of Results Before Reranking (`rag_top_k_when_reranking`)

This parameter sets the number of top results retrieved from the vector database *before* the reranking process begins. A higher value increases the diversity of candidates considered for reranking but also increases processing time. A lower value can lead to missing relevant results if not carefully chosen. This setting is only relevant when reranking is enabled.

* **Configuration:** Use a PATCH request to `/leapfrogai/v1/rag/configure` to set this value:

```json
{
"enable_reranking": true,
"ranking_model": "flashrank",
"rag_top_k_when_reranking": 150 // Adjust this value as needed
}
```

#### 4. Retrieving the Current RAG Configuration

To check the current RAG configuration (including reranking status, model, and `rag_top_k_when_reranking`), send a GET request to `/leapfrogai/v1/rag/configure`. The response will be a JSON object containing all the current settings.

#### 5. Example Configuration Flow

1. **Initial Setup:** Start with reranking enabled using the default `flashrank` model and a `rag_top_k_when_reranking` value of 100.

2. **Experiment with Models:** Test different reranking models (`rankllm`, `colbert`, etc.) by changing the `ranking_model` parameter and observing the impact on response quality. Adjust `rag_top_k_when_reranking` as needed to find the optimal balance between diversity and performance.

3. **Fine-tuning:** Once you identify a suitable model, fine-tune the `rag_top_k_when_reranking` parameter for optimal performance. Monitor response times and quality to determine the best setting.

4. **Disabling Reranking:** If needed, disable reranking by setting `"enable_reranking": false`.

Remember to always consult the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) for information on supported models and their specific requirements. The API documentation provides further details on request formats and potential error responses.
74 changes: 70 additions & 4 deletions src/leapfrogai_api/backend/rag/query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Service for querying the RAG model."""

from rerankers.results import RankedResults
from supabase import AClient as AsyncClient
from langchain_core.embeddings import Embeddings
from leapfrogai_api.backend.rag.leapfrogai_embeddings import LeapfrogAIEmbeddings
from leapfrogai_api.data.crud_vector_content import CRUDVectorContent
from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse
from leapfrogai_api.typedef.rag.rag_types import ConfigurationSingleton
from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse, SearchItem
from leapfrogai_api.backend.constants import TOP_K
from leapfrogai_api.utils.logging_tools import logger
from rerankers import Reranker

# Allows for overwriting type of embeddings that will be instantiated
embeddings_type: type[Embeddings] | type[LeapfrogAIEmbeddings] | None = (
Expand All @@ -22,7 +26,10 @@ def __init__(self, db: AsyncClient) -> None:
self.embeddings = embeddings_type()

async def query_rag(
self, query: str, vector_store_id: str, k: int = TOP_K
self,
query: str,
vector_store_id: str,
k: int = TOP_K,
) -> SearchResponse:
"""
Query the Vector Store.
Expand All @@ -36,11 +43,70 @@ async def query_rag(
SearchResponse: The search response from the vector store.
"""

logger.debug("Beginning RAG query...")

# 1. Embed query
vector = await self.embeddings.aembed_query(query)

# 2. Perform similarity search
_k: int = k
if ConfigurationSingleton.get_instance().enable_reranking:
"""Use the user specified top-k value unless reranking.
When reranking, use the reranking top-k value to get the initial results.
Then filter the list down later to just the k that the user has requested after reranking."""
_k = ConfigurationSingleton.get_instance().rag_top_k_when_reranking

crud_vector_content = CRUDVectorContent(db=self.db)
return await crud_vector_content.similarity_search(
query=vector, vector_store_id=vector_store_id, k=k
results = await crud_vector_content.similarity_search(
query=vector, vector_store_id=vector_store_id, k=_k
)

# 3. Rerank results
if (
ConfigurationSingleton.get_instance().enable_reranking
and len(results.data) > 0
):
ranker = Reranker(ConfigurationSingleton.get_instance().ranking_model)
ranked_results: RankedResults = ranker.rank(
query=query,
docs=[result.content for result in results.data],
doc_ids=[result.id for result in results.data],
)
results = rerank_search_response(results, ranked_results)
# Narrow down the results to the top-k value specified by the user
results.data = results.data[0:k]

logger.debug("Ending RAG query...")

return results


def rerank_search_response(
original_response: SearchResponse, ranked_results: RankedResults
) -> SearchResponse:
"""
Reorder the SearchResponse based on reranked results.

Args:
original_response (SearchResponse): The original search response.
ranked_results (List[str]): List of ranked content strings.

Returns:
SearchResponse: A new SearchResponse with reordered items.
"""
# Create a mapping of id to original SearchItem
content_to_item = {item.id: item for item in original_response.data}

# Create new SearchItems based on reranked results
ranked_items = []
for content in ranked_results.results:
if content.document.doc_id in content_to_item:
item: SearchItem = content_to_item[content.document.doc_id]
item.rank = content.rank
item.score = content.score
ranked_items.append(item)

ranked_response = SearchResponse(data=ranked_items)

# Create a new SearchResponse with reranked items
return ranked_response
3 changes: 3 additions & 0 deletions src/leapfrogai_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from leapfrogai_api.routers.leapfrogai import models as lfai_models
from leapfrogai_api.routers.leapfrogai import vector_stores as lfai_vector_stores
from leapfrogai_api.routers.leapfrogai import count as lfai_token_count
from leapfrogai_api.routers.leapfrogai import rag as lfai_rag
from leapfrogai_api.routers.openai import (
assistants,
audio,
Expand Down Expand Up @@ -81,6 +82,8 @@ async def validation_exception_handler(request, exc):
app.include_router(messages.router)
app.include_router(runs_steps.router)
app.include_router(lfai_vector_stores.router)
if os.environ.get("DEV"):
app.include_router(lfai_rag.router)
app.include_router(lfai_token_count.router)
app.include_router(lfai_models.router)
# This should be at the bottom to prevent it preempting more specific runs endpoints
Expand Down
1 change: 1 addition & 0 deletions src/leapfrogai_api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"postgrest==0.16.11", # required by supabase, bug when using previous versions
"openpyxl == 3.1.5",
"psutil == 6.0.0",
"rerankers[flashrank] == 0.5.3"
]
requires-python = "~=3.11"

Expand Down
56 changes: 56 additions & 0 deletions src/leapfrogai_api/routers/leapfrogai/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""LeapfrogAI endpoints for RAG."""

from fastapi import APIRouter
from leapfrogai_api.typedef.rag.rag_types import (
ConfigurationSingleton,
ConfigurationPayload,
)
from leapfrogai_api.routers.supabase_session import Session
from leapfrogai_api.utils.logging_tools import logger

router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"])


@router.patch("/configure")
async def configure(session: Session, configuration: ConfigurationPayload) -> None:
"""
Configures the RAG settings at runtime.

Args:
session (Session): The database session.
configuration (Configuration): The configuration to update.
"""

# We set the class variable to update the configuration globally
ConfigurationSingleton._instance = ConfigurationSingleton.get_instance().copy(
update=configuration.dict(exclude_none=True)
)


@router.get("/configure")
async def get_configuration(session: Session) -> ConfigurationPayload:
"""
Retrieves the current RAG configuration.

Args:
session (Session): The database session.

Returns:
Configuration: The current RAG configuration.
"""

instance = ConfigurationSingleton.get_instance()

# Create a new dictionary with only the relevant attributes
config_dict = {
key: value
for key, value in instance.__dict__.items()
if not key.startswith("_") # Exclude private attributes
}

# Create a new ConfigurationPayload instance with the filtered dictionary
new_configuration = ConfigurationPayload(**config_dict)

logger.info(f"The current configuration has been set to {new_configuration}")

return new_configuration
4 changes: 1 addition & 3 deletions src/leapfrogai_api/routers/leapfrogai/vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ async def search(
"""
query_service = QueryService(db=session)
return await query_service.query_rag(
query=query,
vector_store_id=vector_store_id,
k=k,
query=query, vector_store_id=vector_store_id, k=k
)


Expand Down
3 changes: 3 additions & 0 deletions src/leapfrogai_api/typedef/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .rag_types import (
ConfigurationSingleton as ConfigurationSingleton,
)
40 changes: 40 additions & 0 deletions src/leapfrogai_api/typedef/rag/rag_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional

from pydantic import BaseModel, Field


class ConfigurationSingleton:
"""Singleton manager for ConfigurationPayload."""

_instance = None

@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = ConfigurationPayload()
cls._instance.enable_reranking = True
cls._instance.rag_top_k_when_reranking = 100
cls._instance.ranking_model = "flashrank"
return cls._instance


class ConfigurationPayload(BaseModel):
"""Response for RAG configuration."""

enable_reranking: Optional[bool] = Field(
default=None,
examples=[True, False],
description="Enables reranking for RAG queries",
)
# More model info can be found here:
# https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file
# https://pypi.org/project/rerankers/
ranking_model: Optional[str] = Field(
default=None,
description="What model to use for reranking. Some options may require additional python dependencies.",
examples=["flashrank", "rankllm", "cross-encoder", "colbert"],
)
rag_top_k_when_reranking: Optional[int] = Field(
default=None,
description="The top-k results returned from the RAG call before reranking",
)
10 changes: 10 additions & 0 deletions src/leapfrogai_api/typedef/vectorstores/search_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from pydantic import BaseModel, Field


Expand Down Expand Up @@ -25,6 +27,14 @@ class SearchItem(BaseModel):
similarity: float = Field(
..., description="Similarity score of this item to the query."
)
rank: Optional[int] = Field(
default=None,
description="The rank of this search item after ranking has occurred.",
)
score: Optional[float] = Field(
default=None,
description="The score of this search item after ranking has occurred.",
)


class SearchResponse(BaseModel):
Expand Down
12 changes: 12 additions & 0 deletions src/leapfrogai_api/utils/logging_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
import logging
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(
level=os.getenv("LFAI_LOG_LEVEL", logging.INFO),
format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s",
)

logger = logging.getLogger(__name__)
Loading