From dd91d3f901d027772d56b76c80ba83e55708a82b Mon Sep 17 00:00:00 2001 From: michel-heon Date: Fri, 11 Oct 2024 16:52:55 -0400 Subject: [PATCH] Merge with main id 77ec531e13100519e3dd0582be1c612aedee9524 --- .../functions/request-handler/adapters/base/base.py | 9 +++------ .../functions/request-handler/adapters/bedrock/base.py | 10 +++++----- .../request-handler/adapters/bedrock/base_test.py | 8 +++----- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/lib/model-interfaces/langchain/functions/request-handler/adapters/base/base.py b/lib/model-interfaces/langchain/functions/request-handler/adapters/base/base.py index bf2a2294..85b31960 100644 --- a/lib/model-interfaces/langchain/functions/request-handler/adapters/base/base.py +++ b/lib/model-interfaces/langchain/functions/request-handler/adapters/base/base.py @@ -40,6 +40,7 @@ class Mode(Enum): CHAIN = "chain" + def get_guardrails() -> dict: if "BEDROCK_GUARDRAILS_ID" in os.environ: logger.debug("Guardrails ID found in environment variables.") @@ -593,12 +594,8 @@ def format(self, **kwargs: Any) -> str: # Register the adapters registry.register(r"^bedrock.ai21.jamba*", BedrockChatAdapter) -registry.register( - r"^bedrock.ai21.j2*", BedrockChatNoStreamingNoSystemPromptAdapter -) -registry.register( - r"^bedrock\.cohere\.command-(text|light-text).*", BedrockChatNoSystemPromptAdapter -) +registry.register(r"^bedrock.ai21.j2*", BedrockChatNoStreamingNoSystemPromptAdapter) +registry.register(r"^bedrock\.cohere\.command-(text|light-text).*", BedrockChatNoSystemPromptAdapter) registry.register(r"^bedrock\.cohere\.command-r.*", BedrockChatAdapter) registry.register(r"^bedrock.anthropic.claude*", BedrockChatAdapter) registry.register(r"^bedrock.meta.llama*", BedrockChatAdapter) diff --git a/lib/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base.py b/lib/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base.py index 365055c2..ee29b238 100644 --- a/lib/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base.py +++ b/lib/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base.py @@ -94,24 +94,24 @@ def get_condense_question_prompt(self): def get_llm(self, model_kwargs={}, extra={}): bedrock = genai_core.clients.get_bedrock_client() params = {} - + # Collect temperature, topP, and maxTokens if available temperature = model_kwargs.get("temperature") top_p = model_kwargs.get("topP") max_tokens = model_kwargs.get("maxTokens") - + if temperature: params["temperature"] = temperature if top_p: params["top_p"] = top_p if max_tokens: params["max_tokens"] = max_tokens - + # Fetch guardrails if any guardrails = get_guardrails() if len(guardrails.keys()) > 0: params["guardrails"] = guardrails - + # Log all parameters in a single log entry, including full guardrails logger.info( f"Creating LLM chain for model {self.model_id}", @@ -121,7 +121,7 @@ def get_llm(self, model_kwargs={}, extra={}): max_tokens=max_tokens, guardrails=guardrails, ) - + # Return ChatBedrockConverse instance with the collected params return ChatBedrockConverse( client=bedrock, diff --git a/tests/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base_test.py b/tests/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base_test.py index 208e6c85..23ea127f 100644 --- a/tests/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base_test.py +++ b/tests/model-interfaces/langchain/functions/request-handler/adapters/bedrock/base_test.py @@ -7,7 +7,6 @@ from adapters.shared.prompts.system_prompts import prompts # Ajout de l'importation - def test_registry(): with pytest.raises(ValueError, match="not found"): registry.get_adapter("invalid") @@ -37,7 +36,7 @@ def test_chat_adapter(mocker): result = model.get_qa_prompt().format( input="input", context="context", chat_history=[HumanMessage(content="history")] ) - # Mise à jour de l'assertion pour correspondre au prompt anglais dans system_prompts.py + assert "Use the following pieces of context" in result assert "Human: history" in result assert "Human: input" in result @@ -45,7 +44,7 @@ def test_chat_adapter(mocker): result = model.get_prompt().format( input="input", chat_history=[HumanMessage(content="history")] ) - # Mise à jour de l'assertion pour correspondre au prompt anglais dans system_prompts.py + assert "The following is a friendly conversation" in result assert "Human: history" in result assert "Human: input" in result @@ -53,7 +52,7 @@ def test_chat_adapter(mocker): result = model.get_condense_question_prompt().format( input="input", chat_history=[HumanMessage(content="history")] ) - # Mise à jour de l'assertion pour correspondre au prompt anglais dans system_prompts.py + assert "Given the conversation inside the tags" in result assert "Human: history" in result assert "Human: input" in result @@ -119,4 +118,3 @@ def test_chat_without_system_adapter(mocker): model="model", callbacks=ANY, ) -