From b63f08526444fc4043495aa268f36447374a125f Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 20 Sep 2024 14:52:56 -0700 Subject: [PATCH 01/59] Initial reranking setup --- src/leapfrogai_api/routers/leapfrogai/rag.py | 21 ++++++++++++++++++++ src/leapfrogai_api/typedef/rag/__init__.py | 3 +++ src/leapfrogai_api/typedef/rag/rag_types.py | 10 ++++++++++ 3 files changed, 34 insertions(+) create mode 100644 src/leapfrogai_api/routers/leapfrogai/rag.py create mode 100644 src/leapfrogai_api/typedef/rag/__init__.py create mode 100644 src/leapfrogai_api/typedef/rag/rag_types.py diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py new file mode 100644 index 000000000..7df678c0c --- /dev/null +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -0,0 +1,21 @@ +"""LeapfrogAI endpoints for RAG.""" + +from fastapi import APIRouter +from leapfrogai_api.typedef.rag.rag_types import Configuration +from leapfrogai_api.routers.supabase_session import Session + +router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) + + +@router.patch("/configure") +async def configure(session: Session, configuration: Configuration): + """ + Configures the RAG settings at runtime. + + Args: + session (Session): The database session. + configuration (Configuration): The configuration to update. + """ + + # We set the class variable to update the configuration globally + Configuration.enable_reranking = configuration.enable_reranking diff --git a/src/leapfrogai_api/typedef/rag/__init__.py b/src/leapfrogai_api/typedef/rag/__init__.py new file mode 100644 index 000000000..003db7f80 --- /dev/null +++ b/src/leapfrogai_api/typedef/rag/__init__.py @@ -0,0 +1,3 @@ +from .rag_types import ( + Configuration as Configuration, +) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py new file mode 100644 index 000000000..7284805a3 --- /dev/null +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + + +class Configuration(BaseModel): + """Configuration for RAG.""" + + enable_reranking: bool = Field( + default=False, + description="Whether to enable reranking", + ) From 09fd6c3b483004aaaf51bc58207cdf27a06ced31 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 20 Sep 2024 15:19:56 -0700 Subject: [PATCH 02/59] Naive reranking implemented with query --- src/leapfrogai_api/backend/rag/query.py | 53 +++++++++++++++++++- src/leapfrogai_api/backend/rag/reranker.py | 56 ++++++++++++++++++++++ 2 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 src/leapfrogai_api/backend/rag/reranker.py diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index e5e0decce..e0f62250f 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -1,11 +1,17 @@ """Service for querying the RAG model.""" +from typing import Annotated, List +from fastapi import Depends from supabase import AClient as AsyncClient from langchain_core.embeddings import Embeddings from leapfrogai_api.backend.rag.leapfrogai_embeddings import LeapfrogAIEmbeddings +from leapfrogai_api.backend.rag.reranker import Reranker from leapfrogai_api.data.crud_vector_content import CRUDVectorContent +from leapfrogai_api.typedef.rag.rag_types import Configuration from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse from leapfrogai_api.backend.constants import TOP_K +from leapfrogai_api.utils import get_model_config +from leapfrogai_api.utils.config import Config # Allows for overwriting type of embeddings that will be instantiated embeddings_type: type[Embeddings] | type[LeapfrogAIEmbeddings] | None = ( @@ -22,7 +28,11 @@ def __init__(self, db: AsyncClient) -> None: self.embeddings = embeddings_type() async def query_rag( - self, query: str, vector_store_id: str, k: int = TOP_K + self, + query: str, + vector_store_id: str, + model_config: Annotated[Config, Depends(get_model_config)], + k: int = TOP_K, ) -> SearchResponse: """ Query the Vector Store. @@ -36,11 +46,50 @@ async def query_rag( SearchResponse: The search response from the vector store. """ + results = SearchResponse(data=[]) + # 1. Embed query vector = await self.embeddings.aembed_query(query) # 2. Perform similarity search crud_vector_content = CRUDVectorContent(db=self.db) - return await crud_vector_content.similarity_search( + results = await crud_vector_content.similarity_search( query=vector, vector_store_id=vector_store_id, k=k ) + + # 3. Rerank results + if Configuration.enable_reranking: + reranker = Reranker(model_config=model_config) + reranked_results: list[str] = await reranker.rerank( + query, [result.content for result in results.data] + ) + results = rerank_search_response(results, reranked_results) + + return results + + +def rerank_search_response( + original_response: SearchResponse, reranked_results: List[str] +) -> SearchResponse: + """ + Reorder the SearchResponse based on reranked results. + + Args: + original_response (SearchResponse): The original search response. + reranked_results (List[str]): List of reranked content strings. + + Returns: + SearchResponse: A new SearchResponse with reordered items. + """ + # Create a mapping of content to original SearchItem + content_to_item = {item.content: item for item in original_response.data} + + # Create new SearchItems based on reranked results + reranked_items = [] + for content in reranked_results: + if content in content_to_item: + item = content_to_item[content] + reranked_items.append(item) + + # Create a new SearchResponse with reranked items + return SearchResponse(data=reranked_items) diff --git a/src/leapfrogai_api/backend/rag/reranker.py b/src/leapfrogai_api/backend/rag/reranker.py new file mode 100644 index 000000000..76ad8705e --- /dev/null +++ b/src/leapfrogai_api/backend/rag/reranker.py @@ -0,0 +1,56 @@ +from typing import List +import leapfrogai_sdk as lfai +from leapfrogai_api.backend.grpc_client import chat_completion +from leapfrogai_api.backend.helpers import grpc_chat_role +from leapfrogai_api.utils.config import Config + + +class Reranker: + def __init__( + self, + model_config: Config, + model: str = "llama-cpp-python", + temperature: float = 0.2, + max_tokens: int = 500, + ): + self.model_config = model_config + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + + async def rerank(self, query: str, documents: List[str]) -> List[str]: + prompt = self._create_rerank_prompt(query, documents) + + chat_items = [lfai.ChatItem(role=grpc_chat_role("user"), content=prompt)] + request = lfai.ChatCompletionRequest( + chat_items=chat_items, + max_new_tokens=self.max_tokens, + temperature=self.temperature, + ) + + model_backend = self.model_config.get_model_backend(self.model) + if model_backend is None: + raise ValueError(f"Model {self.model} not found in configuration") + + response = await chat_completion(model_backend, request) + + reranked_indices = self._parse_rerank_response( + str(response.choices[0].message.content_as_str()), documents + ) + return [documents[i] for i in reranked_indices] + + def _create_rerank_prompt(self, query: str, documents: List[str]) -> str: + # Create a prompt for reranking + doc_list = "\n".join([f"{i+1}. {doc}..." for i, doc in enumerate(documents)]) + return f"Given the query: '{query}', rank the following documents in order of relevance. Return only the numbers of the documents in order of relevance, separated by commas.\n\n{doc_list}" + + def _parse_rerank_response( + self, response: str | None, documents: List[str] + ) -> List[int]: + # Parse the response to get the reranked indices + try: + if response is None: + return list(range(len(documents))) + return [int(i.strip()) - 1 for i in response.split(",")] + except ValueError: + return list(range(len(documents))) # Return original order if parsing fails From f14e22d915c6dc64cd8924ca98c7b913be44cbd7 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 20 Sep 2024 15:33:17 -0700 Subject: [PATCH 03/59] Adds endpoint to check current rag configuraiton --- src/leapfrogai_api/routers/leapfrogai/rag.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 7df678c0c..a1f399763 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -1,8 +1,8 @@ """LeapfrogAI endpoints for RAG.""" -from fastapi import APIRouter +from fastapi import APIRouter, Depends from leapfrogai_api.typedef.rag.rag_types import Configuration -from leapfrogai_api.routers.supabase_session import Session +from leapfrogai_api.routers.supabase_session import Session, get_session router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) @@ -19,3 +19,17 @@ async def configure(session: Session, configuration: Configuration): # We set the class variable to update the configuration globally Configuration.enable_reranking = configuration.enable_reranking + + +@router.get("/configuration") +async def get_configuration(session: Session = Depends(get_session)): + """ + Retrieves the current RAG configuration. + + Args: + session (Session): The database session. + + Returns: + Configuration: The current RAG configuration. + """ + return Configuration(enable_reranking=Configuration.enable_reranking) From 5a9bb34533bbd32278146db69d7d4fe42c3af398 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 20 Sep 2024 15:49:02 -0700 Subject: [PATCH 04/59] Fixes typo --- src/leapfrogai_api/routers/leapfrogai/rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index a1f399763..48fe76dde 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -21,7 +21,7 @@ async def configure(session: Session, configuration: Configuration): Configuration.enable_reranking = configuration.enable_reranking -@router.get("/configuration") +@router.get("/configure") async def get_configuration(session: Session = Depends(get_session)): """ Retrieves the current RAG configuration. From 10163aca197d3537f4aa6c106053d78a54ad2189 Mon Sep 17 00:00:00 2001 From: gharvey Date: Mon, 23 Sep 2024 11:47:27 -0700 Subject: [PATCH 05/59] Adds route to fast api router --- src/leapfrogai_api/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/leapfrogai_api/main.py b/src/leapfrogai_api/main.py index 85822f7f3..ad2f039f5 100644 --- a/src/leapfrogai_api/main.py +++ b/src/leapfrogai_api/main.py @@ -14,6 +14,7 @@ from leapfrogai_api.routers.leapfrogai import models as lfai_models from leapfrogai_api.routers.leapfrogai import vector_stores as lfai_vector_stores from leapfrogai_api.routers.leapfrogai import count as lfai_token_count +from leapfrogai_api.routers.leapfrogai import rag as lfai_rag from leapfrogai_api.routers.openai import ( assistants, audio, @@ -81,6 +82,7 @@ async def validation_exception_handler(request, exc): app.include_router(messages.router) app.include_router(runs_steps.router) app.include_router(lfai_vector_stores.router) +app.include_router(lfai_rag.router) app.include_router(lfai_token_count.router) app.include_router(lfai_models.router) # This should be at the bottom to prevent it preempting more specific runs endpoints From a82b596700c9d067fa9b4bdb4deee1c7ce37a422 Mon Sep 17 00:00:00 2001 From: gharvey Date: Mon, 23 Sep 2024 11:55:49 -0700 Subject: [PATCH 06/59] Fixes issue in endpoint configs --- src/leapfrogai_api/routers/leapfrogai/rag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 48fe76dde..9ba477a7f 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -1,8 +1,8 @@ """LeapfrogAI endpoints for RAG.""" -from fastapi import APIRouter, Depends +from fastapi import APIRouter from leapfrogai_api.typedef.rag.rag_types import Configuration -from leapfrogai_api.routers.supabase_session import Session, get_session +from leapfrogai_api.routers.supabase_session import Session router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) @@ -22,7 +22,7 @@ async def configure(session: Session, configuration: Configuration): @router.get("/configure") -async def get_configuration(session: Session = Depends(get_session)): +async def get_configuration(session: Session): """ Retrieves the current RAG configuration. From 6b6eb825c677b2461e9989cadb601185c1b7af64 Mon Sep 17 00:00:00 2001 From: gharvey Date: Mon, 23 Sep 2024 12:11:38 -0700 Subject: [PATCH 07/59] Ensures that the class level variable has a default value --- src/leapfrogai_api/routers/leapfrogai/rag.py | 1 + src/leapfrogai_api/typedef/rag/rag_types.py | 7 ++----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 9ba477a7f..05c2561b7 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -32,4 +32,5 @@ async def get_configuration(session: Session): Returns: Configuration: The current RAG configuration. """ + return Configuration(enable_reranking=Configuration.enable_reranking) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 7284805a3..e5d3bd31a 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -1,10 +1,7 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel class Configuration(BaseModel): """Configuration for RAG.""" - enable_reranking: bool = Field( - default=False, - description="Whether to enable reranking", - ) + enable_reranking: bool = False From 73377625dfe018a13d7a5e58513086eaa6fc2ced Mon Sep 17 00:00:00 2001 From: gharvey Date: Mon, 23 Sep 2024 12:22:55 -0700 Subject: [PATCH 08/59] Makes the enable_reranking var a classvar --- src/leapfrogai_api/routers/leapfrogai/rag.py | 2 +- src/leapfrogai_api/typedef/rag/rag_types.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 05c2561b7..50d17ea30 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -33,4 +33,4 @@ async def get_configuration(session: Session): Configuration: The current RAG configuration. """ - return Configuration(enable_reranking=Configuration.enable_reranking) + return Configuration() diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index e5d3bd31a..7e2df25fb 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -1,7 +1,9 @@ +from typing import ClassVar + from pydantic import BaseModel class Configuration(BaseModel): """Configuration for RAG.""" - enable_reranking: bool = False + enable_reranking: ClassVar[bool] = False From 88cdbca35cca409e69d44e2b7f4ce7ed363f0c65 Mon Sep 17 00:00:00 2001 From: gharvey Date: Mon, 23 Sep 2024 12:59:51 -0700 Subject: [PATCH 09/59] Creates separate response type --- src/leapfrogai_api/routers/leapfrogai/rag.py | 8 ++++---- src/leapfrogai_api/typedef/rag/rag_types.py | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 50d17ea30..3daea6811 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -1,14 +1,14 @@ """LeapfrogAI endpoints for RAG.""" from fastapi import APIRouter -from leapfrogai_api.typedef.rag.rag_types import Configuration +from leapfrogai_api.typedef.rag.rag_types import Configuration, ConfigurationResponse from leapfrogai_api.routers.supabase_session import Session router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) @router.patch("/configure") -async def configure(session: Session, configuration: Configuration): +async def configure(session: Session, configuration: Configuration) -> None: """ Configures the RAG settings at runtime. @@ -22,7 +22,7 @@ async def configure(session: Session, configuration: Configuration): @router.get("/configure") -async def get_configuration(session: Session): +async def get_configuration(session: Session) -> ConfigurationResponse: """ Retrieves the current RAG configuration. @@ -33,4 +33,4 @@ async def get_configuration(session: Session): Configuration: The current RAG configuration. """ - return Configuration() + return ConfigurationResponse(enable_reranking=Configuration.enable_reranking) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 7e2df25fb..e7ad49270 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -7,3 +7,9 @@ class Configuration(BaseModel): """Configuration for RAG.""" enable_reranking: ClassVar[bool] = False + + +class ConfigurationResponse(BaseModel): + """Response for RAG configuration.""" + + enable_reranking: bool From 75c42f8bd517ae8b3e85ca660a1ae932888f5121 Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 15:04:18 -0700 Subject: [PATCH 10/59] Additional comments --- src/leapfrogai_api/typedef/rag/rag_types.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index e7ad49270..05b3c05bf 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -6,10 +6,21 @@ class Configuration(BaseModel): """Configuration for RAG.""" + # This is a class variable, shared by all instances of Configuration + # It sets a default value, but doesn't create an instance variable enable_reranking: ClassVar[bool] = False + # Note: Pydantic will not create an instance variable for ClassVar fields + # If you need an instance variable, you should declare it separately + class ConfigurationResponse(BaseModel): """Response for RAG configuration.""" + # This is an instance variable, specific to each ConfigurationResponse object + # It will be included in the JSON output when the model is serialized enable_reranking: bool + + +# The separation of Configuration and ConfigurationResponse allows for +# different behavior in input (Configuration) vs output (ConfigurationResponse) From 329a296a3d7fd5f9f4613f11217adf25629aaf61 Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 15:08:58 -0700 Subject: [PATCH 11/59] Cleans up comments and uses correct class for post requests --- src/leapfrogai_api/routers/leapfrogai/rag.py | 8 ++++---- src/leapfrogai_api/typedef/rag/rag_types.py | 14 ++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 3daea6811..bdc8bcf95 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -1,14 +1,14 @@ """LeapfrogAI endpoints for RAG.""" from fastapi import APIRouter -from leapfrogai_api.typedef.rag.rag_types import Configuration, ConfigurationResponse +from leapfrogai_api.typedef.rag.rag_types import Configuration, ConfigurationPayload from leapfrogai_api.routers.supabase_session import Session router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) @router.patch("/configure") -async def configure(session: Session, configuration: Configuration) -> None: +async def configure(session: Session, configuration: ConfigurationPayload) -> None: """ Configures the RAG settings at runtime. @@ -22,7 +22,7 @@ async def configure(session: Session, configuration: Configuration) -> None: @router.get("/configure") -async def get_configuration(session: Session) -> ConfigurationResponse: +async def get_configuration(session: Session) -> ConfigurationPayload: """ Retrieves the current RAG configuration. @@ -33,4 +33,4 @@ async def get_configuration(session: Session) -> ConfigurationResponse: Configuration: The current RAG configuration. """ - return ConfigurationResponse(enable_reranking=Configuration.enable_reranking) + return ConfigurationPayload(enable_reranking=Configuration.enable_reranking) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 05b3c05bf..6339ab258 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -1,6 +1,6 @@ from typing import ClassVar -from pydantic import BaseModel +from pydantic import BaseModel, Field class Configuration(BaseModel): @@ -14,13 +14,11 @@ class Configuration(BaseModel): # If you need an instance variable, you should declare it separately -class ConfigurationResponse(BaseModel): +class ConfigurationPayload(BaseModel): """Response for RAG configuration.""" - # This is an instance variable, specific to each ConfigurationResponse object + # This is an instance variable, specific to each ConfigurationPayload object # It will be included in the JSON output when the model is serialized - enable_reranking: bool - - -# The separation of Configuration and ConfigurationResponse allows for -# different behavior in input (Configuration) vs output (ConfigurationResponse) + enable_reranking: bool = Field( + default=False, description="Enables reranking for RAG queries" + ) From d19be308ea20a61b258ac05479f8b24f2c65cf56 Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 15:24:16 -0700 Subject: [PATCH 12/59] Adds the model config to the search endpoint so that it can be passed down --- .../routers/leapfrogai/vector_stores.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/vector_stores.py b/src/leapfrogai_api/routers/leapfrogai/vector_stores.py index cd2899925..0c8dca4e0 100644 --- a/src/leapfrogai_api/routers/leapfrogai/vector_stores.py +++ b/src/leapfrogai_api/routers/leapfrogai/vector_stores.py @@ -1,10 +1,13 @@ """LeapfrogAI endpoints for RAG.""" -from fastapi import APIRouter +from typing import Annotated + +from fastapi import APIRouter, Depends from leapfrogai_api.backend.rag.query import QueryService from leapfrogai_api.typedef.vectorstores import SearchResponse from leapfrogai_api.routers.supabase_session import Session from leapfrogai_api.backend.constants import TOP_K +from leapfrogai_api.utils import Config, get_model_config router = APIRouter( prefix="/leapfrogai/v1/vector_stores", tags=["leapfrogai/vector_stores"] @@ -14,6 +17,7 @@ @router.post("/search") async def search( session: Session, + model_config: Annotated[Config, Depends(get_model_config)], query: str, vector_store_id: str, k: int = TOP_K, @@ -23,6 +27,7 @@ async def search( Args: session (Session): The database session. + model_config (Config): The current model configuration. query (str): The input query string. vector_store_id (str): The ID of the vector store. k (int, optional): The number of results to retrieve. @@ -32,7 +37,5 @@ async def search( """ query_service = QueryService(db=session) return await query_service.query_rag( - query=query, - vector_store_id=vector_store_id, - k=k, + query=query, vector_store_id=vector_store_id, k=k, model_config=model_config ) From be75d0104c4eefd6e8b7d02a0ae79ec8efacf7dc Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 15:40:08 -0700 Subject: [PATCH 13/59] Adds output to evaluate reranking, refactors class --- src/leapfrogai_api/backend/rag/query.py | 19 +++++++-- src/leapfrogai_api/backend/rag/reranker.py | 49 ++++++++++++++-------- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index e0f62250f..08b3fe572 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -12,6 +12,18 @@ from leapfrogai_api.backend.constants import TOP_K from leapfrogai_api.utils import get_model_config from leapfrogai_api.utils.config import Config +import os +import logging +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig( + level=os.getenv("LFAI_LOG_LEVEL", logging.INFO), + format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s", +) + +logger = logging.getLogger(__name__) # Allows for overwriting type of embeddings that will be instantiated embeddings_type: type[Embeddings] | type[LeapfrogAIEmbeddings] | None = ( @@ -29,15 +41,16 @@ def __init__(self, db: AsyncClient) -> None: async def query_rag( self, + model_config: Annotated[Config, Depends(get_model_config)], query: str, vector_store_id: str, - model_config: Annotated[Config, Depends(get_model_config)], k: int = TOP_K, ) -> SearchResponse: """ Query the Vector Store. Args: + model_config (Config): The current model configuration. query (str): The input query string. vector_store_id (str): The ID of the vector store. k (int, optional): The number of results to retrieve. @@ -46,8 +59,6 @@ async def query_rag( SearchResponse: The search response from the vector store. """ - results = SearchResponse(data=[]) - # 1. Embed query vector = await self.embeddings.aembed_query(query) @@ -91,5 +102,7 @@ def rerank_search_response( item = content_to_item[content] reranked_items.append(item) + logging.info(f"Reranked documents {reranked_items}") + # Create a new SearchResponse with reranked items return SearchResponse(data=reranked_items) diff --git a/src/leapfrogai_api/backend/rag/reranker.py b/src/leapfrogai_api/backend/rag/reranker.py index 76ad8705e..094f94008 100644 --- a/src/leapfrogai_api/backend/rag/reranker.py +++ b/src/leapfrogai_api/backend/rag/reranker.py @@ -3,6 +3,35 @@ from leapfrogai_api.backend.grpc_client import chat_completion from leapfrogai_api.backend.helpers import grpc_chat_role from leapfrogai_api.utils.config import Config +import os +import logging +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig( + level=os.getenv("LFAI_LOG_LEVEL", logging.INFO), + format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s", +) + +logger = logging.getLogger(__name__) + + +def _create_rerank_prompt(query: str, documents: List[str]) -> str: + # Create a prompt for reranking + doc_list = "\n".join([f"{i+1}. {doc}..." for i, doc in enumerate(documents)]) + return f"Given the query: '{query}', rank the following documents in order of relevance. Return only the numbers of the documents in order of relevance, separated by commas.\n\n{doc_list}" + + +def _parse_rerank_response(response: str | None, documents: List[str]) -> List[int]: + # Parse the response to get the reranked indices + try: + if response is None: + return list(range(len(documents))) + return [int(i.strip()) - 1 for i in response.split(",")] + except ValueError: + logger.info("Failed to parse the reranked documents") + return list(range(len(documents))) # Return original order if parsing fails class Reranker: @@ -19,7 +48,7 @@ def __init__( self.max_tokens = max_tokens async def rerank(self, query: str, documents: List[str]) -> List[str]: - prompt = self._create_rerank_prompt(query, documents) + prompt = _create_rerank_prompt(query, documents) chat_items = [lfai.ChatItem(role=grpc_chat_role("user"), content=prompt)] request = lfai.ChatCompletionRequest( @@ -34,23 +63,7 @@ async def rerank(self, query: str, documents: List[str]) -> List[str]: response = await chat_completion(model_backend, request) - reranked_indices = self._parse_rerank_response( + reranked_indices = _parse_rerank_response( str(response.choices[0].message.content_as_str()), documents ) return [documents[i] for i in reranked_indices] - - def _create_rerank_prompt(self, query: str, documents: List[str]) -> str: - # Create a prompt for reranking - doc_list = "\n".join([f"{i+1}. {doc}..." for i, doc in enumerate(documents)]) - return f"Given the query: '{query}', rank the following documents in order of relevance. Return only the numbers of the documents in order of relevance, separated by commas.\n\n{doc_list}" - - def _parse_rerank_response( - self, response: str | None, documents: List[str] - ) -> List[int]: - # Parse the response to get the reranked indices - try: - if response is None: - return list(range(len(documents))) - return [int(i.strip()) - 1 for i in response.split(",")] - except ValueError: - return list(range(len(documents))) # Return original order if parsing fails From ebe977cf4d23ef46ecff5ab0b500daaf278b0754 Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 15:56:35 -0700 Subject: [PATCH 14/59] Adds more output --- src/leapfrogai_api/backend/rag/query.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 08b3fe572..c447bd93c 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -75,6 +75,7 @@ async def query_rag( query, [result.content for result in results.data] ) results = rerank_search_response(results, reranked_results) + logger.info(f"Reranking complete {results}") return results From 2e85b5e2385c38ef2724799391bd6afd6fa3697b Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 16:05:04 -0700 Subject: [PATCH 15/59] More logging --- src/leapfrogai_api/backend/rag/query.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index c447bd93c..39da72331 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -59,6 +59,8 @@ async def query_rag( SearchResponse: The search response from the vector store. """ + logger.info("Beginning RAG query...") + # 1. Embed query vector = await self.embeddings.aembed_query(query) @@ -77,6 +79,8 @@ async def query_rag( results = rerank_search_response(results, reranked_results) logger.info(f"Reranking complete {results}") + logger.info("Ending RAG query...") + return results From 40a80613015e18ba225ead88feb0139100791477 Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 16:33:56 -0700 Subject: [PATCH 16/59] Refactors logging and adds additional outputs --- src/leapfrogai_api/backend/rag/query.py | 15 ++------------- src/leapfrogai_api/backend/rag/reranker.py | 13 +------------ src/leapfrogai_api/routers/leapfrogai/rag.py | 9 ++++++++- src/leapfrogai_api/utils/logging_tools.py | 12 ++++++++++++ 4 files changed, 23 insertions(+), 26 deletions(-) create mode 100644 src/leapfrogai_api/utils/logging_tools.py diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 39da72331..1fddbfa33 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -12,18 +12,7 @@ from leapfrogai_api.backend.constants import TOP_K from leapfrogai_api.utils import get_model_config from leapfrogai_api.utils.config import Config -import os -import logging -from dotenv import load_dotenv - -load_dotenv() - -logging.basicConfig( - level=os.getenv("LFAI_LOG_LEVEL", logging.INFO), - format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s", -) - -logger = logging.getLogger(__name__) +from leapfrogai_api.utils.logging_tools import logger # Allows for overwriting type of embeddings that will be instantiated embeddings_type: type[Embeddings] | type[LeapfrogAIEmbeddings] | None = ( @@ -107,7 +96,7 @@ def rerank_search_response( item = content_to_item[content] reranked_items.append(item) - logging.info(f"Reranked documents {reranked_items}") + logger.info(f"Reranked documents {reranked_items}") # Create a new SearchResponse with reranked items return SearchResponse(data=reranked_items) diff --git a/src/leapfrogai_api/backend/rag/reranker.py b/src/leapfrogai_api/backend/rag/reranker.py index 094f94008..2af78c488 100644 --- a/src/leapfrogai_api/backend/rag/reranker.py +++ b/src/leapfrogai_api/backend/rag/reranker.py @@ -3,18 +3,7 @@ from leapfrogai_api.backend.grpc_client import chat_completion from leapfrogai_api.backend.helpers import grpc_chat_role from leapfrogai_api.utils.config import Config -import os -import logging -from dotenv import load_dotenv - -load_dotenv() - -logging.basicConfig( - level=os.getenv("LFAI_LOG_LEVEL", logging.INFO), - format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s", -) - -logger = logging.getLogger(__name__) +from leapfrogai_api.utils.logging_tools import logger def _create_rerank_prompt(query: str, documents: List[str]) -> str: diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index bdc8bcf95..09a6df8bd 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -3,6 +3,7 @@ from fastapi import APIRouter from leapfrogai_api.typedef.rag.rag_types import Configuration, ConfigurationPayload from leapfrogai_api.routers.supabase_session import Session +from leapfrogai_api.utils.logging_tools import logger router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) @@ -33,4 +34,10 @@ async def get_configuration(session: Session) -> ConfigurationPayload: Configuration: The current RAG configuration. """ - return ConfigurationPayload(enable_reranking=Configuration.enable_reranking) + new_configuration = ConfigurationPayload( + enable_reranking=Configuration.enable_reranking + ) + + logger.info(f"The current configuration has been set to {new_configuration}") + + return new_configuration diff --git a/src/leapfrogai_api/utils/logging_tools.py b/src/leapfrogai_api/utils/logging_tools.py new file mode 100644 index 000000000..aa2448288 --- /dev/null +++ b/src/leapfrogai_api/utils/logging_tools.py @@ -0,0 +1,12 @@ +import os +import logging +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig( + level=os.getenv("LFAI_LOG_LEVEL", logging.INFO), + format="%(name)s: %(asctime)s | %(levelname)s | %(filename)s:%(lineno)s >>> %(message)s", +) + +logger = logging.getLogger(__name__) From d74ba5f5ddce2e5ff0ef55ae7f513c2414dc9aee Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 16:48:42 -0700 Subject: [PATCH 17/59] Updates the similarity measure after reranking --- src/leapfrogai_api/backend/rag/query.py | 12 ++++++++---- src/leapfrogai_api/backend/rag/reranker.py | 7 +++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 1fddbfa33..cff6c5f56 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -8,7 +8,7 @@ from leapfrogai_api.backend.rag.reranker import Reranker from leapfrogai_api.data.crud_vector_content import CRUDVectorContent from leapfrogai_api.typedef.rag.rag_types import Configuration -from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse +from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse, SearchItem from leapfrogai_api.backend.constants import TOP_K from leapfrogai_api.utils import get_model_config from leapfrogai_api.utils.config import Config @@ -91,12 +91,16 @@ def rerank_search_response( # Create new SearchItems based on reranked results reranked_items = [] - for content in reranked_results: + for idx, content in enumerate(reranked_results): if content in content_to_item: - item = content_to_item[content] + item: SearchItem = content_to_item[content] + # Update the similarity to maintain the new order + item.similarity = 1.0 - ((1.0 / len(reranked_results)) * idx) reranked_items.append(item) - logger.info(f"Reranked documents {reranked_items}") + logger.info( + f"Original documents: {original_response}\nReranked documents {reranked_items}" + ) # Create a new SearchResponse with reranked items return SearchResponse(data=reranked_items) diff --git a/src/leapfrogai_api/backend/rag/reranker.py b/src/leapfrogai_api/backend/rag/reranker.py index 2af78c488..328660579 100644 --- a/src/leapfrogai_api/backend/rag/reranker.py +++ b/src/leapfrogai_api/backend/rag/reranker.py @@ -16,10 +16,13 @@ def _parse_rerank_response(response: str | None, documents: List[str]) -> List[i # Parse the response to get the reranked indices try: if response is None: + logger.info("Failed to parse the reranked documents, received no response") return list(range(len(documents))) return [int(i.strip()) - 1 for i in response.split(",")] - except ValueError: - logger.info("Failed to parse the reranked documents") + except ValueError as e: + logger.exception( + f"Failed to parse the reranked documents, an error occurred\n{e}" + ) return list(range(len(documents))) # Return original order if parsing fails From 8ad52167618600924099836f4c55cddd1cf5af86 Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 16:55:38 -0700 Subject: [PATCH 18/59] Adds additional logging --- src/leapfrogai_api/backend/rag/reranker.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/leapfrogai_api/backend/rag/reranker.py b/src/leapfrogai_api/backend/rag/reranker.py index 328660579..2acdf319d 100644 --- a/src/leapfrogai_api/backend/rag/reranker.py +++ b/src/leapfrogai_api/backend/rag/reranker.py @@ -42,6 +42,7 @@ def __init__( async def rerank(self, query: str, documents: List[str]) -> List[str]: prompt = _create_rerank_prompt(query, documents) + # TODO: System prompt needed chat_items = [lfai.ChatItem(role=grpc_chat_role("user"), content=prompt)] request = lfai.ChatCompletionRequest( chat_items=chat_items, @@ -55,6 +56,8 @@ async def rerank(self, query: str, documents: List[str]) -> List[str]: response = await chat_completion(model_backend, request) + logger.info(f"The reranking request has returned {response}") + reranked_indices = _parse_rerank_response( str(response.choices[0].message.content_as_str()), documents ) From 4634ed0123e3ee71ff55b520d044d157232d1adb Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 17:09:04 -0700 Subject: [PATCH 19/59] Improves readability of logging --- src/leapfrogai_api/backend/rag/query.py | 8 +++++--- .../typedef/vectorstores/search_types.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index cff6c5f56..df2ac417d 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -66,7 +66,7 @@ async def query_rag( query, [result.content for result in results.data] ) results = rerank_search_response(results, reranked_results) - logger.info(f"Reranking complete {results}") + logger.info(f"Reranking complete {results.get_response_without_content()}") logger.info("Ending RAG query...") @@ -98,9 +98,11 @@ def rerank_search_response( item.similarity = 1.0 - ((1.0 / len(reranked_results)) * idx) reranked_items.append(item) + reranked_response = SearchResponse(data=reranked_items) + logger.info( - f"Original documents: {original_response}\nReranked documents {reranked_items}" + f"Original documents: {original_response.get_response_without_content()}\nReranked documents {reranked_response.get_response_without_content()}" ) # Create a new SearchResponse with reranked items - return SearchResponse(data=reranked_items) + return reranked_response diff --git a/src/leapfrogai_api/typedef/vectorstores/search_types.py b/src/leapfrogai_api/typedef/vectorstores/search_types.py index 76abb0822..c1deb998f 100644 --- a/src/leapfrogai_api/typedef/vectorstores/search_types.py +++ b/src/leapfrogai_api/typedef/vectorstores/search_types.py @@ -26,3 +26,19 @@ class SearchResponse(BaseModel): description="List of RAG items returned as a result of the query.", min_length=0, ) + + def get_response_without_content(self): + response_without_content: SearchResponse = SearchResponse( + data=[ + SearchItem( + id=item.id, + vector_store_id=item.vector_store_id, + file_id=item.file_id, + content="", + metadata=item.metadata, + similarity=item.similarity, + ) + for item in self.data + ] + ) + return response_without_content From 5e4837add62d8b5c6e86ff71c7a836bf3d552c07 Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 17:20:20 -0700 Subject: [PATCH 20/59] Simply debug output for further readability --- src/leapfrogai_api/backend/rag/query.py | 4 ++-- .../typedef/vectorstores/search_types.py | 19 +++++-------------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index df2ac417d..9c35ad751 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -66,7 +66,7 @@ async def query_rag( query, [result.content for result in results.data] ) results = rerank_search_response(results, reranked_results) - logger.info(f"Reranking complete {results.get_response_without_content()}") + logger.info(f"Reranking complete {results.get_simple_response()}") logger.info("Ending RAG query...") @@ -101,7 +101,7 @@ def rerank_search_response( reranked_response = SearchResponse(data=reranked_items) logger.info( - f"Original documents: {original_response.get_response_without_content()}\nReranked documents {reranked_response.get_response_without_content()}" + f"Original documents: {original_response.get_simple_response()}\nReranked documents {reranked_response.get_simple_response()}" ) # Create a new SearchResponse with reranked items diff --git a/src/leapfrogai_api/typedef/vectorstores/search_types.py b/src/leapfrogai_api/typedef/vectorstores/search_types.py index c1deb998f..9949f01ea 100644 --- a/src/leapfrogai_api/typedef/vectorstores/search_types.py +++ b/src/leapfrogai_api/typedef/vectorstores/search_types.py @@ -27,18 +27,9 @@ class SearchResponse(BaseModel): min_length=0, ) - def get_response_without_content(self): - response_without_content: SearchResponse = SearchResponse( - data=[ - SearchItem( - id=item.id, - vector_store_id=item.vector_store_id, - file_id=item.file_id, - content="", - metadata=item.metadata, - similarity=item.similarity, - ) - for item in self.data - ] - ) + def get_simple_response(self): + response_without_content = [ + {"id": item.id, "similarity": item.similarity} for item in self.data + ] + return response_without_content From 9e051d79f49435bdaeba0945ae144d18b22541af Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 24 Sep 2024 17:32:18 -0700 Subject: [PATCH 21/59] Change user prompt to system prompt --- src/leapfrogai_api/backend/rag/reranker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/reranker.py b/src/leapfrogai_api/backend/rag/reranker.py index 2acdf319d..b880e14ea 100644 --- a/src/leapfrogai_api/backend/rag/reranker.py +++ b/src/leapfrogai_api/backend/rag/reranker.py @@ -42,8 +42,8 @@ def __init__( async def rerank(self, query: str, documents: List[str]) -> List[str]: prompt = _create_rerank_prompt(query, documents) - # TODO: System prompt needed - chat_items = [lfai.ChatItem(role=grpc_chat_role("user"), content=prompt)] + # TODO: Should a system prompt + user prompt be used here? + chat_items = [lfai.ChatItem(role=grpc_chat_role("system"), content=prompt)] request = lfai.ChatCompletionRequest( chat_items=chat_items, max_new_tokens=self.max_tokens, From e8316c1db0566024cabfbcf2319992d17884b9ee Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 11:07:59 -0700 Subject: [PATCH 22/59] Replaces custom reranker with library and llm with FlashRank --- src/leapfrogai_api/backend/rag/query.py | 32 ++++++----- src/leapfrogai_api/backend/rag/reranker.py | 64 ---------------------- src/leapfrogai_api/pyproject.toml | 1 + 3 files changed, 19 insertions(+), 78 deletions(-) delete mode 100644 src/leapfrogai_api/backend/rag/reranker.py diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 9c35ad751..6326c5154 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -1,11 +1,11 @@ """Service for querying the RAG model.""" -from typing import Annotated, List +from typing import Annotated from fastapi import Depends +from rerankers.results import RankedResults from supabase import AClient as AsyncClient from langchain_core.embeddings import Embeddings from leapfrogai_api.backend.rag.leapfrogai_embeddings import LeapfrogAIEmbeddings -from leapfrogai_api.backend.rag.reranker import Reranker from leapfrogai_api.data.crud_vector_content import CRUDVectorContent from leapfrogai_api.typedef.rag.rag_types import Configuration from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse, SearchItem @@ -13,6 +13,7 @@ from leapfrogai_api.utils import get_model_config from leapfrogai_api.utils.config import Config from leapfrogai_api.utils.logging_tools import logger +from rerankers import Reranker # Allows for overwriting type of embeddings that will be instantiated embeddings_type: type[Embeddings] | type[LeapfrogAIEmbeddings] | None = ( @@ -61,11 +62,13 @@ async def query_rag( # 3. Rerank results if Configuration.enable_reranking: - reranker = Reranker(model_config=model_config) - reranked_results: list[str] = await reranker.rerank( - query, [result.content for result in results.data] + ranker = Reranker("flashrank") + ranked_results: RankedResults = ranker.rank( + query=query, + docs=[result.content for result in results.data], + doc_ids=[result.id for result in results.data], ) - results = rerank_search_response(results, reranked_results) + results = rerank_search_response(results, ranked_results) logger.info(f"Reranking complete {results.get_simple_response()}") logger.info("Ending RAG query...") @@ -74,28 +77,29 @@ async def query_rag( def rerank_search_response( - original_response: SearchResponse, reranked_results: List[str] + original_response: SearchResponse, ranked_results: RankedResults ) -> SearchResponse: """ Reorder the SearchResponse based on reranked results. Args: original_response (SearchResponse): The original search response. - reranked_results (List[str]): List of reranked content strings. + ranked_results (List[str]): List of ranked content strings. Returns: SearchResponse: A new SearchResponse with reordered items. """ - # Create a mapping of content to original SearchItem - content_to_item = {item.content: item for item in original_response.data} + # Create a mapping of id to original SearchItem + content_to_item = {item.id: item for item in original_response.data} # Create new SearchItems based on reranked results reranked_items = [] - for idx, content in enumerate(reranked_results): - if content in content_to_item: + for content in ranked_results.results: + if content.document.doc_id in content_to_item: item: SearchItem = content_to_item[content] - # Update the similarity to maintain the new order - item.similarity = 1.0 - ((1.0 / len(reranked_results)) * idx) + # TODO: Find a better way to handle this + # Update the similarity to instead be the rank + item.similarity = content.rank reranked_items.append(item) reranked_response = SearchResponse(data=reranked_items) diff --git a/src/leapfrogai_api/backend/rag/reranker.py b/src/leapfrogai_api/backend/rag/reranker.py deleted file mode 100644 index b880e14ea..000000000 --- a/src/leapfrogai_api/backend/rag/reranker.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import List -import leapfrogai_sdk as lfai -from leapfrogai_api.backend.grpc_client import chat_completion -from leapfrogai_api.backend.helpers import grpc_chat_role -from leapfrogai_api.utils.config import Config -from leapfrogai_api.utils.logging_tools import logger - - -def _create_rerank_prompt(query: str, documents: List[str]) -> str: - # Create a prompt for reranking - doc_list = "\n".join([f"{i+1}. {doc}..." for i, doc in enumerate(documents)]) - return f"Given the query: '{query}', rank the following documents in order of relevance. Return only the numbers of the documents in order of relevance, separated by commas.\n\n{doc_list}" - - -def _parse_rerank_response(response: str | None, documents: List[str]) -> List[int]: - # Parse the response to get the reranked indices - try: - if response is None: - logger.info("Failed to parse the reranked documents, received no response") - return list(range(len(documents))) - return [int(i.strip()) - 1 for i in response.split(",")] - except ValueError as e: - logger.exception( - f"Failed to parse the reranked documents, an error occurred\n{e}" - ) - return list(range(len(documents))) # Return original order if parsing fails - - -class Reranker: - def __init__( - self, - model_config: Config, - model: str = "llama-cpp-python", - temperature: float = 0.2, - max_tokens: int = 500, - ): - self.model_config = model_config - self.model = model - self.temperature = temperature - self.max_tokens = max_tokens - - async def rerank(self, query: str, documents: List[str]) -> List[str]: - prompt = _create_rerank_prompt(query, documents) - - # TODO: Should a system prompt + user prompt be used here? - chat_items = [lfai.ChatItem(role=grpc_chat_role("system"), content=prompt)] - request = lfai.ChatCompletionRequest( - chat_items=chat_items, - max_new_tokens=self.max_tokens, - temperature=self.temperature, - ) - - model_backend = self.model_config.get_model_backend(self.model) - if model_backend is None: - raise ValueError(f"Model {self.model} not found in configuration") - - response = await chat_completion(model_backend, request) - - logger.info(f"The reranking request has returned {response}") - - reranked_indices = _parse_rerank_response( - str(response.choices[0].message.content_as_str()), documents - ) - return [documents[i] for i in reranked_indices] diff --git a/src/leapfrogai_api/pyproject.toml b/src/leapfrogai_api/pyproject.toml index 1085cd5b2..bdbc8aea0 100644 --- a/src/leapfrogai_api/pyproject.toml +++ b/src/leapfrogai_api/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "postgrest==0.16.8", # required by supabase, bug when using previous versions "openpyxl == 3.1.5", "psutil == 6.0.0", + "rerankers[flashrank] == 0.5.3" ] requires-python = "~=3.11" From b355e86c8508a8f515a6f2c282ade43d62a806a7 Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 11:15:08 -0700 Subject: [PATCH 23/59] Fixes invalid dictionary index --- src/leapfrogai_api/backend/rag/query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 6326c5154..d49c4d380 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -96,7 +96,7 @@ def rerank_search_response( reranked_items = [] for content in ranked_results.results: if content.document.doc_id in content_to_item: - item: SearchItem = content_to_item[content] + item: SearchItem = content_to_item[content.document.doc_id] # TODO: Find a better way to handle this # Update the similarity to instead be the rank item.similarity = content.rank From 77b7249368e4a089b5155073e19a9c829acff195 Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 11:30:51 -0700 Subject: [PATCH 24/59] Adds more ranking models and configuration for ranking models --- src/leapfrogai_api/pyproject.toml | 2 +- src/leapfrogai_api/routers/leapfrogai/rag.py | 4 +++- src/leapfrogai_api/typedef/rag/rag_types.py | 6 ++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/leapfrogai_api/pyproject.toml b/src/leapfrogai_api/pyproject.toml index bdbc8aea0..59581fa70 100644 --- a/src/leapfrogai_api/pyproject.toml +++ b/src/leapfrogai_api/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "postgrest==0.16.8", # required by supabase, bug when using previous versions "openpyxl == 3.1.5", "psutil == 6.0.0", - "rerankers[flashrank] == 0.5.3" + "rerankers[flashrank,transformers,rankllm] == 0.5.3" ] requires-python = "~=3.11" diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 09a6df8bd..ae8bb462b 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -20,6 +20,7 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No # We set the class variable to update the configuration globally Configuration.enable_reranking = configuration.enable_reranking + Configuration.ranking_model = configuration.ranking_model @router.get("/configure") @@ -35,7 +36,8 @@ async def get_configuration(session: Session) -> ConfigurationPayload: """ new_configuration = ConfigurationPayload( - enable_reranking=Configuration.enable_reranking + enable_reranking=Configuration.enable_reranking, + ranking_model=Configuration.ranking_model, ) logger.info(f"The current configuration has been set to {new_configuration}") diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 6339ab258..c05cb5bfe 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -9,6 +9,7 @@ class Configuration(BaseModel): # This is a class variable, shared by all instances of Configuration # It sets a default value, but doesn't create an instance variable enable_reranking: ClassVar[bool] = False + ranking_model: ClassVar[str] = "flashrank" # Note: Pydantic will not create an instance variable for ClassVar fields # If you need an instance variable, you should declare it separately @@ -22,3 +23,8 @@ class ConfigurationPayload(BaseModel): enable_reranking: bool = Field( default=False, description="Enables reranking for RAG queries" ) + ranking_model: str = Field( + default="flashrank", + description="What model to use for reranking", + examples=["flashrank", "rankllm", "cross-encoder", "colbert"], + ) From 75e70fd273bcdd4ec6670be2b88645686c7e03b0 Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 11:37:48 -0700 Subject: [PATCH 25/59] Adds score and rank to search item response --- src/leapfrogai_api/backend/rag/query.py | 3 ++- .../typedef/vectorstores/search_types.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index d49c4d380..f4372cdca 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -99,7 +99,8 @@ def rerank_search_response( item: SearchItem = content_to_item[content.document.doc_id] # TODO: Find a better way to handle this # Update the similarity to instead be the rank - item.similarity = content.rank + item.rank = content.rank + item.score = content.score reranked_items.append(item) reranked_response = SearchResponse(data=reranked_items) diff --git a/src/leapfrogai_api/typedef/vectorstores/search_types.py b/src/leapfrogai_api/typedef/vectorstores/search_types.py index 9949f01ea..5c48a6830 100644 --- a/src/leapfrogai_api/typedef/vectorstores/search_types.py +++ b/src/leapfrogai_api/typedef/vectorstores/search_types.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel, Field @@ -16,6 +18,14 @@ class SearchItem(BaseModel): similarity: float = Field( ..., description="Similarity score of this item to the query." ) + rank: Optional[int] = Field( + default=None, + description="The rank of this search item after ranking has occurred.", + ) + score: Optional[float] = Field( + default=None, + description="The score of this search item after ranking has occurred.", + ) class SearchResponse(BaseModel): From 2db8263e08bb2905c90cca26446695c4e7387b2d Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 12:07:11 -0700 Subject: [PATCH 26/59] Ensures that the configured model is used when ranking --- src/leapfrogai_api/backend/rag/query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index f4372cdca..b511d51a0 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -62,7 +62,7 @@ async def query_rag( # 3. Rerank results if Configuration.enable_reranking: - ranker = Reranker("flashrank") + ranker = Reranker(Configuration.ranking_model) ranked_results: RankedResults = ranker.rank( query=query, docs=[result.content for result in results.data], From f46b599c73c27614bc45e3495a84bda403a59c00 Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 12:15:32 -0700 Subject: [PATCH 27/59] Removes transformers options from rerankers --- src/leapfrogai_api/pyproject.toml | 2 +- src/leapfrogai_api/typedef/rag/rag_types.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/leapfrogai_api/pyproject.toml b/src/leapfrogai_api/pyproject.toml index 59581fa70..17aa9c5e1 100644 --- a/src/leapfrogai_api/pyproject.toml +++ b/src/leapfrogai_api/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "postgrest==0.16.8", # required by supabase, bug when using previous versions "openpyxl == 3.1.5", "psutil == 6.0.0", - "rerankers[flashrank,transformers,rankllm] == 0.5.3" + "rerankers[flashrank,rankllm] == 0.5.3" ] requires-python = "~=3.11" diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index c05cb5bfe..78e7667f7 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -9,6 +9,9 @@ class Configuration(BaseModel): # This is a class variable, shared by all instances of Configuration # It sets a default value, but doesn't create an instance variable enable_reranking: ClassVar[bool] = False + # More model info can be found here: + # https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file + # https://pypi.org/project/rerankers/ ranking_model: ClassVar[str] = "flashrank" # Note: Pydantic will not create an instance variable for ClassVar fields @@ -18,11 +21,17 @@ class Configuration(BaseModel): class ConfigurationPayload(BaseModel): """Response for RAG configuration.""" + # Singleton instance of this class + instance: ClassVar[any] = None + # This is an instance variable, specific to each ConfigurationPayload object # It will be included in the JSON output when the model is serialized enable_reranking: bool = Field( default=False, description="Enables reranking for RAG queries" ) + # More model info can be found here: + # https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file + # https://pypi.org/project/rerankers/ ranking_model: str = Field( default="flashrank", description="What model to use for reranking", From 811861eee1744b19216920a3bb220b262a07089f Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 12:26:10 -0700 Subject: [PATCH 28/59] Replaces duplicate global class with singleton class --- src/leapfrogai_api/backend/rag/query.py | 6 ++--- src/leapfrogai_api/routers/leapfrogai/rag.py | 15 ++++++++---- src/leapfrogai_api/typedef/rag/__init__.py | 2 +- src/leapfrogai_api/typedef/rag/rag_types.py | 24 +++++++------------- 4 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index b511d51a0..973bafe46 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -7,7 +7,7 @@ from langchain_core.embeddings import Embeddings from leapfrogai_api.backend.rag.leapfrogai_embeddings import LeapfrogAIEmbeddings from leapfrogai_api.data.crud_vector_content import CRUDVectorContent -from leapfrogai_api.typedef.rag.rag_types import Configuration +from leapfrogai_api.typedef.rag.rag_types import ConfigurationSingleton from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse, SearchItem from leapfrogai_api.backend.constants import TOP_K from leapfrogai_api.utils import get_model_config @@ -61,8 +61,8 @@ async def query_rag( ) # 3. Rerank results - if Configuration.enable_reranking: - ranker = Reranker(Configuration.ranking_model) + if ConfigurationSingleton.get_instance().enable_reranking: + ranker = Reranker(ConfigurationSingleton.get_instance().ranking_model) ranked_results: RankedResults = ranker.rank( query=query, docs=[result.content for result in results.data], diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index ae8bb462b..6844c3714 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -1,7 +1,10 @@ """LeapfrogAI endpoints for RAG.""" from fastapi import APIRouter -from leapfrogai_api.typedef.rag.rag_types import Configuration, ConfigurationPayload +from leapfrogai_api.typedef.rag.rag_types import ( + ConfigurationSingleton, + ConfigurationPayload, +) from leapfrogai_api.routers.supabase_session import Session from leapfrogai_api.utils.logging_tools import logger @@ -19,8 +22,10 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No """ # We set the class variable to update the configuration globally - Configuration.enable_reranking = configuration.enable_reranking - Configuration.ranking_model = configuration.ranking_model + ConfigurationSingleton.get_instance().enable_reranking = ( + configuration.enable_reranking + ) + ConfigurationSingleton.get_instance().ranking_model = configuration.ranking_model @router.get("/configure") @@ -36,8 +41,8 @@ async def get_configuration(session: Session) -> ConfigurationPayload: """ new_configuration = ConfigurationPayload( - enable_reranking=Configuration.enable_reranking, - ranking_model=Configuration.ranking_model, + enable_reranking=ConfigurationSingleton.get_instance().enable_reranking, + ranking_model=ConfigurationSingleton.get_instance().ranking_model, ) logger.info(f"The current configuration has been set to {new_configuration}") diff --git a/src/leapfrogai_api/typedef/rag/__init__.py b/src/leapfrogai_api/typedef/rag/__init__.py index 003db7f80..65c2e26cd 100644 --- a/src/leapfrogai_api/typedef/rag/__init__.py +++ b/src/leapfrogai_api/typedef/rag/__init__.py @@ -1,3 +1,3 @@ from .rag_types import ( - Configuration as Configuration, + ConfigurationSingleton as ConfigurationSingleton, ) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 78e7667f7..d50cf05c2 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -1,29 +1,21 @@ -from typing import ClassVar - from pydantic import BaseModel, Field -class Configuration(BaseModel): - """Configuration for RAG.""" +class ConfigurationSingleton: + """Singleton manager for ConfigurationPayload.""" - # This is a class variable, shared by all instances of Configuration - # It sets a default value, but doesn't create an instance variable - enable_reranking: ClassVar[bool] = False - # More model info can be found here: - # https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file - # https://pypi.org/project/rerankers/ - ranking_model: ClassVar[str] = "flashrank" + _instance = None - # Note: Pydantic will not create an instance variable for ClassVar fields - # If you need an instance variable, you should declare it separately + @classmethod + def get_instance(cls, **kwargs): + if cls._instance is None: + cls._instance = ConfigurationPayload(**kwargs) + return cls._instance class ConfigurationPayload(BaseModel): """Response for RAG configuration.""" - # Singleton instance of this class - instance: ClassVar[any] = None - # This is an instance variable, specific to each ConfigurationPayload object # It will be included in the JSON output when the model is serialized enable_reranking: bool = Field( From c3e86c830217de8989d50be8477c7953617a9bcc Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 12:32:45 -0700 Subject: [PATCH 29/59] Returns transformers and removes rankllm --- src/leapfrogai_api/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/leapfrogai_api/pyproject.toml b/src/leapfrogai_api/pyproject.toml index 17aa9c5e1..8fbcc01fc 100644 --- a/src/leapfrogai_api/pyproject.toml +++ b/src/leapfrogai_api/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "postgrest==0.16.8", # required by supabase, bug when using previous versions "openpyxl == 3.1.5", "psutil == 6.0.0", - "rerankers[flashrank,rankllm] == 0.5.3" + "rerankers[flashrank,transformers] == 0.5.3" ] requires-python = "~=3.11" From f9273a3fe6ef4f72a025fedc69269bbd2f216382 Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 12:48:02 -0700 Subject: [PATCH 30/59] Switch to flashrank exlcusively, separate retrieval vs reranking topk --- src/leapfrogai_api/backend/rag/query.py | 11 ++++++++++- src/leapfrogai_api/pyproject.toml | 2 +- src/leapfrogai_api/routers/leapfrogai/rag.py | 4 ++++ src/leapfrogai_api/typedef/rag/rag_types.py | 4 ++++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 973bafe46..55ca68f14 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -55,9 +55,16 @@ async def query_rag( vector = await self.embeddings.aembed_query(query) # 2. Perform similarity search + _k: int = k + if ConfigurationSingleton.get_instance().enable_reranking: + # Use the user specified top-k value unless reranking + # When reranking, use the reranking top-k value to get the initial results + # Then filter the list down later to just the k that the user has requested + _k = ConfigurationSingleton.get_instance().rag_top_k_when_reranking + crud_vector_content = CRUDVectorContent(db=self.db) results = await crud_vector_content.similarity_search( - query=vector, vector_store_id=vector_store_id, k=k + query=vector, vector_store_id=vector_store_id, k=_k ) # 3. Rerank results @@ -69,6 +76,8 @@ async def query_rag( doc_ids=[result.id for result in results.data], ) results = rerank_search_response(results, ranked_results) + # Narrow down the results to the top-k value specified by the user + results.data = results.data[0:k] logger.info(f"Reranking complete {results.get_simple_response()}") logger.info("Ending RAG query...") diff --git a/src/leapfrogai_api/pyproject.toml b/src/leapfrogai_api/pyproject.toml index 8fbcc01fc..bdbc8aea0 100644 --- a/src/leapfrogai_api/pyproject.toml +++ b/src/leapfrogai_api/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "postgrest==0.16.8", # required by supabase, bug when using previous versions "openpyxl == 3.1.5", "psutil == 6.0.0", - "rerankers[flashrank,transformers] == 0.5.3" + "rerankers[flashrank] == 0.5.3" ] requires-python = "~=3.11" diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 6844c3714..8c6f6872e 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -26,6 +26,9 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No configuration.enable_reranking ) ConfigurationSingleton.get_instance().ranking_model = configuration.ranking_model + ConfigurationSingleton.get_instance().rag_top_k_when_reranking = ( + configuration.rag_top_k_when_reranking + ) @router.get("/configure") @@ -43,6 +46,7 @@ async def get_configuration(session: Session) -> ConfigurationPayload: new_configuration = ConfigurationPayload( enable_reranking=ConfigurationSingleton.get_instance().enable_reranking, ranking_model=ConfigurationSingleton.get_instance().ranking_model, + rag_top_k_when_reranking=ConfigurationSingleton.get_instance().rag_top_k_when_reranking, ) logger.info(f"The current configuration has been set to {new_configuration}") diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index d50cf05c2..ea86df87f 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -29,3 +29,7 @@ class ConfigurationPayload(BaseModel): description="What model to use for reranking", examples=["flashrank", "rankllm", "cross-encoder", "colbert"], ) + rag_top_k_when_reranking: int = Field( + default=100, + description="The top-k results returned from the RAG call before reranking", + ) From 8d4d829350e62b66f56674f7739ff819b22175ac Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 12:56:47 -0700 Subject: [PATCH 31/59] Adds disclaimer to support ranking model list --- src/leapfrogai_api/typedef/rag/rag_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index ea86df87f..a55f81a31 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -26,7 +26,7 @@ class ConfigurationPayload(BaseModel): # https://pypi.org/project/rerankers/ ranking_model: str = Field( default="flashrank", - description="What model to use for reranking", + description="What model to use for reranking. Some options may require additional python dependencies.", examples=["flashrank", "rankllm", "cross-encoder", "colbert"], ) rag_top_k_when_reranking: int = Field( From 8e435beaf28930f1c04ddc3b0536ffd266b772e6 Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 13:22:27 -0700 Subject: [PATCH 32/59] Removes todo --- src/leapfrogai_api/backend/rag/query.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 55ca68f14..5e003d311 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -106,8 +106,6 @@ def rerank_search_response( for content in ranked_results.results: if content.document.doc_id in content_to_item: item: SearchItem = content_to_item[content.document.doc_id] - # TODO: Find a better way to handle this - # Update the similarity to instead be the rank item.rank = content.rank item.score = content.score reranked_items.append(item) From e61fda99a0e0217297c6f3b983fba8f3bcf75679 Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 13:35:30 -0700 Subject: [PATCH 33/59] Adds configure endpoint API test --- tests/pytest/leapfrogai_api/test_api.py | 33 +++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py index 10fbe698b..6a1c4dfe3 100644 --- a/tests/pytest/leapfrogai_api/test_api.py +++ b/tests/pytest/leapfrogai_api/test_api.py @@ -31,6 +31,7 @@ ) TEXT_INPUT_LEN = len(TEXT_INPUT) + ######################### ######################### @@ -152,6 +153,7 @@ def test_routes(): "/openai/v1/files": ["POST"], "/openai/v1/assistants": ["POST"], "/leapfrogai/v1/count/tokens": ["POST"], + "/leapfrogai/v1/rag/configure": ["PATCH", "GET"], } openai_routes = [ @@ -540,3 +542,34 @@ def test_token_count(dummy_auth_middleware): assert "token_count" in response_data assert isinstance(response_data["token_count"], int) assert response_data["token_count"] == len(input_text) + + +@pytest.mark.skipif( + os.environ.get("LFAI_RUN_REPEATER_TESTS") != "true", + reason="LFAI_RUN_REPEATER_TESTS envvar was not set to true", +) +def test_configure(dummy_auth_middleware): + """Test the RAG configuration endpoints.""" + with TestClient(app) as client: + token_count_request = { + "enable_reranking": True, + "ranking_model": "rankllm", + "rag_top_k_when_reranking": 50, + } + response = client.patch( + "/leapfrogai/v1/rag/configure", json=token_count_request + ) + assert response.status_code == 200 + + response = client.get("/leapfrogai/v1/rag/configure") + assert response.status_code == 200 + response_data = response.json() + assert "enable_reranking" in response_data + assert "ranking_model" in response_data + assert "rag_top_k_when_reranking" in response_data + assert isinstance(response_data["enable_reranking"], bool) + assert isinstance(response_data["ranking_model"], str) + assert isinstance(response_data["rag_top_k_when_reranking"], int) + assert response_data["enable_reranking"] is True + assert response_data["ranking_model"] == "rankllm" + assert response_data["rag_top_k_when_reranking"] == 50 From 0a429a70a2ff93f9fb201f22558a57f9ea1dc0ac Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 15:05:45 -0700 Subject: [PATCH 34/59] Changes comment formatting --- src/leapfrogai_api/backend/rag/query.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 5e003d311..5a347c158 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -57,9 +57,9 @@ async def query_rag( # 2. Perform similarity search _k: int = k if ConfigurationSingleton.get_instance().enable_reranking: - # Use the user specified top-k value unless reranking - # When reranking, use the reranking top-k value to get the initial results - # Then filter the list down later to just the k that the user has requested + """Use the user specified top-k value unless reranking. + When reranking, use the reranking top-k value to get the initial results. + Then filter the list down later to just the k that the user has requested after reranking.""" _k = ConfigurationSingleton.get_instance().rag_top_k_when_reranking crud_vector_content = CRUDVectorContent(db=self.db) From a42a94a0af252330592c4ddaadd928357c19992f Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 15:30:18 -0700 Subject: [PATCH 35/59] Swap method order in test --- tests/pytest/leapfrogai_api/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py index 6a1c4dfe3..4e49c45b8 100644 --- a/tests/pytest/leapfrogai_api/test_api.py +++ b/tests/pytest/leapfrogai_api/test_api.py @@ -153,7 +153,7 @@ def test_routes(): "/openai/v1/files": ["POST"], "/openai/v1/assistants": ["POST"], "/leapfrogai/v1/count/tokens": ["POST"], - "/leapfrogai/v1/rag/configure": ["PATCH", "GET"], + "/leapfrogai/v1/rag/configure": ["GET", "PATCH"], } openai_routes = [ From 6591d4818574d457dfd0f185e4383c47d8eab81d Mon Sep 17 00:00:00 2001 From: gharvey Date: Wed, 25 Sep 2024 16:15:14 -0700 Subject: [PATCH 36/59] Adds RAG + reranking e2e test --- tests/integration/api/test_rag_files.py | 72 +++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/integration/api/test_rag_files.py b/tests/integration/api/test_rag_files.py index 9ed2ad28c..8888d9ebd 100644 --- a/tests/integration/api/test_rag_files.py +++ b/tests/integration/api/test_rag_files.py @@ -1,8 +1,12 @@ import os from pathlib import Path +from typing import Optional + +import requests from openai.types.beta.threads.text import Text import pytest +from leapfrogai_api.typedef.rag.rag_types import ConfigurationPayload from ...utils.client import client_config_factory @@ -80,3 +84,71 @@ def test_rag_needle_haystack(): for a in message_content.annotations: print(a.text) + + +def configure_rag( + base_url: str, + enable_reranking: bool, + ranking_model: str, + rag_top_k_when_reranking: int, +): + """ + Configures the RAG settings. + + Args: + base_url: The base URL of the API (e.g., "http://localhost:8000"). + enable_reranking: Whether to enable reranking. + ranking_model: The ranking model to use. + rag_top_k_when_reranking: The top-k results to return before reranking. + """ + + url = f"{base_url}/leapfrogai/v1/rag/configure" + configuration = ConfigurationPayload( + enable_reranking=enable_reranking, + ranking_model=ranking_model, + rag_top_k_when_reranking=rag_top_k_when_reranking, + ) + + try: + response = requests.patch(url, json=configuration.model_dump()) + response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) + print("RAG configuration updated successfully.") + except requests.exceptions.RequestException as e: + print(f"Error configuring RAG: {e}") + + +def get_rag_configuration(base_url: str) -> Optional[ConfigurationPayload]: + """ + Retrieves the current RAG configuration. + + Args: + base_url: The base URL of the API. + + Returns: + The RAG configuration, or None if there was an error. + """ + url = f"{base_url}/leapfrogai/v1/rag/configure" + + try: + response = requests.get(url) + response.raise_for_status() + config = ConfigurationPayload.model_validate_json(response.text) + print(f"Current RAG configuration: {config}") + return config + except requests.exceptions.RequestException as e: + print(f"Error getting RAG configuration: {e}") + return None + + +@pytest.mark.skipif( + os.environ.get("LFAI_RUN_NIAH_TESTS") != "true", + reason="LFAI_RUN_NIAH_TESTS envvar was not set to true", +) +def test_rag_needle_haystack_with_reranking(): + base_url = os.getenv( + "LEAPFROGAI_API_URL", "https://leapfrogai-api.uds.dev/openai/v1" + ) + configure_rag(base_url, True, "flashrank", 100) + config_result = get_rag_configuration(base_url) + assert config_result.enable_reranking is True + test_rag_needle_haystack() From c245c94bd55edb42a5d69e0802cbc13aaf2d2852 Mon Sep 17 00:00:00 2001 From: gharvey Date: Thu, 26 Sep 2024 10:20:26 -0700 Subject: [PATCH 37/59] Refactors test_routes api test --- tests/pytest/leapfrogai_api/test_api.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py index 4e49c45b8..532642e8a 100644 --- a/tests/pytest/leapfrogai_api/test_api.py +++ b/tests/pytest/leapfrogai_api/test_api.py @@ -203,10 +203,14 @@ def test_routes(): ] actual_routes = app.routes - for route in actual_routes: - if hasattr(route, "path") and route.path in expected_routes: - assert route.methods == set(expected_routes[route.path]) - del expected_routes[route.path] + for expected_route in expected_routes: + matching_routes = {expected_route: []} + for actual_route in actual_routes: + if hasattr(actual_route, "path") and expected_route == actual_route.path: + matching_routes[actual_route.path].extend(actual_route.methods) + assert set(expected_routes[expected_route]) <= set( + matching_routes[expected_route] + ) for route, name, methods in openai_routes: found = False @@ -221,8 +225,6 @@ def test_routes(): break assert found, f"Missing route: {route}, {name}, {methods}" - assert len(expected_routes) == 0 - def test_healthz(): """Test the healthz endpoint.""" From 8259220b93f15a7834aa80af8300c028e85b1953 Mon Sep 17 00:00:00 2001 From: gharvey Date: Thu, 26 Sep 2024 10:25:34 -0700 Subject: [PATCH 38/59] Refactors logging, var names, and removes unneeded logs --- src/leapfrogai_api/backend/rag/query.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 5a347c158..4d41f5903 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -49,7 +49,7 @@ async def query_rag( SearchResponse: The search response from the vector store. """ - logger.info("Beginning RAG query...") + logger.debug("Beginning RAG query...") # 1. Embed query vector = await self.embeddings.aembed_query(query) @@ -78,9 +78,8 @@ async def query_rag( results = rerank_search_response(results, ranked_results) # Narrow down the results to the top-k value specified by the user results.data = results.data[0:k] - logger.info(f"Reranking complete {results.get_simple_response()}") - logger.info("Ending RAG query...") + logger.debug("Ending RAG query...") return results @@ -102,19 +101,15 @@ def rerank_search_response( content_to_item = {item.id: item for item in original_response.data} # Create new SearchItems based on reranked results - reranked_items = [] + ranked_items = [] for content in ranked_results.results: if content.document.doc_id in content_to_item: item: SearchItem = content_to_item[content.document.doc_id] item.rank = content.rank item.score = content.score - reranked_items.append(item) + ranked_items.append(item) - reranked_response = SearchResponse(data=reranked_items) - - logger.info( - f"Original documents: {original_response.get_simple_response()}\nReranked documents {reranked_response.get_simple_response()}" - ) + ranked_response = SearchResponse(data=ranked_items) # Create a new SearchResponse with reranked items - return reranked_response + return ranked_response From 61cad5e4b377cdd07ad15458b322ec169419492e Mon Sep 17 00:00:00 2001 From: gharvey Date: Thu, 26 Sep 2024 12:05:46 -0700 Subject: [PATCH 39/59] Adds dev flag for the configure endpoint --- .github/workflows/pytest.yaml | 1 + src/leapfrogai_api/main.py | 4 +++- src/leapfrogai_api/typedef/rag/rag_types.py | 2 +- tests/pytest/leapfrogai_api/test_api.py | 3 ++- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 93d0f0832..21d2e1985 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -64,6 +64,7 @@ jobs: run: make test-api-unit env: LFAI_RUN_REPEATER_TESTS: true + DEV: true integration: runs-on: ai-ubuntu-big-boy-8-core diff --git a/src/leapfrogai_api/main.py b/src/leapfrogai_api/main.py index ad2f039f5..30e3016fc 100644 --- a/src/leapfrogai_api/main.py +++ b/src/leapfrogai_api/main.py @@ -82,7 +82,9 @@ async def validation_exception_handler(request, exc): app.include_router(messages.router) app.include_router(runs_steps.router) app.include_router(lfai_vector_stores.router) -app.include_router(lfai_rag.router) +# Only enable this in dev mode +if os.environ.get("DEV"): + app.include_router(lfai_rag.router) app.include_router(lfai_token_count.router) app.include_router(lfai_models.router) # This should be at the bottom to prevent it preempting more specific runs endpoints diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index a55f81a31..ad3b46d16 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -19,7 +19,7 @@ class ConfigurationPayload(BaseModel): # This is an instance variable, specific to each ConfigurationPayload object # It will be included in the JSON output when the model is serialized enable_reranking: bool = Field( - default=False, description="Enables reranking for RAG queries" + default=True, description="Enables reranking for RAG queries" ) # More model info can be found here: # https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py index 532642e8a..2de15bab1 100644 --- a/tests/pytest/leapfrogai_api/test_api.py +++ b/tests/pytest/leapfrogai_api/test_api.py @@ -547,7 +547,8 @@ def test_token_count(dummy_auth_middleware): @pytest.mark.skipif( - os.environ.get("LFAI_RUN_REPEATER_TESTS") != "true", + os.environ.get("LFAI_RUN_REPEATER_TESTS") != "true" + or os.environ.get("DEV") != "true", reason="LFAI_RUN_REPEATER_TESTS envvar was not set to true", ) def test_configure(dummy_auth_middleware): From ac8f58e5f749036d16d77d55d45e9bd9993bb67e Mon Sep 17 00:00:00 2001 From: gharvey Date: Thu, 26 Sep 2024 12:22:44 -0700 Subject: [PATCH 40/59] Adds dev_only decorator for fast_api dev endpoints --- src/leapfrogai_api/main.py | 4 +--- src/leapfrogai_api/routers/leapfrogai/rag.py | 3 +++ src/leapfrogai_api/utils/decorators.py | 18 ++++++++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 src/leapfrogai_api/utils/decorators.py diff --git a/src/leapfrogai_api/main.py b/src/leapfrogai_api/main.py index 30e3016fc..ad2f039f5 100644 --- a/src/leapfrogai_api/main.py +++ b/src/leapfrogai_api/main.py @@ -82,9 +82,7 @@ async def validation_exception_handler(request, exc): app.include_router(messages.router) app.include_router(runs_steps.router) app.include_router(lfai_vector_stores.router) -# Only enable this in dev mode -if os.environ.get("DEV"): - app.include_router(lfai_rag.router) +app.include_router(lfai_rag.router) app.include_router(lfai_token_count.router) app.include_router(lfai_models.router) # This should be at the bottom to prevent it preempting more specific runs endpoints diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 8c6f6872e..30bd02d97 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -6,11 +6,13 @@ ConfigurationPayload, ) from leapfrogai_api.routers.supabase_session import Session +from leapfrogai_api.utils.decorators import dev_only from leapfrogai_api.utils.logging_tools import logger router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) +@dev_only @router.patch("/configure") async def configure(session: Session, configuration: ConfigurationPayload) -> None: """ @@ -31,6 +33,7 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No ) +@dev_only @router.get("/configure") async def get_configuration(session: Session) -> ConfigurationPayload: """ diff --git a/src/leapfrogai_api/utils/decorators.py b/src/leapfrogai_api/utils/decorators.py new file mode 100644 index 000000000..241389b08 --- /dev/null +++ b/src/leapfrogai_api/utils/decorators.py @@ -0,0 +1,18 @@ +import os +from functools import wraps + +from leapfrogai_api.utils.logging_tools import logger + + +def dev_only(func): + """Decorator to conditionally register a FastAPI route only when the env var 'DEV' is set.""" + + @wraps(func) + def wrapper(*args, **kwargs): + if os.environ.get("DEV") == "true": + return func(*args, **kwargs) + else: + logger.warning(f"Route '{func.__name__}' is only available in dev mode.") + return None + + return wrapper From 713e7e815b2ad4dc086c4ac3abd83ed19d9c2141 Mon Sep 17 00:00:00 2001 From: gharvey Date: Thu, 26 Sep 2024 12:26:30 -0700 Subject: [PATCH 41/59] Removes unnecessary function --- src/leapfrogai_api/typedef/vectorstores/search_types.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/leapfrogai_api/typedef/vectorstores/search_types.py b/src/leapfrogai_api/typedef/vectorstores/search_types.py index 5c48a6830..9ce6f3873 100644 --- a/src/leapfrogai_api/typedef/vectorstores/search_types.py +++ b/src/leapfrogai_api/typedef/vectorstores/search_types.py @@ -36,10 +36,3 @@ class SearchResponse(BaseModel): description="List of RAG items returned as a result of the query.", min_length=0, ) - - def get_simple_response(self): - response_without_content = [ - {"id": item.id, "similarity": item.similarity} for item in self.data - ] - - return response_without_content From 14332999dcac14c2fd1cfc1202ddf38ade361481 Mon Sep 17 00:00:00 2001 From: gharvey Date: Thu, 26 Sep 2024 12:28:43 -0700 Subject: [PATCH 42/59] Removes unneeded imports and functions params --- src/leapfrogai_api/backend/rag/query.py | 6 ------ src/leapfrogai_api/routers/leapfrogai/vector_stores.py | 9 ++------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 4d41f5903..985285dee 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -1,7 +1,5 @@ """Service for querying the RAG model.""" -from typing import Annotated -from fastapi import Depends from rerankers.results import RankedResults from supabase import AClient as AsyncClient from langchain_core.embeddings import Embeddings @@ -10,8 +8,6 @@ from leapfrogai_api.typedef.rag.rag_types import ConfigurationSingleton from leapfrogai_api.typedef.vectorstores.search_types import SearchResponse, SearchItem from leapfrogai_api.backend.constants import TOP_K -from leapfrogai_api.utils import get_model_config -from leapfrogai_api.utils.config import Config from leapfrogai_api.utils.logging_tools import logger from rerankers import Reranker @@ -31,7 +27,6 @@ def __init__(self, db: AsyncClient) -> None: async def query_rag( self, - model_config: Annotated[Config, Depends(get_model_config)], query: str, vector_store_id: str, k: int = TOP_K, @@ -40,7 +35,6 @@ async def query_rag( Query the Vector Store. Args: - model_config (Config): The current model configuration. query (str): The input query string. vector_store_id (str): The ID of the vector store. k (int, optional): The number of results to retrieve. diff --git a/src/leapfrogai_api/routers/leapfrogai/vector_stores.py b/src/leapfrogai_api/routers/leapfrogai/vector_stores.py index 0c8dca4e0..f3f0e2a89 100644 --- a/src/leapfrogai_api/routers/leapfrogai/vector_stores.py +++ b/src/leapfrogai_api/routers/leapfrogai/vector_stores.py @@ -1,13 +1,10 @@ """LeapfrogAI endpoints for RAG.""" -from typing import Annotated - -from fastapi import APIRouter, Depends +from fastapi import APIRouter from leapfrogai_api.backend.rag.query import QueryService from leapfrogai_api.typedef.vectorstores import SearchResponse from leapfrogai_api.routers.supabase_session import Session from leapfrogai_api.backend.constants import TOP_K -from leapfrogai_api.utils import Config, get_model_config router = APIRouter( prefix="/leapfrogai/v1/vector_stores", tags=["leapfrogai/vector_stores"] @@ -17,7 +14,6 @@ @router.post("/search") async def search( session: Session, - model_config: Annotated[Config, Depends(get_model_config)], query: str, vector_store_id: str, k: int = TOP_K, @@ -27,7 +23,6 @@ async def search( Args: session (Session): The database session. - model_config (Config): The current model configuration. query (str): The input query string. vector_store_id (str): The ID of the vector store. k (int, optional): The number of results to retrieve. @@ -37,5 +32,5 @@ async def search( """ query_service = QueryService(db=session) return await query_service.query_rag( - query=query, vector_store_id=vector_store_id, k=k, model_config=model_config + query=query, vector_store_id=vector_store_id, k=k ) From bb3b926c06154d86ec15bf1e3f5eb3c0c078f100 Mon Sep 17 00:00:00 2001 From: gharvey Date: Thu, 26 Sep 2024 18:25:46 -0700 Subject: [PATCH 43/59] Bumps eval deps versions and adds check for empty rag results --- src/leapfrogai_api/backend/rag/query.py | 5 ++++- src/leapfrogai_api/typedef/rag/rag_types.py | 2 -- src/leapfrogai_evals/pyproject.toml | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/leapfrogai_api/backend/rag/query.py b/src/leapfrogai_api/backend/rag/query.py index 985285dee..bd0ae9bf6 100644 --- a/src/leapfrogai_api/backend/rag/query.py +++ b/src/leapfrogai_api/backend/rag/query.py @@ -62,7 +62,10 @@ async def query_rag( ) # 3. Rerank results - if ConfigurationSingleton.get_instance().enable_reranking: + if ( + ConfigurationSingleton.get_instance().enable_reranking + and len(results.data) > 0 + ): ranker = Reranker(ConfigurationSingleton.get_instance().ranking_model) ranked_results: RankedResults = ranker.rank( query=query, diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index ad3b46d16..a0e9de92e 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -16,8 +16,6 @@ def get_instance(cls, **kwargs): class ConfigurationPayload(BaseModel): """Response for RAG configuration.""" - # This is an instance variable, specific to each ConfigurationPayload object - # It will be included in the JSON output when the model is serialized enable_reranking: bool = Field( default=True, description="Enables reranking for RAG queries" ) diff --git a/src/leapfrogai_evals/pyproject.toml b/src/leapfrogai_evals/pyproject.toml index 8d671cafd..ff1a0d675 100644 --- a/src/leapfrogai_evals/pyproject.toml +++ b/src/leapfrogai_evals/pyproject.toml @@ -8,7 +8,7 @@ version = "0.13.0" dependencies = [ "deepeval == 1.3.0", - "openai == 1.42.0", + "openai == 1.45.0", "tqdm == 4.66.5", "python-dotenv == 1.0.1", "seaborn == 0.13.2", @@ -16,7 +16,8 @@ dependencies = [ "huggingface-hub == 0.24.6", "anthropic ==0.34.2", "instructor ==1.4.3", - "pyPDF2 == 3.0.1" + "pyPDF2 == 3.0.1", + "python-dotenv == 1.0.1" ] requires-python = "~=3.11" readme = "README.md" From 70d864e752e8d1099dbe74ce5e109a8b6ae2a69e Mon Sep 17 00:00:00 2001 From: gharvey Date: Thu, 26 Sep 2024 19:10:01 -0700 Subject: [PATCH 44/59] Removes dev_only flag in favor of a simpler solution --- src/leapfrogai_api/main.py | 3 ++- src/leapfrogai_api/routers/leapfrogai/rag.py | 3 --- src/leapfrogai_api/utils/decorators.py | 18 ------------------ 3 files changed, 2 insertions(+), 22 deletions(-) delete mode 100644 src/leapfrogai_api/utils/decorators.py diff --git a/src/leapfrogai_api/main.py b/src/leapfrogai_api/main.py index ad2f039f5..f9b3682d4 100644 --- a/src/leapfrogai_api/main.py +++ b/src/leapfrogai_api/main.py @@ -82,7 +82,8 @@ async def validation_exception_handler(request, exc): app.include_router(messages.router) app.include_router(runs_steps.router) app.include_router(lfai_vector_stores.router) -app.include_router(lfai_rag.router) +if os.environ.get("DEV"): + app.include_router(lfai_rag.router) app.include_router(lfai_token_count.router) app.include_router(lfai_models.router) # This should be at the bottom to prevent it preempting more specific runs endpoints diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 30bd02d97..8c6f6872e 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -6,13 +6,11 @@ ConfigurationPayload, ) from leapfrogai_api.routers.supabase_session import Session -from leapfrogai_api.utils.decorators import dev_only from leapfrogai_api.utils.logging_tools import logger router = APIRouter(prefix="/leapfrogai/v1/rag", tags=["leapfrogai/rag"]) -@dev_only @router.patch("/configure") async def configure(session: Session, configuration: ConfigurationPayload) -> None: """ @@ -33,7 +31,6 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No ) -@dev_only @router.get("/configure") async def get_configuration(session: Session) -> ConfigurationPayload: """ diff --git a/src/leapfrogai_api/utils/decorators.py b/src/leapfrogai_api/utils/decorators.py deleted file mode 100644 index 241389b08..000000000 --- a/src/leapfrogai_api/utils/decorators.py +++ /dev/null @@ -1,18 +0,0 @@ -import os -from functools import wraps - -from leapfrogai_api.utils.logging_tools import logger - - -def dev_only(func): - """Decorator to conditionally register a FastAPI route only when the env var 'DEV' is set.""" - - @wraps(func) - def wrapper(*args, **kwargs): - if os.environ.get("DEV") == "true": - return func(*args, **kwargs) - else: - logger.warning(f"Route '{func.__name__}' is only available in dev mode.") - return None - - return wrapper From 93e7d9d0b6600c9fce108f654237d8acdd5f3477 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 10:46:13 -0700 Subject: [PATCH 45/59] Updates documentation and simplifies setting and getting configuration values --- src/leapfrogai_api/README.md | 69 ++++++++++++++++++++ src/leapfrogai_api/routers/leapfrogai/rag.py | 16 ++--- src/leapfrogai_api/typedef/rag/rag_types.py | 8 ++- 3 files changed, 80 insertions(+), 13 deletions(-) diff --git a/src/leapfrogai_api/README.md b/src/leapfrogai_api/README.md index eec4dd0c6..214c986a9 100644 --- a/src/leapfrogai_api/README.md +++ b/src/leapfrogai_api/README.md @@ -56,3 +56,72 @@ See the ["Access" section of the DEVELOPMENT.md](../../docs/DEVELOPMENT.md#acces ### Tests See the [tests directory documentation](../../tests/README.md) for more details. + +### Reranking Configuration + +The LeapfrogAI API includes a Retrieval Augmented Generation (RAG) pipeline for enhanced question answering. This section details how to configure its reranking options. All RAG configurations are managed through the `/leapfrogai/v1/rag/configure` API endpoint. + +#### 1. Enabling/Disabling Reranking + +Reranking improves the accuracy and relevance of RAG responses. You can enable or disable it using the `enable_reranking` parameter: + +* **Enable Reranking:** Send a PATCH request to `/leapfrogai/v1/rag/configure` with the following JSON payload: + +```json +{ + "enable_reranking": true +} +``` + +* **Disable Reranking:** Send a PATCH request with: + +```json +{ + "enable_reranking": false +} +``` + +#### 2. Selecting a Reranking Model + +Multiple reranking models are supported, each offering different performance characteristics. Choose your preferred model using the `ranking_model` parameter. Ensure you've installed any necessary Python dependencies for your chosen model (see the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) on dependencies). + +* **Supported Models:** The system supports several models, including (but not limited to) `flashrank`, `rankllm`, `cross-encoder`, and `colbert`. Refer to the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) for a complete list and details on their capabilities. + +* **Model Selection:** Use a PATCH request to `/leapfrogai/v1/rag/configure` with the desired model: + +```json +{ + "enable_reranking": true, // Reranking must be enabled + "ranking_model": "rankllm" // Or another supported model +} +``` + +#### 3. Adjusting the Number of Results Before Reranking (`rag_top_k_when_reranking`) + +This parameter sets the number of top results retrieved from the vector database *before* the reranking process begins. A higher value increases the diversity of candidates considered for reranking but also increases processing time. A lower value can lead to missing relevant results if not carefully chosen. This setting is only relevant when reranking is enabled. + +* **Configuration:** Use a PATCH request to `/leapfrogai/v1/rag/configure` to set this value: + +```json +{ + "enable_reranking": true, + "ranking_model": "flashrank", + "rag_top_k_when_reranking": 150 // Adjust this value as needed +} +``` + +#### 4. Retrieving the Current RAG Configuration + +To check the current RAG configuration (including reranking status, model, and `rag_top_k_when_reranking`), send a GET request to `/leapfrogai/v1/rag/configure`. The response will be a JSON object containing all the current settings. + +#### 5. Example Configuration Flow + +1. **Initial Setup:** Start with reranking enabled using the default `flashrank` model and a `rag_top_k_when_reranking` value of 100. + +2. **Experiment with Models:** Test different reranking models (`rankllm`, `colbert`, etc.) by changing the `ranking_model` parameter and observing the impact on response quality. Adjust `rag_top_k_when_reranking` as needed to find the optimal balance between diversity and performance. + +3. **Fine-tuning:** Once you identify a suitable model, fine-tune the `rag_top_k_when_reranking` parameter for optimal performance. Monitor response times and quality to determine the best setting. + +4. **Disabling Reranking:** If needed, disable reranking by setting `"enable_reranking": false`. + +Remember to always consult the [rerankers library documentation](https://github.com/AnswerDotAI/rerankers) for information on supported models and their specific requirements. The API documentation provides further details on request formats and potential error responses. diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 8c6f6872e..ecca74242 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -21,13 +21,11 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No configuration (Configuration): The configuration to update. """ + current_configuration = ConfigurationSingleton.get_instance() + # We set the class variable to update the configuration globally - ConfigurationSingleton.get_instance().enable_reranking = ( - configuration.enable_reranking - ) - ConfigurationSingleton.get_instance().ranking_model = configuration.ranking_model - ConfigurationSingleton.get_instance().rag_top_k_when_reranking = ( - configuration.rag_top_k_when_reranking + current_configuration._instance = current_configuration.copy( + update=configuration.__dict__ ) @@ -43,10 +41,8 @@ async def get_configuration(session: Session) -> ConfigurationPayload: Configuration: The current RAG configuration. """ - new_configuration = ConfigurationPayload( - enable_reranking=ConfigurationSingleton.get_instance().enable_reranking, - ranking_model=ConfigurationSingleton.get_instance().ranking_model, - rag_top_k_when_reranking=ConfigurationSingleton.get_instance().rag_top_k_when_reranking, + new_configuration = ConfigurationPayload.copy( + update=ConfigurationSingleton.get_instance().__dict__ ) logger.info(f"The current configuration has been set to {new_configuration}") diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index a0e9de92e..11ed90e70 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel, Field @@ -16,18 +18,18 @@ def get_instance(cls, **kwargs): class ConfigurationPayload(BaseModel): """Response for RAG configuration.""" - enable_reranking: bool = Field( + enable_reranking: Optional[bool] = Field( default=True, description="Enables reranking for RAG queries" ) # More model info can be found here: # https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file # https://pypi.org/project/rerankers/ - ranking_model: str = Field( + ranking_model: Optional[str] = Field( default="flashrank", description="What model to use for reranking. Some options may require additional python dependencies.", examples=["flashrank", "rankllm", "cross-encoder", "colbert"], ) - rag_top_k_when_reranking: int = Field( + rag_top_k_when_reranking: Optional[int] = Field( default=100, description="The top-k results returned from the RAG call before reranking", ) From fd54aa253408de2af5ef2d8ce4190184f10e1de3 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 10:54:33 -0700 Subject: [PATCH 46/59] Fixes get request for config --- src/leapfrogai_api/routers/leapfrogai/rag.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index ecca74242..e610e5f0f 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -41,9 +41,17 @@ async def get_configuration(session: Session) -> ConfigurationPayload: Configuration: The current RAG configuration. """ - new_configuration = ConfigurationPayload.copy( - update=ConfigurationSingleton.get_instance().__dict__ - ) + instance = ConfigurationSingleton.get_instance() + + # Create a new dictionary with only the relevant attributes + config_dict = { + key: value + for key, value in instance.__dict__.items() + if not key.startswith("_") # Exclude private attributes + } + + # Create a new ConfigurationPayload instance with the filtered dictionary + new_configuration = ConfigurationPayload(**config_dict) logger.info(f"The current configuration has been set to {new_configuration}") From d2db2c72333557e6c5d19396e8680ac067478a45 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 11:04:29 -0700 Subject: [PATCH 47/59] Ensure that the singleton gets updated --- src/leapfrogai_api/routers/leapfrogai/rag.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index e610e5f0f..f0000f7f1 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -21,10 +21,8 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No configuration (Configuration): The configuration to update. """ - current_configuration = ConfigurationSingleton.get_instance() - # We set the class variable to update the configuration globally - current_configuration._instance = current_configuration.copy( + ConfigurationSingleton._instance = ConfigurationSingleton.get_instance().copy( update=configuration.__dict__ ) From fd4bea15c00956aeb83cbaef8ed805e2029131b7 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 11:10:28 -0700 Subject: [PATCH 48/59] Adds zarf configs for dev mode --- packages/api/chart/values.yaml | 2 ++ packages/api/values/registry1-values.yaml | 2 ++ packages/api/values/upstream-values.yaml | 2 ++ packages/api/zarf.yaml | 3 +++ 4 files changed, 9 insertions(+) diff --git a/packages/api/chart/values.yaml b/packages/api/chart/values.yaml index 65b397e46..4c217ba8a 100644 --- a/packages/api/chart/values.yaml +++ b/packages/api/chart/values.yaml @@ -25,6 +25,8 @@ api: value: "*.toml" - name: DEFAULT_EMBEDDINGS_MODEL value: "text-embeddings" + - name: DEV + value: "false" - name: PORT value: "8080" - name: SUPABASE_URL diff --git a/packages/api/values/registry1-values.yaml b/packages/api/values/registry1-values.yaml index d269c6415..4bd35ee39 100644 --- a/packages/api/values/registry1-values.yaml +++ b/packages/api/values/registry1-values.yaml @@ -16,6 +16,8 @@ api: value: "*.toml" - name: DEFAULT_EMBEDDINGS_MODEL value: "###ZARF_VAR_DEFAULT_EMBEDDINGS_MODEL###" + - name: DEV + value: "###ZARF_VAR_DEV###" - name: PORT value: "8080" - name: SUPABASE_URL diff --git a/packages/api/values/upstream-values.yaml b/packages/api/values/upstream-values.yaml index 6d867260e..ef2dcdad9 100644 --- a/packages/api/values/upstream-values.yaml +++ b/packages/api/values/upstream-values.yaml @@ -14,6 +14,8 @@ api: value: "*.toml" - name: DEFAULT_EMBEDDINGS_MODEL value: "###ZARF_VAR_DEFAULT_EMBEDDINGS_MODEL###" + - name: DEV + value: "###ZARF_VAR_DEV###" - name: PORT value: "8080" - name: SUPABASE_URL diff --git a/packages/api/zarf.yaml b/packages/api/zarf.yaml index 4fa6c59f2..92b3c8123 100644 --- a/packages/api/zarf.yaml +++ b/packages/api/zarf.yaml @@ -16,6 +16,9 @@ variables: description: "Flag to expose the OpenAPI schema for debugging." - name: DEFAULT_EMBEDDINGS_MODEL default: "text-embeddings" + - name: DEV + default: "false" + description: "Flag to enable development endpoints." components: - name: leapfrogai-api From 0872a5ee6aa1e3d8e0647e5538264464936697bc Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 11:23:36 -0700 Subject: [PATCH 49/59] Make deep copy to prevent issues with variables overwriting --- src/leapfrogai_api/routers/leapfrogai/rag.py | 4 +--- src/leapfrogai_api/typedef/rag/rag_types.py | 6 ++++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index f0000f7f1..ebf9ec0bf 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -22,9 +22,7 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No """ # We set the class variable to update the configuration globally - ConfigurationSingleton._instance = ConfigurationSingleton.get_instance().copy( - update=configuration.__dict__ - ) + ConfigurationSingleton.update_instance(configuration) @router.get("/configure") diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 11ed90e70..6243d1713 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -1,3 +1,4 @@ +import copy from typing import Optional from pydantic import BaseModel, Field @@ -14,6 +15,11 @@ def get_instance(cls, **kwargs): cls._instance = ConfigurationPayload(**kwargs) return cls._instance + @classmethod + def update_instance(cls, configuration): + cls._instance = copy.deepcopy(cls.get_instance()) + cls._instance.model_validate(configuration.__dict__) + class ConfigurationPayload(BaseModel): """Response for RAG configuration.""" From 0fb97b82f656e1518726d8e76003ae02b82a94e8 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 11:32:30 -0700 Subject: [PATCH 50/59] Fixes update logic --- src/leapfrogai_api/routers/leapfrogai/rag.py | 2 +- src/leapfrogai_api/typedef/rag/rag_types.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index ebf9ec0bf..791dd1f2d 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -22,7 +22,7 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No """ # We set the class variable to update the configuration globally - ConfigurationSingleton.update_instance(configuration) + ConfigurationSingleton.get_instance().update_instance(configuration) @router.get("/configure") diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 6243d1713..c73674dbb 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -1,4 +1,3 @@ -import copy from typing import Optional from pydantic import BaseModel, Field @@ -17,8 +16,8 @@ def get_instance(cls, **kwargs): @classmethod def update_instance(cls, configuration): - cls._instance = copy.deepcopy(cls.get_instance()) - cls._instance.model_validate(configuration.__dict__) + for key, value in configuration.items(): + setattr(cls._instance, key, value) class ConfigurationPayload(BaseModel): From b86fdf339febb9f1dae31db191fce049a28a7c3a Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 11:38:52 -0700 Subject: [PATCH 51/59] Moves the update function to the payload class --- src/leapfrogai_api/typedef/rag/rag_types.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index c73674dbb..63ad25c6a 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -14,11 +14,6 @@ def get_instance(cls, **kwargs): cls._instance = ConfigurationPayload(**kwargs) return cls._instance - @classmethod - def update_instance(cls, configuration): - for key, value in configuration.items(): - setattr(cls._instance, key, value) - class ConfigurationPayload(BaseModel): """Response for RAG configuration.""" @@ -38,3 +33,7 @@ class ConfigurationPayload(BaseModel): default=100, description="The top-k results returned from the RAG call before reranking", ) + + def update_instance(self, configuration): + for key, value in configuration.items(): + setattr(self._instance, key, value) From 825df2a7e5c23e43272c6fd1af3d21bdadb17e46 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 11:58:05 -0700 Subject: [PATCH 52/59] Prevents default overwriting --- src/leapfrogai_api/typedef/rag/rag_types.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 63ad25c6a..3f6442a08 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -19,21 +19,24 @@ class ConfigurationPayload(BaseModel): """Response for RAG configuration.""" enable_reranking: Optional[bool] = Field( - default=True, description="Enables reranking for RAG queries" + default=None, + examples=[True, False], + description="Enables reranking for RAG queries", ) # More model info can be found here: # https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file # https://pypi.org/project/rerankers/ ranking_model: Optional[str] = Field( - default="flashrank", + default=None, description="What model to use for reranking. Some options may require additional python dependencies.", examples=["flashrank", "rankllm", "cross-encoder", "colbert"], ) rag_top_k_when_reranking: Optional[int] = Field( - default=100, + default=None, description="The top-k results returned from the RAG call before reranking", ) - def update_instance(self, configuration): - for key, value in configuration.items(): - setattr(self._instance, key, value) + def update_instance(self, configuration: BaseModel): + for key, value in configuration.model_dump().items(): + if value is not None: + setattr(self, key, value) From cec2a26c389e4b4c97b669e85d5366e4513bee92 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 12:15:02 -0700 Subject: [PATCH 53/59] Prevents default overwriting --- src/leapfrogai_api/routers/leapfrogai/rag.py | 4 +++- src/leapfrogai_api/typedef/rag/rag_types.py | 12 +++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/leapfrogai_api/routers/leapfrogai/rag.py b/src/leapfrogai_api/routers/leapfrogai/rag.py index 791dd1f2d..3b61b616e 100644 --- a/src/leapfrogai_api/routers/leapfrogai/rag.py +++ b/src/leapfrogai_api/routers/leapfrogai/rag.py @@ -22,7 +22,9 @@ async def configure(session: Session, configuration: ConfigurationPayload) -> No """ # We set the class variable to update the configuration globally - ConfigurationSingleton.get_instance().update_instance(configuration) + ConfigurationSingleton._instance = ConfigurationSingleton.get_instance().copy( + update=configuration.dict(exclude_none=True) + ) @router.get("/configure") diff --git a/src/leapfrogai_api/typedef/rag/rag_types.py b/src/leapfrogai_api/typedef/rag/rag_types.py index 3f6442a08..17fe6601c 100644 --- a/src/leapfrogai_api/typedef/rag/rag_types.py +++ b/src/leapfrogai_api/typedef/rag/rag_types.py @@ -9,9 +9,12 @@ class ConfigurationSingleton: _instance = None @classmethod - def get_instance(cls, **kwargs): + def get_instance(cls): if cls._instance is None: - cls._instance = ConfigurationPayload(**kwargs) + cls._instance = ConfigurationPayload() + cls._instance.enable_reranking = True + cls._instance.rag_top_k_when_reranking = 100 + cls._instance.ranking_model = "flashrank" return cls._instance @@ -35,8 +38,3 @@ class ConfigurationPayload(BaseModel): default=None, description="The top-k results returned from the RAG call before reranking", ) - - def update_instance(self, configuration: BaseModel): - for key, value in configuration.model_dump().items(): - if value is not None: - setattr(self, key, value) From 588b3eca49f6b3e685aee109224d603c43e2b5e3 Mon Sep 17 00:00:00 2001 From: gharvey Date: Fri, 27 Sep 2024 12:27:32 -0700 Subject: [PATCH 54/59] Adds to configuration test --- tests/pytest/leapfrogai_api/test_api.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/pytest/leapfrogai_api/test_api.py b/tests/pytest/leapfrogai_api/test_api.py index 2de15bab1..a1aea7c95 100644 --- a/tests/pytest/leapfrogai_api/test_api.py +++ b/tests/pytest/leapfrogai_api/test_api.py @@ -554,13 +554,13 @@ def test_token_count(dummy_auth_middleware): def test_configure(dummy_auth_middleware): """Test the RAG configuration endpoints.""" with TestClient(app) as client: - token_count_request = { + rag_configuration_request = { "enable_reranking": True, "ranking_model": "rankllm", "rag_top_k_when_reranking": 50, } response = client.patch( - "/leapfrogai/v1/rag/configure", json=token_count_request + "/leapfrogai/v1/rag/configure", json=rag_configuration_request ) assert response.status_code == 200 @@ -576,3 +576,23 @@ def test_configure(dummy_auth_middleware): assert response_data["enable_reranking"] is True assert response_data["ranking_model"] == "rankllm" assert response_data["rag_top_k_when_reranking"] == 50 + + # Update only some of the configs to see if the existing ones persist + rag_configuration_request = {"ranking_model": "flashrank"} + response = client.patch( + "/leapfrogai/v1/rag/configure", json=rag_configuration_request + ) + assert response.status_code == 200 + + response = client.get("/leapfrogai/v1/rag/configure") + assert response.status_code == 200 + response_data = response.json() + assert "enable_reranking" in response_data + assert "ranking_model" in response_data + assert "rag_top_k_when_reranking" in response_data + assert isinstance(response_data["enable_reranking"], bool) + assert isinstance(response_data["ranking_model"], str) + assert isinstance(response_data["rag_top_k_when_reranking"], int) + assert response_data["enable_reranking"] is True + assert response_data["ranking_model"] == "flashrank" + assert response_data["rag_top_k_when_reranking"] == 50 From f8e6d209b2d646b8901ae63cec53ed5bf6c56300 Mon Sep 17 00:00:00 2001 From: Gato <115658935+CollectiveUnicorn@users.noreply.github.com> Date: Mon, 30 Sep 2024 09:05:38 -0700 Subject: [PATCH 55/59] Update test_rag_files.py --- tests/integration/api/test_rag_files.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/api/test_rag_files.py b/tests/integration/api/test_rag_files.py index c8273f519..09ba8fd71 100644 --- a/tests/integration/api/test_rag_files.py +++ b/tests/integration/api/test_rag_files.py @@ -1,5 +1,4 @@ import os -from pathlib import Path from typing import Optional import requests From 5a23a511e6433be2a97f98927c26541d9643d6b5 Mon Sep 17 00:00:00 2001 From: gharvey Date: Mon, 30 Sep 2024 14:09:42 -0700 Subject: [PATCH 56/59] Adds small fixes --- tests/integration/api/test_rag_files.py | 19 +++++++------------ tests/utils/client.py | 2 ++ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/tests/integration/api/test_rag_files.py b/tests/integration/api/test_rag_files.py index 09ba8fd71..864565a28 100644 --- a/tests/integration/api/test_rag_files.py +++ b/tests/integration/api/test_rag_files.py @@ -7,7 +7,7 @@ from tests.utils.data_path import data_path from leapfrogai_api.typedef.rag.rag_types import ConfigurationPayload -from tests.utils.client import client_config_factory +from tests.utils.client import ROOT_URL, client_config_factory def make_test_assistant(client, model, vector_store_id): @@ -84,7 +84,6 @@ def test_rag_needle_haystack(): def configure_rag( - base_url: str, enable_reranking: bool, ranking_model: str, rag_top_k_when_reranking: int, @@ -93,13 +92,11 @@ def configure_rag( Configures the RAG settings. Args: - base_url: The base URL of the API (e.g., "http://localhost:8000"). enable_reranking: Whether to enable reranking. ranking_model: The ranking model to use. rag_top_k_when_reranking: The top-k results to return before reranking. """ - - url = f"{base_url}/leapfrogai/v1/rag/configure" + url = f"{ROOT_URL}/leapfrogai/v1/rag/configure" configuration = ConfigurationPayload( enable_reranking=enable_reranking, ranking_model=ranking_model, @@ -114,7 +111,7 @@ def configure_rag( print(f"Error configuring RAG: {e}") -def get_rag_configuration(base_url: str) -> Optional[ConfigurationPayload]: +def get_rag_configuration() -> Optional[ConfigurationPayload]: """ Retrieves the current RAG configuration. @@ -124,7 +121,7 @@ def get_rag_configuration(base_url: str) -> Optional[ConfigurationPayload]: Returns: The RAG configuration, or None if there was an error. """ - url = f"{base_url}/leapfrogai/v1/rag/configure" + url = f"{ROOT_URL}/leapfrogai/v1/rag/configure" try: response = requests.get(url) @@ -142,10 +139,8 @@ def get_rag_configuration(base_url: str) -> Optional[ConfigurationPayload]: reason="LFAI_RUN_NIAH_TESTS envvar was not set to true", ) def test_rag_needle_haystack_with_reranking(): - base_url = os.getenv( - "LEAPFROGAI_API_URL", "https://leapfrogai-api.uds.dev/openai/v1" - ) - configure_rag(base_url, True, "flashrank", 100) - config_result = get_rag_configuration(base_url) + configure_rag(True, "flashrank", 100) + config_result = get_rag_configuration() + assert config_result is not None assert config_result.enable_reranking is True test_rag_needle_haystack() diff --git a/tests/utils/client.py b/tests/utils/client.py index 8411d5077..599445a66 100644 --- a/tests/utils/client.py +++ b/tests/utils/client.py @@ -3,6 +3,8 @@ LEAPFROGAI_MODEL = os.getenv("LEAPFROGAI_MODEL", "llama-cpp-python") OPENAI_MODEL = "gpt-4o-mini" +ROOT_URL = "https://leapfrogai-api.uds.dev" + def openai_client(): From 0db333edf9fa018eb241ff26469577ace3d3cab6 Mon Sep 17 00:00:00 2001 From: gharvey Date: Mon, 30 Sep 2024 21:39:14 -0700 Subject: [PATCH 57/59] Ruff linting --- tests/utils/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utils/client.py b/tests/utils/client.py index c25c8c3c4..6abdc8b01 100644 --- a/tests/utils/client.py +++ b/tests/utils/client.py @@ -8,6 +8,7 @@ OPENAI_MODEL = "gpt-4o-mini" ROOT_URL = "https://leapfrogai-api.uds.dev" + def get_leapfrogai_model() -> str: """Get the model to use for LeapfrogAI. From 2cbe4bdca9093207994d791b6adcc5cb9c31837e Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 1 Oct 2024 10:21:28 -0700 Subject: [PATCH 58/59] Fix unnecessary environment variables --- tests/utils/client.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/utils/client.py b/tests/utils/client.py index 6abdc8b01..6fe598514 100644 --- a/tests/utils/client.py +++ b/tests/utils/client.py @@ -4,10 +4,6 @@ import requests from requests import Response -LEAPFROGAI_MODEL = os.getenv("LEAPFROGAI_MODEL", "llama-cpp-python") -OPENAI_MODEL = "gpt-4o-mini" -ROOT_URL = "https://leapfrogai-api.uds.dev" - def get_leapfrogai_model() -> str: """Get the model to use for LeapfrogAI. From 01e77df25f4656cc93405737fb530bf0038a190d Mon Sep 17 00:00:00 2001 From: gharvey Date: Tue, 1 Oct 2024 11:35:56 -0700 Subject: [PATCH 59/59] Swaps env out with new helper function --- tests/integration/api/test_rag_files.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/api/test_rag_files.py b/tests/integration/api/test_rag_files.py index 864565a28..7520ddbcc 100644 --- a/tests/integration/api/test_rag_files.py +++ b/tests/integration/api/test_rag_files.py @@ -7,7 +7,7 @@ from tests.utils.data_path import data_path from leapfrogai_api.typedef.rag.rag_types import ConfigurationPayload -from tests.utils.client import ROOT_URL, client_config_factory +from tests.utils.client import client_config_factory, get_leapfrogai_api_url_base def make_test_assistant(client, model, vector_store_id): @@ -96,7 +96,7 @@ def configure_rag( ranking_model: The ranking model to use. rag_top_k_when_reranking: The top-k results to return before reranking. """ - url = f"{ROOT_URL}/leapfrogai/v1/rag/configure" + url = f"{get_leapfrogai_api_url_base()}/leapfrogai/v1/rag/configure" configuration = ConfigurationPayload( enable_reranking=enable_reranking, ranking_model=ranking_model, @@ -121,7 +121,7 @@ def get_rag_configuration() -> Optional[ConfigurationPayload]: Returns: The RAG configuration, or None if there was an error. """ - url = f"{ROOT_URL}/leapfrogai/v1/rag/configure" + url = f"{get_leapfrogai_api_url_base()}/leapfrogai/v1/rag/configure" try: response = requests.get(url)