Skip to content

Commit

Permalink
improve logs and model_interaction sources
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Jul 25, 2024
1 parent 1d0f057 commit 9ec53d2
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 56 deletions.
23 changes: 1 addition & 22 deletions core/cat/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,4 @@ class BaseAgent(ABC):

@abstractmethod
async def execute(*args, **kwargs) -> AgentOutput:
pass

# TODO: this is here to debug langchain, take it away
def _log_prompt(self, langchain_prompt, title):
print("\n")
print(get_colored_text(f"==================== {title} ====================", "green"))
for m in langchain_prompt.messages:
print(get_colored_text(type(m).__name__, "green"))
print(m.content)
print(get_colored_text("========================================", "green"))
return langchain_prompt

# TODO: this is here to debug langchain, take it away
def _log_output(self, langchain_output, title):
print("\n")
print(get_colored_text(f"==================== {title} ====================", "blue"))
if hasattr(langchain_output, 'content'):
print(langchain_output.content)
else:
print(langchain_output)
print(get_colored_text("========================================", "blue"))
return langchain_output
pass
1 change: 0 additions & 1 deletion core/cat/agents/form_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import traceback
from cat.experimental.form import CatForm, CatFormState
from cat.looking_glass.callbacks import NewTokenHandler, ModelInteractionHandler
from cat.agents.base_agent import BaseAgent, AgentOutput
from cat.log import log

Expand Down
4 changes: 2 additions & 2 deletions core/cat/agents/memory_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ async def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput:

chain = (
prompt
| RunnableLambda(lambda x: self._log_prompt(x, "MAIN PROMPT"))
| RunnableLambda(lambda x: utils.langchain_log_prompt(x, "MAIN PROMPT"))
| stray._llm
| RunnableLambda(lambda x: self._log_output(x, "MAIN PROMPT OUTPUT"))
| RunnableLambda(lambda x: utils.langchain_log_output(x, "MAIN PROMPT OUTPUT"))
| StrOutputParser()
)

Expand Down
4 changes: 2 additions & 2 deletions core/cat/agents/procedures_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ async def execute_chain(self, stray, procedures_prompt_template, allowed_procedu

chain = (
prompt
| RunnableLambda(lambda x: self._log_prompt(x, "TOOL PROMPT"))
| RunnableLambda(lambda x: utils.langchain_log_prompt(x, "TOOL PROMPT"))
| stray._llm
| RunnableLambda(lambda x: self._log_output(x, "TOOL PROMPT OUTPUT"))
| RunnableLambda(lambda x: utils.langchain_log_output(x, "TOOL PROMPT OUTPUT"))
| ChooseProcedureOutputParser() # ensures output is a LLMAction
)

Expand Down
3 changes: 1 addition & 2 deletions core/cat/factory/llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from langchain_openai import AzureChatOpenAI
from langchain_openai import AzureOpenAI
from langchain_community.llms import (
OpenAI,
HuggingFaceTextGenInference,
HuggingFaceEndpoint,
)
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, OpenAI
from langchain_cohere import ChatCohere
from langchain_google_genai import ChatGoogleGenerativeAI

Expand Down
55 changes: 28 additions & 27 deletions core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from typing import Literal, get_args, List, Dict, Union, Any

from langchain.docstore.document import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.llms import BaseLLM
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers.string import StrOutputParser
Expand All @@ -20,8 +18,7 @@
from cat.memory.working_memory import WorkingMemory
from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role, EmbedderModelInteraction
from cat.agents.base_agent import AgentOutput

from cat.utils import levenshtein_distance
from cat import utils

MSG_TYPES = Literal["notification", "chat", "error", "chat_token"]

Expand Down Expand Up @@ -292,30 +289,34 @@ def llm(self, prompt: str, stream: bool = False) -> str:
callbacks.append(NewTokenHandler(self))

# Add a token counter to the callbacks
callbacks.append(ModelInteractionHandler(self, self.__class__.__name__))
caller = utils.get_caller_info()
callbacks.append(ModelInteractionHandler(self, caller or "StrayCat"))


# TODO: add here optional convo history passed to the method, or taken from working memory
messages=[
HumanMessage(content=prompt)
]

# Check if self._llm is a completion model and generate a response
# TODOV2: do not support non-chat models
#if isinstance(self._llm, BaseLLM):
# log.critical("LLM")
# return self._llm.invoke(
# prompt,
# config=RunnableConfig(callbacks=callbacks)
# )

# Check if self._llm is a chat model and call it as a completion model
if True:#isinstance(self._llm, BaseChatModel):
log.critical("CHAT LLM")
return self._llm.invoke(
messages,
config=RunnableConfig(callbacks=callbacks)
).content # returns AIMessage
# here we deal with motherfucking langchain
prompt = ChatPromptTemplate(
# TODO: add here optional convo history passed to the method, or taken from working memory
messages=[
SystemMessage(content=prompt)
]
)

chain = (
prompt
| RunnableLambda(lambda x: utils.langchain_log_prompt(x, f"{caller} prompt"))
| self._llm
| RunnableLambda(lambda x: utils.langchain_log_output(x, f"{caller} prompt output"))
| StrOutputParser()
)

output = chain.invoke(
{}, # in case we need to pass info to the template
config=RunnableConfig(callbacks=callbacks)
)

return output


async def __call__(self, message_dict):
"""Call the Cat instance.
Expand Down Expand Up @@ -516,7 +517,7 @@ def classify(

# find the closest match and its score with levenshtein distance
best_label, score = min(
((label, levenshtein_distance(response, label)) for label in labels_names),
((label, utils.levenshtein_distance(response, label)) for label in labels_names),
key=lambda x: x[1],
)

Expand Down
39 changes: 39 additions & 0 deletions core/cat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import traceback
import inspect
from datetime import timedelta
from urllib.parse import urlparse
from typing import Dict, Tuple
Expand All @@ -10,6 +11,7 @@
from langchain.evaluation import StringDistance, load_evaluator, EvaluatorType
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.utils import get_colored_text

from cat.log import log
from cat.env import get_env
Expand Down Expand Up @@ -204,6 +206,43 @@ def match_prompt_variables(
return prompt_variables, prompt_template


def get_caller_info():
# go 2 steps up the stack
try:
calling_frame = inspect.currentframe()
grand_father_frame = calling_frame.f_back.f_back
grand_father_name = grand_father_frame.f_code.co_name

# check if the grand_father_frame is in a class method
if 'self' in grand_father_frame.f_locals:
return grand_father_frame.f_locals['self'].__class__.__name__ + "." + grand_father_name
return grand_father_name
except Exception as e:
log.error(e)
return None


def langchain_log_prompt(langchain_prompt, title):
print("\n")
print(get_colored_text(f"==================== {title} ====================", "green"))
for m in langchain_prompt.messages:
print(get_colored_text(type(m).__name__, "green"))
print(m.content)
print(get_colored_text("========================================", "green"))
return langchain_prompt


def langchain_log_output(langchain_output, title):
print("\n")
print(get_colored_text(f"==================== {title} ====================", "blue"))
if hasattr(langchain_output, 'content'):
print(langchain_output.content)
else:
print(langchain_output)
print(get_colored_text("========================================", "blue"))
return langchain_output


# This is our masterwork during tea time
class singleton:
instances = {}
Expand Down

0 comments on commit 9ec53d2

Please sign in to comment.