diff --git a/core/cat/agents/base_agent.py b/core/cat/agents/base_agent.py index c53dfe13..28f02f77 100644 --- a/core/cat/agents/base_agent.py +++ b/core/cat/agents/base_agent.py @@ -15,25 +15,4 @@ class BaseAgent(ABC): @abstractmethod async def execute(*args, **kwargs) -> AgentOutput: - pass - - # TODO: this is here to debug langchain, take it away - def _log_prompt(self, langchain_prompt, title): - print("\n") - print(get_colored_text(f"==================== {title} ====================", "green")) - for m in langchain_prompt.messages: - print(get_colored_text(type(m).__name__, "green")) - print(m.content) - print(get_colored_text("========================================", "green")) - return langchain_prompt - - # TODO: this is here to debug langchain, take it away - def _log_output(self, langchain_output, title): - print("\n") - print(get_colored_text(f"==================== {title} ====================", "blue")) - if hasattr(langchain_output, 'content'): - print(langchain_output.content) - else: - print(langchain_output) - print(get_colored_text("========================================", "blue")) - return langchain_output \ No newline at end of file + pass \ No newline at end of file diff --git a/core/cat/agents/form_agent.py b/core/cat/agents/form_agent.py index b0a22c46..93dc44b4 100644 --- a/core/cat/agents/form_agent.py +++ b/core/cat/agents/form_agent.py @@ -1,6 +1,5 @@ import traceback from cat.experimental.form import CatForm, CatFormState -from cat.looking_glass.callbacks import NewTokenHandler, ModelInteractionHandler from cat.agents.base_agent import BaseAgent, AgentOutput from cat.log import log diff --git a/core/cat/agents/memory_agent.py b/core/cat/agents/memory_agent.py index 4e6be1b2..84d8a040 100644 --- a/core/cat/agents/memory_agent.py +++ b/core/cat/agents/memory_agent.py @@ -30,9 +30,9 @@ async def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput: chain = ( prompt - | RunnableLambda(lambda x: self._log_prompt(x, "MAIN PROMPT")) + | RunnableLambda(lambda x: utils.langchain_log_prompt(x, "MAIN PROMPT")) | stray._llm - | RunnableLambda(lambda x: self._log_output(x, "MAIN PROMPT OUTPUT")) + | RunnableLambda(lambda x: utils.langchain_log_output(x, "MAIN PROMPT OUTPUT")) | StrOutputParser() ) diff --git a/core/cat/agents/procedures_agent.py b/core/cat/agents/procedures_agent.py index 588899d1..2729586b 100644 --- a/core/cat/agents/procedures_agent.py +++ b/core/cat/agents/procedures_agent.py @@ -123,9 +123,9 @@ async def execute_chain(self, stray, procedures_prompt_template, allowed_procedu chain = ( prompt - | RunnableLambda(lambda x: self._log_prompt(x, "TOOL PROMPT")) + | RunnableLambda(lambda x: utils.langchain_log_prompt(x, "TOOL PROMPT")) | stray._llm - | RunnableLambda(lambda x: self._log_output(x, "TOOL PROMPT OUTPUT")) + | RunnableLambda(lambda x: utils.langchain_log_output(x, "TOOL PROMPT OUTPUT")) | ChooseProcedureOutputParser() # ensures output is a LLMAction ) diff --git a/core/cat/factory/llm.py b/core/cat/factory/llm.py index 7041c40c..d2d85f08 100644 --- a/core/cat/factory/llm.py +++ b/core/cat/factory/llm.py @@ -1,11 +1,10 @@ from langchain_openai import AzureChatOpenAI from langchain_openai import AzureOpenAI from langchain_community.llms import ( - OpenAI, HuggingFaceTextGenInference, HuggingFaceEndpoint, ) -from langchain_openai import ChatOpenAI +from langchain_openai import ChatOpenAI, OpenAI from langchain_cohere import ChatCohere from langchain_google_genai import ChatGoogleGenerativeAI diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index 0a21a251..dc530a4c 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -5,9 +5,7 @@ from typing import Literal, get_args, List, Dict, Union, Any from langchain.docstore.document import Document -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_community.llms import BaseLLM -from langchain_core.messages import AIMessage, HumanMessage, BaseMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage from langchain_core.runnables import RunnableConfig, RunnableLambda from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers.string import StrOutputParser @@ -20,8 +18,7 @@ from cat.memory.working_memory import WorkingMemory from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role, EmbedderModelInteraction from cat.agents.base_agent import AgentOutput - -from cat.utils import levenshtein_distance +from cat import utils MSG_TYPES = Literal["notification", "chat", "error", "chat_token"] @@ -292,30 +289,34 @@ def llm(self, prompt: str, stream: bool = False) -> str: callbacks.append(NewTokenHandler(self)) # Add a token counter to the callbacks - callbacks.append(ModelInteractionHandler(self, self.__class__.__name__)) + caller = utils.get_caller_info() + callbacks.append(ModelInteractionHandler(self, caller or "StrayCat")) - # TODO: add here optional convo history passed to the method, or taken from working memory - messages=[ - HumanMessage(content=prompt) - ] - # Check if self._llm is a completion model and generate a response - # TODOV2: do not support non-chat models - #if isinstance(self._llm, BaseLLM): - # log.critical("LLM") - # return self._llm.invoke( - # prompt, - # config=RunnableConfig(callbacks=callbacks) - # ) - - # Check if self._llm is a chat model and call it as a completion model - if True:#isinstance(self._llm, BaseChatModel): - log.critical("CHAT LLM") - return self._llm.invoke( - messages, - config=RunnableConfig(callbacks=callbacks) - ).content # returns AIMessage + # here we deal with motherfucking langchain + prompt = ChatPromptTemplate( + # TODO: add here optional convo history passed to the method, or taken from working memory + messages=[ + SystemMessage(content=prompt) + ] + ) + + chain = ( + prompt + | RunnableLambda(lambda x: utils.langchain_log_prompt(x, f"{caller} prompt")) + | self._llm + | RunnableLambda(lambda x: utils.langchain_log_output(x, f"{caller} prompt output")) + | StrOutputParser() + ) + + output = chain.invoke( + {}, # in case we need to pass info to the template + config=RunnableConfig(callbacks=callbacks) + ) + + return output + async def __call__(self, message_dict): """Call the Cat instance. @@ -516,7 +517,7 @@ def classify( # find the closest match and its score with levenshtein distance best_label, score = min( - ((label, levenshtein_distance(response, label)) for label in labels_names), + ((label, utils.levenshtein_distance(response, label)) for label in labels_names), key=lambda x: x[1], ) diff --git a/core/cat/utils.py b/core/cat/utils.py index 7edf35d5..ba46745a 100644 --- a/core/cat/utils.py +++ b/core/cat/utils.py @@ -2,6 +2,7 @@ import os import traceback +import inspect from datetime import timedelta from urllib.parse import urlparse from typing import Dict, Tuple @@ -10,6 +11,7 @@ from langchain.evaluation import StringDistance, load_evaluator, EvaluatorType from langchain_core.output_parsers import JsonOutputParser from langchain_core.prompts import PromptTemplate +from langchain_core.utils import get_colored_text from cat.log import log from cat.env import get_env @@ -204,6 +206,43 @@ def match_prompt_variables( return prompt_variables, prompt_template +def get_caller_info(): + # go 2 steps up the stack + try: + calling_frame = inspect.currentframe() + grand_father_frame = calling_frame.f_back.f_back + grand_father_name = grand_father_frame.f_code.co_name + + # check if the grand_father_frame is in a class method + if 'self' in grand_father_frame.f_locals: + return grand_father_frame.f_locals['self'].__class__.__name__ + "." + grand_father_name + return grand_father_name + except Exception as e: + log.error(e) + return None + + +def langchain_log_prompt(langchain_prompt, title): + print("\n") + print(get_colored_text(f"==================== {title} ====================", "green")) + for m in langchain_prompt.messages: + print(get_colored_text(type(m).__name__, "green")) + print(m.content) + print(get_colored_text("========================================", "green")) + return langchain_prompt + + +def langchain_log_output(langchain_output, title): + print("\n") + print(get_colored_text(f"==================== {title} ====================", "blue")) + if hasattr(langchain_output, 'content'): + print(langchain_output.content) + else: + print(langchain_output) + print(get_colored_text("========================================", "blue")) + return langchain_output + + # This is our masterwork during tea time class singleton: instances = {}