diff --git a/README.md b/README.md index dd7e4b86..88fd052c 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ Follow instructions on how to run it with [docker compose and volumes](https://c ```python from cat.mad_hatter.decorators import tool, hook +# hooks are an event system to get finegraned control over your assistant @hook def agent_prompt_prefix(prefix, cat): prefix = """You are Marvin the socks seller, a poetic vendor of socks. @@ -56,7 +57,7 @@ You are an expert in socks, and you reply with exactly one rhyme. """ return prefix - +# langchain inspired tools (function calling) @tool(return_direct=True) def socks_prices(color, cat): """How much do socks cost? Input is the sock color.""" @@ -65,10 +66,45 @@ def socks_prices(color, cat): "white": 10, "pink": 50, } - if color not in prices.keys(): - return f"No {color} socks" - else: - return f"{prices[color]} โ‚ฌ" + + price = prices.get(color, 0) + return f"{price} bucks, meeeow!" +``` + +## Conversational form example + +```python +from pydantic import BaseModel +from cat.experimental.form import form, CatForm + +# data structure to fill up +class PizzaOrder(BaseModel): + pizza_type: str + phone: int + +# forms let you control goal oriented conversations +@form +class PizzaForm(CatForm): + description = "Pizza Order" + model_class = PizzaOrder + start_examples = [ + "order a pizza!", + "I want pizza" + ] + stop_examples = [ + "stop pizza order", + "not hungry anymore", + ] + ask_confirm = True + + def submit(self, form_data): + + # do the actual order here! + + # return to convo + return { + "output": f"Pizza order on its way: {form_data}" + } ``` ## Docs and Resources @@ -82,10 +118,11 @@ def socks_prices(color, cat): ## Why use the Cat - โšก๏ธ API first, so you get a microservice to easily add a conversational layer to your app -- ๐Ÿš€ Extensible via plugins (AI can connect to your APIs or execute custom python code) -- ๐Ÿ› Easy to use admin panel -- ๐ŸŒ Supports any language model (works with OpenAI, Google, Ollama, HuggingFace, custom services) - ๐Ÿ˜ Remembers conversations and documents and uses them in conversation +- ๐Ÿš€ Extensible via plugins (public plugin registry + private plugins allowed) +- ๐ŸŽš Event callbacks, function calling (tools), conversational forms +- ๐Ÿ› Easy to use admin panel (chat, visualize memory and plugins, adjust settings) +- ๐ŸŒ Supports any language model (works with OpenAI, Google, Ollama, HuggingFace, custom services) - ๐Ÿ‹ Production ready - 100% [dockerized](https://docs.docker.com/get-docker/) - ๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆ Active [Discord community](https://discord.gg/bHX5sNFCYU) and easy to understand [docs](https://cheshire-cat-ai.github.io/docs/) diff --git a/core/cat/experimental/__init__.py b/core/cat/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/core/cat/experimental/form/__init__.py b/core/cat/experimental/form/__init__.py new file mode 100644 index 00000000..f43f2c55 --- /dev/null +++ b/core/cat/experimental/form/__init__.py @@ -0,0 +1,2 @@ +from .cat_form import CatForm, CatFormState +from .form_decorator import form \ No newline at end of file diff --git a/core/cat/experimental/form/cat_form.py b/core/cat/experimental/form/cat_form.py new file mode 100644 index 00000000..6c19b56f --- /dev/null +++ b/core/cat/experimental/form/cat_form.py @@ -0,0 +1,318 @@ +from typing import List, Dict +from dataclasses import dataclass +from pydantic import BaseModel, ConfigDict, ValidationError + +from langchain.chains import LLMChain +from langchain_core.prompts.prompt import PromptTemplate + +#from cat.looking_glass.prompts import MAIN_PROMPT_PREFIX +from enum import Enum +from cat.log import log +import json + + +# Conversational Form State +class CatFormState(Enum): + INCOMPLETE = "incomplete" + COMPLETE = "complete" + WAIT_CONFIRM = "wait_confirm" + CLOSED = "closed" + + +class CatForm: # base model of forms + + model_class: BaseModel + procedure_type: str = "form" + name: str = None + description: str + start_examples: List[str] + stop_examples: List[str] = [] + ask_confirm: bool = False + triggers_map = None + _autopilot = False + + def __init__(self, cat) -> None: + self._state = CatFormState.INCOMPLETE + self._model: Dict = {} + + self._cat = cat + + self._errors: List[str] = [] + self._missing_fields: List[str] = [] + + @property + def cat(self): + return self._cat + + def submit(self, form_data) -> str: + raise NotImplementedError + + # Check user confirm the form data + def confirm(self) -> bool: + + # Get user message + user_message = self.cat.working_memory["user_message_json"]["text"] + + # Confirm prompt + confirm_prompt = \ +f"""Your task is to produce a JSON representing whether a user is confirming or not. +JSON must be in this format: +```json +{{ + "confirm": // type boolean, must be `true` or `false` +}} +``` + +User said "{user_message}" + +JSON: +```json +{{ + "confirm": """ + + # Queries the LLM and check if user is agree or not + response = self.cat.llm(confirm_prompt, stream=True) + return "true" in response.lower() + + # Check if the user wants to exit the form + # it is run at the befginning of every form.next() + def check_exit_intent(self) -> bool: + + # Get user message + history = self.stringify_convo_history() + + # Stop examples + stop_examples = """ +Examples where {"exit": true}: +- exit form +- stop it""" + + for se in self.stop_examples: + stop_examples += f"\n- {se}" + + # Check exit prompt + check_exit_prompt = \ +f"""Your task is to produce a JSON representing whether a user wants to exit or not. +JSON must be in this format: +```json +{{ + "exit": // type boolean, must be `true` or `false` +}} +``` + +{stop_examples} + +This is the conversation: + +{history} + +JSON: +```json +{{ + "exit": """ + + # Queries the LLM and check if user is agree or not + response = self.cat.llm(check_exit_prompt, stream=True) + return "true" in response.lower() + + # Execute the dialogue step + def next(self): + + # could we enrich prompt completion with episodic/declarative memories? + #self.cat.working_memory["episodic_memories"] = [] + + if self.check_exit_intent(): + self._state = CatFormState.CLOSED + + # If state is WAIT_CONFIRM, check user confirm response.. + if self._state == CatFormState.WAIT_CONFIRM: + if self.confirm(): + self._state = CatFormState.CLOSED + return self.submit(self._model) + else: + self._state = CatFormState.INCOMPLETE + + # If the state is INCOMPLETE, execute model update + # (and change state based on validation result) + if self._state == CatFormState.INCOMPLETE: + self._model = self.update() + + # If state is COMPLETE, ask confirm (or execute action directly) + if self._state == CatFormState.COMPLETE: + if self.ask_confirm: + self._state = CatFormState.WAIT_CONFIRM + else: + self._state = CatFormState.CLOSED + return self.submit(self._model) + + # if state is still INCOMPLETE, recap and ask for new info + return self.message() + + # Updates the form with the information extracted from the user's response + # (Return True if the model is updated) + def update(self): + + # Conversation to JSON + json_details = self.extract() + json_details = self.sanitize(json_details) + + # model merge old and new + new_model = self._model | json_details + + # Validate new_details + new_model = self.validate(new_model) + + return new_model + + def message(self): + + if self._state == CatFormState.CLOSED: + return { + "output": f"Form {type(self).__name__} closed" + } + + separator = "\n - " + missing_fields = "" + if self._missing_fields: + missing_fields = "\nMissing fields:" + missing_fields += separator + separator.join(self._missing_fields) + invalid_fields = "" + if self._errors: + invalid_fields = "\nInvalid fields:" + invalid_fields += separator + separator.join(self._errors) + + out = f"""Info until now: + +```json +{json.dumps(self._model, indent=4)} +``` +{missing_fields} +{invalid_fields} +""" + + if self._state == CatFormState.WAIT_CONFIRM: + out += "\n --> Confirm? Yes or no?" + + return { + "output": out + } + + def stringify_convo_history(self): + + user_message = self.cat.working_memory["user_message_json"]["text"] + chat_history = self.cat.working_memory["history"][-10:] # last n messages + + # stringify history + history = "" + for turn in chat_history: + history += f"\n - {turn['who']}: {turn['message']}" + history += f"Human: {user_message}" + + return history + + # Extract model informations from user message + def extract(self): + + prompt = self.extraction_prompt() + log.debug(prompt) + + # Invoke LLM chain + extraction_chain = LLMChain( + prompt = PromptTemplate.from_template(prompt), + llm = self._cat._llm, + verbose = True, + output_key = "output" + ) + json_str = extraction_chain.invoke({"stop": ["```"]})["output"] + + log.debug(f"Form JSON after parser:\n{json_str}") + + # json parser + try: + output_model = json.loads(json_str) + except Exception as e: + output_model = {} + log.warning(e) + + return output_model + + def extraction_prompt(self): + + history = self.stringify_convo_history() + + # JSON structure + # BaseModel.__fields__['my_field'].type_ + JSON_structure = "{" + for field_name, field in self.model_class.model_fields.items(): + if field.description: + description = field.description + else: + description = "" + JSON_structure += f'\n\t"{field_name}": // {description} Must be of type `{field.annotation.__name__}` or `null`' # field.required? + JSON_structure += "\n}" + + # TODO: reintroduce examples + prompt = \ +f"""Your task is to fill up a JSON out of a conversation. +The JSON must have this format: +```json +{JSON_structure} +``` + +This is the current JSON: +```json +{json.dumps(self._model, indent=4)} +``` + +This is the conversation: + +{history} + +Updated JSON: +```json +""" + + # TODO: convo example (optional but supported) + + prompt_escaped = prompt.replace("{", "{{").replace("}", "}}") + return prompt_escaped + + # Sanitize model (take away unwanted keys and null values) + # NOTE: unwanted keys are automatically taken away by pydantic + def sanitize(self, model): + + # preserve only non-null fields + null_fields = [None, '', 'None', 'null', 'lower-case', 'unknown', 'missing'] + model = {key: value for key, value in model.items() if value not in null_fields} + + return model + + # Validate model + def validate(self, model): + + self._missing_fields = [] + self._errors = [] + + try: + # INFO TODO: In this case the optional fields are always ignored + + # Attempts to create the model object to update the default values and validate it + model = self.model_class(**model).model_dump(mode="json") + + # If model is valid change state to COMPLETE + self._state = CatFormState.COMPLETE + + except ValidationError as e: + # Collect ask_for and errors messages + for error_message in e.errors(): + field_name = error_message['loc'][0] + if error_message['type'] == 'missing': + self._missing_fields.append(field_name) + else: + self._errors.append(f'{field_name}: {error_message["msg"]}') + del model[field_name] + + # Set state to INCOMPLETE + self._state = CatFormState.INCOMPLETE + + return model \ No newline at end of file diff --git a/core/cat/experimental/form/form_decorator.py b/core/cat/experimental/form/form_decorator.py new file mode 100644 index 00000000..430568a7 --- /dev/null +++ b/core/cat/experimental/form/form_decorator.py @@ -0,0 +1,17 @@ +from .cat_form import CatForm + +# form decorator +def form(Form: CatForm) -> CatForm: + Form._autopilot = True + if Form.name is None: + Form.name = Form.__name__ + + if Form.triggers_map is None: + Form.triggers_map = { + "start_example": Form.start_examples, + "description": [ + f"{Form.name}: {Form.description}" + ] + } + + return Form diff --git a/core/cat/factory/custom_embedder.py b/core/cat/factory/custom_embedder.py index de01ed8d..0c745302 100644 --- a/core/cat/factory/custom_embedder.py +++ b/core/cat/factory/custom_embedder.py @@ -1,8 +1,10 @@ +import re import os import string import json from typing import List from itertools import combinations +from collections import OrderedDict from sklearn.feature_extraction.text import CountVectorizer from langchain_core.embeddings import Embeddings import httpx @@ -19,23 +21,26 @@ class DumbEmbedder(Embeddings): This class relies on the `CountVectorizer`[1]_ offered by Scikit-learn. This embedder uses a naive approach to extract features from a text and build an embedding vector. Namely, it looks for pairs of characters in text starting form a vocabulary with all possible pairs of - printable characters, digits excluded. - + printable characters, digits excluded. """ def __init__(self): + # Get all printable characters numbers excluded and make everything lowercase chars = [p.lower() for p in string.printable[10:]] # Make the vocabulary with all possible combinations of 2 characters - voc = {f"{k[0]}{k[1]}": v for v, k in enumerate(combinations(chars, 2))} - - # Re-index the tokens - for i, k in enumerate(voc.keys()): - voc[k] = i + voc = [] + for k in combinations(chars, 2): + voc.append(f"{k[0]}{k[1]}") + voc = sorted(set(voc)) # Naive embedder that counts occurrences of couple of characters in text - self.embedder = CountVectorizer(vocabulary=voc, analyzer="char_wb", ngram_range=(2, 2)) + self.embedder = CountVectorizer( + vocabulary=voc, + analyzer=lambda s: re.findall("..", s), + binary=True + ) def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed a list of text and returns the embedding vectors that are lists of floats.""" @@ -43,7 +48,8 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_query(self, text: str) -> List[float]: """Embed a string of text and returns the embedding vector as a list of floats.""" - return self.embedder.transform([text]).astype(float).todense().tolist()[0] + return self.embed_documents([text])[0] + class CustomOpenAIEmbeddings(Embeddings): diff --git a/core/cat/factory/embedder.py b/core/cat/factory/embedder.py index 716d053d..ad05020d 100644 --- a/core/cat/factory/embedder.py +++ b/core/cat/factory/embedder.py @@ -3,8 +3,12 @@ import langchain from pydantic import BaseModel, ConfigDict, Field -from langchain_community.embeddings import FakeEmbeddings, FastEmbedEmbeddings, CohereEmbeddings -from langchain_openai import OpenAIEmbeddings +from langchain_community.embeddings import ( + FakeEmbeddings, + FastEmbedEmbeddings, + CohereEmbeddings, +) +from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings from fastembed.embedding import Embedding from cat.factory.custom_embedder import DumbEmbedder, CustomOpenAIEmbeddings @@ -18,9 +22,7 @@ class EmbedderSettings(BaseModel): # This is related to pydantic, because "model_*" attributes are protected. # We deactivate the protection because langchain relies on several "model_*" named attributes - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) # instantiate an Embedder from configuration @classmethod @@ -88,12 +90,12 @@ class EmbedderOpenAIConfig(EmbedderSettings): class EmbedderAzureOpenAIConfig(EmbedderSettings): openai_api_key: str model: str - openai_api_base: str + azure_endpoint: str openai_api_type: str openai_api_version: str deployment: str - _pyclass: Type = OpenAIEmbeddings + _pyclass: Type = AzureOpenAIEmbeddings model_config = ConfigDict( json_schema_extra={ @@ -119,8 +121,13 @@ class EmbedderCohereConfig(EmbedderSettings): # Enum for menu selection in the admin! -FastEmbedModels = Enum("FastEmbedModels", {item['model'].replace('/', '_').replace('-', '_'): item["model"] for item in - Embedding.list_supported_models()}) +FastEmbedModels = Enum( + "FastEmbedModels", + { + item["model"].replace("/", "_").replace("-", "_"): item["model"] + for item in Embedding.list_supported_models() + }, +) class EmbedderQdrantFastEmbedConfig(EmbedderSettings): @@ -144,6 +151,7 @@ class EmbedderGeminiChatConfig(EmbedderSettings): This class contains the configuration for the Gemini Embedder. """ + google_api_key: str model: str = "models/embedding-001" # Default model https://python.langchain.com/docs/integrations/text_embedding/google_generative_ai @@ -171,12 +179,14 @@ def get_allowed_embedder_models(): ] mad_hatter_instance = MadHatter() - list_embedder = mad_hatter_instance.execute_hook("factory_allowed_embedders", list_embedder_default, cat=None) + list_embedder = mad_hatter_instance.execute_hook( + "factory_allowed_embedders", list_embedder_default, cat=None + ) return list_embedder def get_embedder_from_name(name_embedder: str): - """ Find the llm adapter class by name""" + """Find the llm adapter class by name""" for cls in get_allowed_embedder_models(): if cls.__name__ == name_embedder: return cls diff --git a/core/cat/factory/llm.py b/core/cat/factory/llm.py index a931c289..62e627a8 100644 --- a/core/cat/factory/llm.py +++ b/core/cat/factory/llm.py @@ -1,5 +1,11 @@ from langchain_community.chat_models import AzureChatOpenAI -from langchain_community.llms import OpenAI, AzureOpenAI, Cohere, Ollama, HuggingFaceTextGenInference, HuggingFaceEndpoint +from langchain_community.llms import ( + OpenAI, + AzureOpenAI, + Cohere, + HuggingFaceTextGenInference, + HuggingFaceEndpoint, +) from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI @@ -19,9 +25,7 @@ class LLMSettings(BaseModel): # This is related to pydantic, because "model_*" attributes are protected. # We deactivate the protection because langchain relies on several "model_*" named attributes - model_config = ConfigDict( - protected_namespaces=() - ) + model_config = ConfigDict(protected_namespaces=()) # instantiate an LLM from configuration @classmethod @@ -39,9 +43,8 @@ class LLMDefaultConfig(LLMSettings): model_config = ConfigDict( json_schema_extra={ "humanReadableName": "Default Language Model", - "description": - "A dumb LLM just telling that the Cat is not configured. " - "There will be a nice LLM here once consumer hardware allows it.", + "description": "A dumb LLM just telling that the Cat is not configured. " + "There will be a nice LLM here once consumer hardware allows it.", "link": "", } ) @@ -97,7 +100,7 @@ class LLMOpenAICompatibleConfig(LLMSettings): class LLMOpenAIChatConfig(LLMSettings): openai_api_key: str model_name: str = "gpt-3.5-turbo" - temperature: float = 0.7 # default value, from 0 to 1. Higher value create more creative and randomic answers, lower value create more focused and deterministc answers + temperature: float = 0.7 # default value, from 0 to 1. Higher value create more creative and randomic answers streaming: bool = True _pyclass: Type = ChatOpenAI @@ -113,7 +116,7 @@ class LLMOpenAIChatConfig(LLMSettings): class LLMOpenAIConfig(LLMSettings): openai_api_key: str model_name: str = "gpt-3.5-turbo-instruct" # used instead of text-davinci-003 since it deprecated - temperature: float = 0.7 # default value, from 0 to 1. Higher value create more creative and randomic answers, lower value create more focused and deterministc answers + temperature: float = 0.7 # default value, from 0 to 1. Higher value create more creative and randomic answers streaming: bool = True _pyclass: Type = OpenAI @@ -130,12 +133,12 @@ class LLMOpenAIConfig(LLMSettings): class LLMAzureChatOpenAIConfig(LLMSettings): openai_api_key: str model_name: str = "gpt-35-turbo" # or gpt-4, use only chat models ! - openai_api_base: str + azure_endpoint: str openai_api_type: str = "azure" # Dont mix api versions https://github.com/hwchase17/langchain/issues/4775 openai_api_version: str = "2023-05-15" - deployment_name: str + azure_deployment: str streaming: bool = True _pyclass: Type = AzureChatOpenAI @@ -151,13 +154,13 @@ class LLMAzureChatOpenAIConfig(LLMSettings): # https://python.langchain.com/en/latest/modules/models/llms/integrations/azure_openai_example.html class LLMAzureOpenAIConfig(LLMSettings): openai_api_key: str - openai_api_base: str + azure_endpoint: str api_type: str = "azure" # https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#completions # Current supported versions 2022-12-01, 2023-03-15-preview, 2023-05-15 # Don't mix api versions: https://github.com/hwchase17/langchain/issues/4775 api_version: str = "2023-05-15" - deployment_name: str = "gpt-35-turbo-instruct" # Model "comming soon" according to microsoft + azure_deployment: str = "gpt-35-turbo-instruct" model_name: str = "gpt-35-turbo-instruct" # Use only completion models ! streaming: bool = True _pyclass: Type = AzureOpenAI @@ -205,11 +208,16 @@ class LLMHuggingFaceTextGenInferenceConfig(LLMSettings): } ) - +# https://api.python.langchain.com/en/latest/llms/langchain_community.llms.huggingface_endpoint.HuggingFaceEndpoint.html class LLMHuggingFaceEndpointConfig(LLMSettings): endpoint_url: str huggingfacehub_api_token: str - task: str = "text2text-generation" + task: str = "text-generation" + max_new_tokens: int = 512 + top_k: int = None + top_p: float = 0.95 + temperature: float = 0.8 + return_full_text: bool = False _pyclass: Type = HuggingFaceEndpoint model_config = ConfigDict( @@ -220,10 +228,12 @@ class LLMHuggingFaceEndpointConfig(LLMSettings): } ) -# monkey patch to fix stops sequences -ollama_fix: Type = CustomOllama -ollama_fix._create_stream = _create_stream_patch -ollama_fix._acreate_stream = _acreate_stream_patch + +# monkey patch to fix stops sequences +OllamaFix: Type = CustomOllama +OllamaFix._create_stream = _create_stream_patch +OllamaFix._acreate_stream = _acreate_stream_patch + class LLMOllamaConfig(LLMSettings): base_url: str @@ -233,7 +243,7 @@ class LLMOllamaConfig(LLMSettings): repeat_penalty: float = 1.1 temperature: float = 0.8 - _pyclass: Type = ollama_fix + _pyclass: Type = OllamaFix model_config = ConfigDict( json_schema_extra={ @@ -251,18 +261,19 @@ class LLMGeminiChatConfig(LLMSettings): * `google_api_key`: The Google API key used to access the Google Natural Language Processing (NLP) API. * `model`: The name of the LLM model to use. In this case, it is set to "gemini". - * `temperature`: The temperature of the model, which controls the creativity and variety of the generated responses. - * `top_p`: The top-p truncation value, which controls the probability of the generated words. - * `top_k`: The top-k truncation value, which controls the number of candidate words to consider during generation. + * `temperature`: The temperature of the model, which controls the creativity and variety of the generated responses. + * `top_p`: The top-p truncation value, which controls the probability of the generated words. + * `top_k`: The top-k truncation value, which controls the number of candidate words to consider during generation. * `max_output_tokens`: The maximum number of tokens to generate in a single response. The `LLMGeminiChatConfig` class is used to create an instance of the Gemini LLM model, which can be used to generate text in natural language. """ - google_api_key: str + + google_api_key: str model: str = "gemini-pro" - temperature: float = 0.1 + temperature: float = 0.1 top_p: int = 1 - top_k: int = 1 + top_k: int = 1 max_output_tokens: int = 29000 _pyclass: Type = ChatGoogleGenerativeAI @@ -276,9 +287,8 @@ class LLMGeminiChatConfig(LLMSettings): ) - def get_allowed_language_models(): - + list_llms_default = [ LLMOpenAIChatConfig, LLMOpenAIConfig, @@ -293,14 +303,16 @@ def get_allowed_language_models(): LLMCustomConfig, LLMDefaultConfig, ] - + mad_hatter_instance = MadHatter() - list_llms = mad_hatter_instance.execute_hook("factory_allowed_llms", list_llms_default, cat=None) + list_llms = mad_hatter_instance.execute_hook( + "factory_allowed_llms", list_llms_default, cat=None + ) return list_llms def get_llm_from_name(name_llm: str): - """ Find the llm adapter class by name""" + """Find the llm adapter class by name""" for cls in get_allowed_language_models(): if cls.__name__ == name_llm: return cls diff --git a/core/cat/headers.py b/core/cat/headers.py index e38ff74a..2151643a 100644 --- a/core/cat/headers.py +++ b/core/cat/headers.py @@ -63,5 +63,5 @@ def session(request: Request) -> str: event_loop = request.app.state.event_loop if user_id not in strays.keys(): - strays[user_id] = StrayCat(user_id=user_id, event_loop=event_loop) + strays[user_id] = StrayCat(user_id=user_id, main_loop=event_loop) return strays[user_id] \ No newline at end of file diff --git a/core/cat/log.py b/core/cat/log.py index 6deac930..42f0aaf4 100644 --- a/core/cat/log.py +++ b/core/cat/log.py @@ -202,9 +202,11 @@ def log(self, msg, level="DEBUG"): ) # prettify - # TODO: newlines lose coloring :( - if type(msg) in [dict, list, str]: - msg = json.dumps(msg, indent=4) + if type(msg) in [dict, list, str]: # TODO: should be recursive + try: + msg = json.dumps(msg, indent=4) + except: + pass else: msg = pformat(msg) diff --git a/core/cat/looking_glass/agent_manager.py b/core/cat/looking_glass/agent_manager.py index 32559e5c..ee7e134a 100644 --- a/core/cat/looking_glass/agent_manager.py +++ b/core/cat/looking_glass/agent_manager.py @@ -6,18 +6,23 @@ from copy import deepcopy +from langchain_core.runnables import RunnableConfig from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain.chains import LLMChain from langchain.agents import AgentExecutor, LLMSingleActionAgent +from cat.mad_hatter.plugin import Plugin from cat.mad_hatter.mad_hatter import MadHatter +from cat.mad_hatter.decorators.tool import CatTool from cat.looking_glass import prompts from cat.looking_glass.callbacks import NewTokenHandler -from cat.looking_glass.output_parser import ToolOutputParser +from cat.looking_glass.output_parser import ChooseProcedureOutputParser, AgentAction, AgentFinish from cat.utils import verbal_timedelta from cat.log import log +from cat.experimental.form import CatForm, CatFormState + class AgentManager: """Manager of Langchain Agent. @@ -39,21 +44,40 @@ def __init__(self): else: self.verbose = False + async def execute_procedures_agent(self, agent_input, stray): - def execute_tool_agent(self, agent_input, allowed_tools, stray): - - # fix tools so they have an instance of the cat - allowed_tools_copy = deepcopy(allowed_tools) - for t in allowed_tools_copy: - # Prepare the tool to be used in the Cat (adding properties) - t.assign_cat(stray) - - allowed_tools_names = [t.name for t in allowed_tools_copy] - # TODO: dynamic input_variables as in the main prompt + # gather recalled procedures + recalled_procedures_names = set() + for p in stray.working_memory["procedural_memories"]: + procedure = p[0] + if procedure.metadata["type"] in ["tool","form"] and procedure.metadata["trigger_type"] in ["description", "start_example"]: + recalled_procedures_names.add(procedure.metadata["source"]) + # Get tools with that name from mad_hatter + allowed_procedures: Dict[str, CatTool | CatForm] = {} + allowed_tools: List[CatTool] = [] + return_direct_tools: List[str] = [] + + for p in self.mad_hatter.procedures: + + if p.name in recalled_procedures_names: + # Prepare the tool to be used in the Cat (adding properties) + if Plugin._is_cat_tool(p): + tool = deepcopy(p) + tool.assign_cat(stray) + allowed_tools.append(tool) + allowed_procedures[tool.name] = tool + + # cache if the tool is return_direct + if p.return_direct: + return_direct_tools.append(tool.name) + else: + # form + allowed_procedures[p.name] = p + prompt = prompts.ToolPromptTemplate( template = self.mad_hatter.execute_hook("agent_prompt_instructions", prompts.TOOL_PROMPT, cat=stray), - tools=allowed_tools_copy, + procedures=allowed_procedures, # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically # This includes the `intermediate_steps` variable because it is needed to fill the scratchpad input_variables=["input", "intermediate_steps"] @@ -69,25 +93,63 @@ def execute_tool_agent(self, agent_input, allowed_tools, stray): # init agent agent = LLMSingleActionAgent( llm_chain=agent_chain, - output_parser=ToolOutputParser(), + output_parser=ChooseProcedureOutputParser(), stop=["\nObservation:"], - allowed_tools=allowed_tools_names, verbose=self.verbose ) # agent executor agent_executor = AgentExecutor.from_agent_and_tools( agent=agent, - tools=allowed_tools_copy, + tools=allowed_tools, return_intermediate_steps=True, verbose=self.verbose ) - out = agent_executor(agent_input) + # agent RUN + out = await agent_executor.ainvoke(agent_input) + + # Extract intermediate steps in the format ((tool_name, tool_input), output) + # Also check if we have a return_direct tool + # TODO: only works with tools at the moment + out["return_direct"] = False + intermediate_steps = [] + for step in out.get("intermediate_steps", []): + intermediate_steps.append( + ((step[0].tool, step[0].tool_input), step[1]) + ) + + # If a tool was decorated with return_direct=True, indicate it in output + if step[0].tool in return_direct_tools: + out["return_direct"] = True + out["intermediate_steps"] = intermediate_steps + + # if a form was selected, build it and store it in working memory + if "form" in out.keys(): + FormClass = allowed_procedures.get(out["form"], None) + f = FormClass(stray) + stray.working_memory["forms"] = f + # let the form reply directly + out = f.next() + out["return_direct"] = True + return out - - def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, stray): + async def execute_form_agent(self, stray): + + active_form = stray.working_memory.get("forms", None) + if active_form: + log.warning(active_form._state) + # closing form if state is closed + if active_form._state == CatFormState.CLOSED: + del stray.working_memory["forms"] + else: + # continue form + return active_form.next() + + return None # no active form + + async def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, stray): input_variables = [i for i in agent_input.keys() if i in prompt_prefix + prompt_suffix] # memory chain (second step) @@ -99,100 +161,78 @@ def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix, stray) memory_chain = LLMChain( prompt=memory_prompt, llm=stray._llm, - verbose=self.verbose + verbose=self.verbose, + output_key="output" ) - out = memory_chain(agent_input, callbacks=[NewTokenHandler(stray)]) - out["output"] = out["text"] - del out["text"] - return out - + return await memory_chain.ainvoke(agent_input, config=RunnableConfig(callbacks=[NewTokenHandler(stray)])) - def execute_agent(self, stray): + async def execute_agent(self, stray): """Instantiate the Agent with tools. - The method formats the main prompt and gather the allowed tools. It also instantiates a conversational Agent + The method formats the main prompt and gather the allowed tools/forms. It also instantiates a conversational Agent from Langchain. Returns ------- - agent_executor : AgentExecutor - Instance of the Agent provided with a set of tools. + agent_executor : agent reply + Reply of the Agent in the format `{"output": ..., "intermediate_steps": ...}`. """ # prepare input to be passed to the agent. # Info will be extracted from working memory agent_input = self.format_agent_input(stray.working_memory) agent_input = self.mad_hatter.execute_hook("before_agent_starts", agent_input, cat=stray) + # should we run the default agent? fast_reply = {} fast_reply = self.mad_hatter.execute_hook("agent_fast_reply", fast_reply, cat=stray) if len(fast_reply.keys()) > 0: return fast_reply + + # obtain prompt parts from plugins prompt_prefix = self.mad_hatter.execute_hook("agent_prompt_prefix", prompts.MAIN_PROMPT_PREFIX, cat=stray) prompt_suffix = self.mad_hatter.execute_hook("agent_prompt_suffix", prompts.MAIN_PROMPT_SUFFIX, cat=stray) - - - # tools currently recalled in working memory - recalled_tools = stray.working_memory["procedural_memories"] - # Get the tools names only - tools_names = [t[0].metadata["name"] for t in recalled_tools] - tools_names = self.mad_hatter.execute_hook("agent_allowed_tools", tools_names, cat=stray) - # Get tools with that name from mad_hatter - allowed_tools = [i for i in self.mad_hatter.tools if i.name in tools_names] - - # Try to get information from tools if there is some allowed - if len(allowed_tools) > 0: - - log.debug(f"{len(allowed_tools)} allowed tools retrived.") + + # Run active form if present + form_result = await self.execute_form_agent(stray) + if form_result: + return form_result # exit agent with form output + + # Select and run useful procedures + intermediate_steps = [] + procedural_memories = stray.working_memory["procedural_memories"] + if len(procedural_memories) > 0: + + log.debug(f"Procedural memories retrived: {len(procedural_memories)}.") try: - tools_result = self.execute_tool_agent(agent_input, allowed_tools, stray) - - # If tools_result["output"] is None the LLM has used the fake tool none_of_the_others - # so no relevant information has been obtained from the tools. - if tools_result["output"] is not None: - - # Extract of intermediate steps in the format ((tool_name, tool_input), output) - used_tools = list(map(lambda x:((x[0].tool, x[0].tool_input), x[1]), tools_result["intermediate_steps"])) - - # Get the name of the tools that have return_direct - return_direct_tools = [] - for t in allowed_tools: - if t.return_direct: - return_direct_tools.append(t.name) - - # execute_tool_agent returns immediately when a tool with return_direct is called, - # so if one is used it is definitely the last one used - if used_tools[-1][0][0] in return_direct_tools: - # intermediate_steps still contains the information of all the tools used even if their output is not returned - tools_result["intermediate_steps"] = used_tools - return tools_result - - #Adding the tools_output key in agent input, needed by the memory chain - agent_input["tools_output"] = "## Tools output: \n" + tools_result["output"] if tools_result["output"] else "" - - # Execute the memory chain - out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, stray) - - # If some tools are used the intermediate step are added to the agent output - out["intermediate_steps"] = used_tools - - #Early return - return out - + procedures_result = await self.execute_procedures_agent(agent_input, stray) + if procedures_result.get("return_direct"): + # exit agent if a return_direct procedure was executed + return procedures_result + + # Adding the tools_output key in agent input, needed by the memory chain + if procedures_result.get("output"): + agent_input["tools_output"] = "## Tools output: \n" + procedures_result["output"] + + # store intermediate steps to enrich memory chain + intermediate_steps = procedures_result["intermediate_steps"] + except Exception as e: log.error(e) traceback.print_exc() - #If an exeption occur in the execute_tool_agent or there is no allowed tools execute only the memory chain - - #Adding the tools_output key in agent input, needed by the memory chain - agent_input["tools_output"] = "" - # Execute the memory chain - out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, stray) + # we run memory chain if: + # - no procedures where recalled or selected or + # - procedures have all return_direct=False or + # - procedures agent crashed big time + if "tools_output" not in agent_input: + agent_input["tools_output"] = "" + memory_chain_output = await self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix, stray) + memory_chain_output["intermediate_steps"] = intermediate_steps - return out + return memory_chain_output def format_agent_input(self, working_memory): """Format the input for the Agent. diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 0ec4833d..cf14b16f 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -1,4 +1,6 @@ import time +from typing import List, Dict +from typing_extensions import Protocol from langchain_core.language_models.llms import BaseLLM from langchain.base_language import BaseLanguageModel @@ -8,7 +10,6 @@ from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI - from cat.db import crud from cat.factory.custom_llm import CustomOpenAI from cat.factory.embedder import get_embedder_from_name @@ -22,6 +23,15 @@ from cat.rabbit_hole import RabbitHole from cat.utils import singleton +class Procedure(Protocol): + name: str + procedure_type: str # "tool" or "form" + + # { + # "description": [], + # "start_examples": [], + # } + triggers_map: Dict[str, List[str]] # main class @singleton @@ -33,7 +43,7 @@ class CheshireCat(): Attributes ---------- todo : list - TODO TODO TODO. + Yet to be written. """ @@ -57,9 +67,9 @@ def __init__(self): self.load_memory() # After memory is loaded, we can get/create tools embeddings - # every time the mad_hatter finishes syncing hooks and tools, it will notify the Cat (so it can embed tools in vector memory) - self.mad_hatter.on_finish_plugins_sync_callback = self.embed_tools - self.embed_tools() + # every time the mad_hatter finishes syncing hooks, tools and forms, it will notify the Cat (so it can embed tools in vector memory) + self.mad_hatter.on_finish_plugins_sync_callback = self.embed_procedures + self.embed_procedures() # first time launched manually # Agent manager instance (for reasoning) self.agent_manager = AgentManager() @@ -246,54 +256,70 @@ def load_memory(self): } self.memory = LongTermMemory(vector_memory_config=vector_memory_config) - def embed_tools(self): - # loops over tools and assigns an embedding each. If an embedding is not present in vectorDB, - # it is created and saved - - # retrieve from vectorDB all tool embeddings - embedded_tools = self.memory.vectors.procedural.get_all_points() - - # easy access to (point_id, tool_description) - embedded_tools_ids = [t.id for t in embedded_tools] - embedded_tools_descriptions = [t.payload["page_content"] for t in embedded_tools] - - # loop over mad_hatter tools - for tool in self.mad_hatter.tools: - # if the tool is not embedded - if tool.description not in embedded_tools_descriptions: - # embed the tool and save it to DB - tool_embedding = self.embedder.embed_documents([tool.description]) - self.memory.vectors.procedural.add_point( - tool.description, - tool_embedding[0], - { - "source": "tool", - "when": time.time(), - "name": tool.name, - "examples": tool.examples, - "docstring": tool.docstring - }, - ) - - log.warning(f"Newly embedded {repr(tool)}") - - # easy access to mad hatter tools (found in plugins) - mad_hatter_tools_descriptions = [t.description for t in self.mad_hatter.tools] - - # loop over embedded tools and delete the ones not present in active plugins - points_to_be_deleted = [] - for id, descr in zip(embedded_tools_ids, embedded_tools_descriptions): - # if the tool is not active, it inserts it in the list of points to be deleted - if descr not in mad_hatter_tools_descriptions: - log.warning(f"Deleting embedded CatTool: {descr}") - points_to_be_deleted.append(id) - - # delete not active tools - if len(points_to_be_deleted) > 0: - self.memory.vectors.vector_db.delete( - collection_name="procedural", - points_selector=points_to_be_deleted + def build_embedded_procedures_hashes(self, embedded_procedures): + + hashes = {} + for ep in embedded_procedures: + #log.warning(ep) + metadata = ep.payload["metadata"] + content = ep.payload["page_content"] + source = metadata["source"] + trigger_type = metadata.get("trigger_type", "unsupported") # there may be legacy points with no trigger_type + + p_hash = f"{source}.{trigger_type}.{content}" + hashes[p_hash] = ep.id + + return hashes + + def build_active_procedures_hashes(self, active_procedures): + + hashes = {} + for ap in active_procedures: + for trigger_type, trigger_list in ap.triggers_map.items(): + for trigger_content in trigger_list: + p_hash = f"{ap.name}.{trigger_type}.{trigger_content}" + hashes[p_hash] = { + "obj": ap, + "source": ap.name, + "type": ap.procedure_type, + "trigger_type": trigger_type, + "content": trigger_content, + } + return hashes + + def embed_procedures(self): + + # Retrieve from vectorDB all procedural embeddings + embedded_procedures = self.memory.vectors.procedural.get_all_points() + embedded_procedures_hashes = self.build_embedded_procedures_hashes(embedded_procedures) + + # Easy access to active procedures in mad_hatter (source of truth!) + active_procedures_hashes = self.build_active_procedures_hashes(self.mad_hatter.procedures) + + # points_to_be_kept = set(active_procedures_hashes.keys()) and set(embedded_procedures_hashes.keys()) not necessary + points_to_be_deleted = set(embedded_procedures_hashes.keys()) - set(active_procedures_hashes.keys()) + points_to_be_embedded = set(active_procedures_hashes.keys()) - set(embedded_procedures_hashes.keys()) + + points_to_be_deleted_ids = [embedded_procedures_hashes[p] for p in points_to_be_deleted] + if points_to_be_deleted_ids: + log.warning(f"Deleting triggers: {points_to_be_deleted}") + self.memory.vectors.procedural.delete_points(points_to_be_deleted_ids) + + active_triggers_to_be_embedded = [active_procedures_hashes[p] for p in points_to_be_embedded] + for t in active_triggers_to_be_embedded: + print(t) + trigger_embedding = self.embedder.embed_documents([t["content"]]) + self.memory.vectors.procedural.add_point( + t["content"], + trigger_embedding[0], + { + "source": t["source"], + "type": t["type"], + "trigger_type": t["trigger_type"], + "when": time.time(), + } ) + log.warning(f"Newly embedded {t['type']} trigger: {t['source']}, {t['trigger_type']}, {t['content']}") def send_ws_message(self, content: str, msg_type="notification"): log.error("No websocket connection open") diff --git a/core/cat/looking_glass/output_parser.py b/core/cat/looking_glass/output_parser.py index b43cff5a..07504522 100644 --- a/core/cat/looking_glass/output_parser.py +++ b/core/cat/looking_glass/output_parser.py @@ -3,8 +3,11 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException from typing import Union +from cat.mad_hatter.mad_hatter import MadHatter +from cat.log import log -class ToolOutputParser(AgentOutputParser): + +class ChooseProcedureOutputParser(AgentOutputParser): def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: @@ -35,5 +38,17 @@ def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: log=llm_output, ) + mh = MadHatter() + + for Form in mh.forms: + if Form.name == action: + return AgentFinish( + return_values={ + "output": None, + "form": action + }, + log=llm_output, + ) + # Return the action and action input return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) \ No newline at end of file diff --git a/core/cat/looking_glass/prompts.py b/core/cat/looking_glass/prompts.py index 695a16f5..1fe4f170 100644 --- a/core/cat/looking_glass/prompts.py +++ b/core/cat/looking_glass/prompts.py @@ -1,15 +1,17 @@ -from typing import List +from typing import Union, Dict from langchain.agents.tools import BaseTool from langchain.prompts import StringPromptTemplate +from cat.experimental.form import CatForm + class ToolPromptTemplate(StringPromptTemplate): # The template to use template: str # The list of tools available - tools: List[BaseTool] + procedures: Dict[str,Union[BaseTool, CatForm.__class__]] def format(self, **kwargs) -> str: # Get the intermediate steps (AgentAction, Observation tuples) @@ -23,14 +25,15 @@ def format(self, **kwargs) -> str: kwargs["agent_scratchpad"] = thoughts # Create a tools variable from the list of tools provided kwargs["tools"] = "" - for tool in self.tools: - kwargs["tools"] += f" - {tool.description}\n" - if len(tool.examples) > 0: - kwargs["tools"] += f"\tExamples of questions for {tool.name}:\n" - for example in tool.examples: - kwargs["tools"] += f"\t - \"{example}\"\n" + for proc in self.procedures.values(): + kwargs["tools"] += f" - {proc.name}: {proc.description}\n" + # if len(tool.examples) > 0: + # kwargs["tools"] += f"\tExamples of questions for {tool.name}:\n" + # for example in tool.examples: + # kwargs["tools"] += f"\t - \"{example}\"\n" + #kwargs["tools"] += "\n" # Create a list of tool names for the tools provided - kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) + kwargs["tool_names"] = ", ".join(self.procedures.keys()) return self.template.format(**kwargs) @@ -38,8 +41,7 @@ def format(self, **kwargs) -> str: TOOL_PROMPT = """Answer the following question: `{input}` You can only reply using these tools: -{tools} - - none_of_the_others(): Use this tool if none of the others tools help. Input is always None. +{tools} - none_of_the_others: Use this tool if none of the others tools help. Input is always None. If you want to use tools, use the following format: Action: the name of the action to take, should be one of [{tool_names}] diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index e30f99d3..ec957d67 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -25,8 +25,8 @@ class StrayCat: def __init__( self, user_id: str, + main_loop, ws: WebSocket = None, - event_loop = None ): self.__user_id = user_id self.__ws_messages = asyncio.Queue() @@ -35,12 +35,9 @@ def __init__( # attribute to store ws connection self.ws = ws - # event loop - if event_loop is None: - self.__loop = asyncio.get_event_loop() - else: - self.__loop = event_loop + self.__main_loop = main_loop + self.__loop = asyncio.new_event_loop() def send_ws_message(self, content: str, msg_type: MSG_TYPES="notification"): @@ -66,26 +63,27 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES="notification"): raise ValueError(f"The message type `{msg_type}` is not valid. Valid types: {', '.join(options)}") if msg_type == "error": - self.__loop.create_task( - self.__ws_messages.put( - { - "type": msg_type, - "name": "GenericError", - "description": content - } - ) + + # Call put_nowait in the uvicorn main loop is necessary + # as the ws_mesages queue + + self.__main_loop.call_soon_threadsafe( + self.__ws_messages.put_nowait, + { + "type": msg_type, + "name": "GenericError", + "description": content + } ) else: - self.__loop.create_task( - self.__ws_messages.put( - { - "type": msg_type, - "content": content - } - ) + self.__main_loop.call_soon_threadsafe( + self.__ws_messages.put_nowait, + { + "type": msg_type, + "content": content + } ) - def recall_relevant_memories_to_working_memory(self, query=None): """Retrieve context from memory. @@ -120,7 +118,7 @@ def recall_relevant_memories_to_working_memory(self, query=None): # We may want to search in memory recall_query = self.mad_hatter.execute_hook("cat_recall_query", recall_query, cat=self) - log.info(f'Recall query: "{recall_query}"') + log.info(f"Recall query: '{recall_query}'") # Embed recall query recall_query_embedding = self.embedder.embed_query(recall_query) @@ -176,7 +174,6 @@ def recall_relevant_memories_to_working_memory(self, query=None): # hook to modify/enrich retrieved memories self.mad_hatter.execute_hook("after_cat_recalls_memories", cat=self) - def llm(self, prompt: str, stream: bool = False) -> str: """Generate a response using the LLM model. @@ -207,8 +204,7 @@ def llm(self, prompt: str, stream: bool = False) -> str: if isinstance(self._llm, BaseChatModel): return self._llm.call_as_llm(prompt, callbacks=callbacks) - - def __call__(self, user_message_json): + async def __call__(self, user_message_json): """Call the Cat instance. This method is called on the user's message received from the client. @@ -267,7 +263,7 @@ def __call__(self, user_message_json): # reply with agent try: - cat_message = self.agent_manager.execute_agent(self) + cat_message = await self.agent_manager.execute_agent(self) except Exception as e: # This error happens when the LLM # does not respect prompt instructions. @@ -323,7 +319,7 @@ def __call__(self, user_message_json): "content": str(cat_message.get("output")), "why": { "input": cat_message.get("input"), - "intermediate_steps": cat_message.get("intermediate_steps"), + "intermediate_steps": cat_message.get("intermediate_steps", []), "memory": { "episodic": episodic_report, "declarative": declarative_report, @@ -340,6 +336,11 @@ def __call__(self, user_message_json): return final_output + def run(self, user_message_json): + return self.loop.run_until_complete( + self.__call__(user_message_json) + ) + def send_long_message_to_declarative(self): #Split input after MAX_TEXT_INPUT tokens, on a whitespace, if any, and send it to the rabbit hole index = MAX_TEXT_INPUT @@ -364,10 +365,6 @@ def send_long_message_to_declarative(self): @property def user_id(self): return self.__user_id - - @property - def ws_messages(self): - return self.__ws_messages @property def _llm(self): @@ -392,3 +389,7 @@ def mad_hatter(self): @property def agent_manager(self): return CheshireCat().agent_manager + + @property + def loop(self): + return self.__loop \ No newline at end of file diff --git a/core/cat/mad_hatter/core_plugin/tools.py b/core/cat/mad_hatter/core_plugin/tools.py index bbe72b1e..d566c29f 100644 --- a/core/cat/mad_hatter/core_plugin/tools.py +++ b/core/cat/mad_hatter/core_plugin/tools.py @@ -3,10 +3,8 @@ from cat.mad_hatter.decorators import tool -@tool +@tool(examples=["what time is it", "get the time"]) def get_the_time(tool_input, cat): - """Replies to "what time is it", "get the clock" and similar questions. Input is always None.""" + """Useful to get the current time when asked. Input is always None.""" return str(datetime.now()) - - diff --git a/core/cat/mad_hatter/decorators/__init__.py b/core/cat/mad_hatter/decorators/__init__.py index 4607435c..5f9b0f2d 100644 --- a/core/cat/mad_hatter/decorators/__init__.py +++ b/core/cat/mad_hatter/decorators/__init__.py @@ -1,3 +1,3 @@ from cat.mad_hatter.decorators.tool import CatTool, tool from cat.mad_hatter.decorators.hook import CatHook, hook -from cat.mad_hatter.decorators.plugin_decorator import CatPluginDecorator, plugin \ No newline at end of file +from cat.mad_hatter.decorators.plugin_decorator import CatPluginDecorator, plugin diff --git a/core/cat/mad_hatter/decorators/tool.py b/core/cat/mad_hatter/decorators/tool.py index b7913544..ead4f1d5 100644 --- a/core/cat/mad_hatter/decorators/tool.py +++ b/core/cat/mad_hatter/decorators/tool.py @@ -1,43 +1,66 @@ -from typing import Union, Callable, List +import inspect + +from typing import Union, Callable, List from inspect import signature -from langchain.agents import Tool +from langchain_core.tools import BaseTool # All @tool decorated functions in plugins become a CatTool. # The difference between base langchain Tool and CatTool is that CatTool has an instance of the cat as attribute (set by the MadHatter) -class CatTool(Tool): +class CatTool(BaseTool): + + def __init__(self, name: str, func: Callable, return_direct: bool = False, examples: List[str] = []): - def __init__(self, name: str, func: Callable, description: str, - return_direct: bool = False, examples: List[str] = []): + description = func.__doc__.strip() # call parent contructor - super().__init__(name=name, func=func, description=description, - return_direct=return_direct) + super().__init__(name=name, func=func, description=description, return_direct=return_direct) # StrayCat instance will be set by AgentManager self.cat = None + self.func = func + self.procedure_type = "tool" self.name = name + self.description = description self.return_direct = return_direct - self.func = func - self.examples = examples - self.docstring = self.func.__doc__.strip() - # remove cat argument from description signature so it does not end up in prompts - self.description = self.description.replace(", cat)", ")") + + self.triggers_map = { + "description" : [ + f"{name}: {description}" + ], + "start_example": examples + } + # remove cat argument from signature so it does not end up in prompts + self.signature = f"{signature(self.func)}".replace(", cat)", ")") + + @property + def start_examples(self): + return self.triggers_map["start_example"] def __repr__(self) -> str: - return f"CatTool(name={self.name}, return_direct={self.return_direct}, description={self.docstring})" + return f"CatTool(name={self.name}, return_direct={self.return_direct}, description={self.description})" # used by the AgentManager to let a Tool access the cat instance def assign_cat(self, cat): self.cat = cat def _run(self, input_by_llm): + if inspect.iscoroutinefunction(self.func): + raise NotImplementedError("Tool does not support sync") + return self.func(input_by_llm, cat=self.cat) async def _arun(self, input_by_llm): - # should be used for async Tools, just using sync here - return self._run(input_by_llm) + if inspect.iscoroutinefunction(self.func): + return await self.func(input_by_llm, cat=self.cat) + + return await self.cat.loop.run_in_executor( + None, + self.func, + input_by_llm, + self.cat + ) # override `extra = 'forbid'` for Tool pydantic model in langchain class Config: @@ -71,11 +94,9 @@ def search_api(query: str, cat) -> str: def _make_with_name(tool_name: str) -> Callable: def _make_tool(func: Callable[[str], str]) -> CatTool: assert func.__doc__, "Function must have a docstring" - description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" tool_ = CatTool( name=tool_name, func=func, - description=description, return_direct=return_direct, examples=examples, ) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index f2e00e60..c9df5b2d 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -1,17 +1,25 @@ +import os import glob import shutil -import os +import inspect import traceback from copy import deepcopy +from typing import List, Dict from cat.log import log + import cat.utils as utils from cat.utils import singleton + from cat.db import crud from cat.db.models import Setting + from cat.mad_hatter.plugin_extractor import PluginExtractor from cat.mad_hatter.plugin import Plugin -import inspect +from cat.mad_hatter.decorators.hook import CatHook +from cat.mad_hatter.decorators.tool import CatTool + +from cat.experimental.form import CatForm # This class is responsible for plugins functionality: # - loading @@ -27,12 +35,13 @@ class MadHatter: def __init__(self): - self.plugins = {} # plugins dictionary + self.plugins: Dict[str, Plugin] = {} # plugins dictionary - self.hooks = {} # dict of active plugins hooks ( hook_name -> [CatHook, CatHook, ...]) - self.tools = [] # list of active plugins tools + self.hooks: Dict[str, List[CatHook]] = {} # dict of active plugins hooks ( hook_name -> [CatHook, CatHook, ...]) + self.tools: List[CatTool] = [] # list of active plugins tools + self.forms: List[CatForm] = [] # list of active plugins forms - self.active_plugins = [] + self.active_plugins: List[str] = [] self.plugins_folder = utils.get_plugins_path() @@ -102,7 +111,7 @@ def find_plugins(self): if plugin_id in self.active_plugins: self.plugins[plugin_id].activate() - self.sync_hooks_and_tools() + self.sync_hooks_tools_and_forms() def load_plugin(self, plugin_path): # Instantiate plugin. @@ -117,20 +126,23 @@ def load_plugin(self, plugin_path): # Print the error and go on with the others. log.error(str(e)) - # Load hooks and tools of the active plugins into MadHatter - def sync_hooks_and_tools(self): + # Load hooks, tools and forms of the active plugins into MadHatter + def sync_hooks_tools_and_forms(self): - # emptying tools and hooks + # emptying tools, hooks and forms self.hooks = {} self.tools = [] + self.forms = [] for _, plugin in self.plugins.items(): - # load hooks and tools + # load hooks, tools and forms from active plugins if plugin.id in self.active_plugins: # cache tools self.tools += plugin.tools + self.forms += plugin.forms + # cache hooks (indexed by hook name) for h in plugin.hooks: if h.name not in self.hooks.keys(): @@ -212,7 +224,7 @@ def toggle_plugin(self, plugin_id): self.save_active_plugins_to_db(list(set(self.active_plugins))) # update cache and embeddings - self.sync_hooks_and_tools() + self.sync_hooks_tools_and_forms() else: raise Exception("Plugin {plugin_id} not present in plugins folder") @@ -286,3 +298,6 @@ def get_plugin(self): name = plugin_suffix.split("/")[0] return self.plugins[name] + @property + def procedures(self): + return self.tools + self.forms \ No newline at end of file diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index a4fdb3e4..efd68ee1 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -6,12 +6,13 @@ import traceback import importlib import subprocess -from typing import Dict -from inspect import getmembers +from typing import Dict, List +from inspect import getmembers, isclass from pydantic import BaseModel, ValidationError from packaging.requirements import Requirement from cat.mad_hatter.decorators import CatTool, CatHook, CatPluginDecorator +from cat.experimental.form import CatForm from cat.utils import to_camel_case from cat.log import log @@ -49,11 +50,12 @@ def __init__(self, plugin_path: str): # plugin manifest (name, decription, thumb, etc.) self._manifest = self._load_manifest() - # list of tools and hooks contained in the plugin. + # list of tools, forms and hooks contained in the plugin. # The MadHatter will cache them for easier access, # but they are created and stored in each plugin instance - self._hooks = [] - self._tools = [] + self._hooks: List[CatHook] = [] # list of plugin hooks + self._tools: List[CatTool] = [] # list of plugin tools + self._forms: List[CatForm] = [] # list of plugin forms # list of @plugin decorated functions overriding default plugin behaviour self._plugin_overrides = [] # TODO: make this a dictionary indexed by func name, for faster access @@ -64,8 +66,9 @@ def __init__(self, plugin_path: str): def activate(self): # install plugin requirements on activation self._install_requirements() - # lists of hooks and tools - self._hooks, self._tools, self._plugin_overrides = self._load_decorated_functions() + + # Load of hooks and tools + self._load_decorated_functions() # by default, plugin settings are saved inside the plugin folder # in a JSON file called settings.json @@ -275,6 +278,7 @@ def _install_requirements(self): def _load_decorated_functions(self): hooks = [] tools = [] + forms = [] plugin_overrides = [] for py_file in self.py_files: @@ -285,8 +289,10 @@ def _load_decorated_functions(self): # save a reference to decorated functions try: plugin_module = importlib.import_module(py_filename) + hooks += getmembers(plugin_module, self._is_cat_hook) tools += getmembers(plugin_module, self._is_cat_tool) + forms += getmembers(plugin_module, self._is_cat_form) plugin_overrides += getmembers(plugin_module, self._is_cat_plugin_override) except Exception as e: log.error(f"Error in {py_filename}: {str(e)}. Unable to load plugin {self._id}") @@ -294,29 +300,34 @@ def _load_decorated_functions(self): traceback.print_exc() # clean and enrich instances - hooks = list(map(self._clean_hook, hooks)) - tools = list(map(self._clean_tool, tools)) - plugin_overrides = list(map(self._clean_plugin_override, plugin_overrides)) - - return hooks, tools, plugin_overrides + self._hooks = list(map(self._clean_hook, hooks)) + self._tools = list(map(self._clean_tool, tools)) + self._forms = list(map(self._clean_form, forms)) + self._plugin_overrides = list(map(self._clean_plugin_override, plugin_overrides)) def plugin_specific_error_message(self): name = self.manifest.get("name") url = self.manifest.get("plugin_url") return f"To resolve any problem related to {name} plugin, contact the creator using github issue at the link {url}" - def _clean_hook(self, hook): + def _clean_hook(self, hook: CatHook): # getmembers returns a tuple h = hook[1] h.plugin_id = self._id return h - def _clean_tool(self, tool): + def _clean_tool(self, tool: CatTool): # getmembers returns a tuple t = tool[1] t.plugin_id = self._id return t + def _clean_form(self, form: CatForm): + # getmembers returns a tuple + f = form[1] + f.plugin_id = self._id + return f + def _clean_plugin_override(self, plugin_override): # getmembers returns a tuple return plugin_override[1] @@ -326,6 +337,17 @@ def _clean_plugin_override(self, plugin_override): @staticmethod def _is_cat_hook(obj): return isinstance(obj, CatHook) + + @staticmethod + def _is_cat_form(obj): + + if not isclass(obj) or obj is CatForm: + return False + + if not issubclass(obj, CatForm) or not obj._autopilot: + return False + + return True # a plugin tool function has to be decorated with @tool # (which returns an instance of CatTool) @@ -362,3 +384,7 @@ def hooks(self): @property def tools(self): return self._tools + + @property + def forms(self): + return self._forms diff --git a/core/cat/main.py b/core/cat/main.py index 1557651a..8e797ba8 100644 --- a/core/cat/main.py +++ b/core/cat/main.py @@ -113,6 +113,7 @@ async def validation_exception_handler(request, exc): "cat.main:cheshire_cat_api", host="0.0.0.0", port=80, + use_colors=True, log_level=log_level.lower(), **debug_config ) diff --git a/core/cat/memory/working_memory.py b/core/cat/memory/working_memory.py index 290fc0d0..270a9e6e 100644 --- a/core/cat/memory/working_memory.py +++ b/core/cat/memory/working_memory.py @@ -1,7 +1,3 @@ -import asyncio -from typing import get_args, Literal - -from cat.log import log class WorkingMemory(dict): """Cat's volatile memory. @@ -24,11 +20,6 @@ def __init__(self): # and the asyncio queue to manage the session notifications super().__init__(history=[]) - def get_user_id(self): - """Get current user id.""" - - return self["user_message_json"]["user_id"] - def update_conversation_history(self, who, message, why={}): """Update the conversation history. diff --git a/core/cat/routes/websocket.py b/core/cat/routes/websocket.py index b0aa105b..9428fe10 100644 --- a/core/cat/routes/websocket.py +++ b/core/cat/routes/websocket.py @@ -1,7 +1,7 @@ import traceback import asyncio -from fastapi import Depends, APIRouter, WebSocketDisconnect, WebSocket +from fastapi import APIRouter, WebSocketDisconnect, WebSocket from fastapi.concurrency import run_in_threadpool from cat.looking_glass.stray_cat import StrayCat @@ -21,7 +21,7 @@ async def receive_message(websocket: WebSocket, stray: StrayCat): user_message["user_id"] = stray.user_id # Run the `ccat` object's method in a threadpool since it might be a CPU-bound operation. - cat_message = await run_in_threadpool(stray, user_message) + cat_message = await run_in_threadpool(stray.run, user_message) # Send the response message back to the user. await websocket.send_json(cat_message) @@ -34,7 +34,7 @@ async def check_messages(websoket: WebSocket, stray: StrayCat): while True: # extract from FIFO list websocket notification - notification = await stray.ws_messages.get() + notification = await stray._StrayCat__ws_messages.get() await websoket.send_json(notification) @@ -62,6 +62,7 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str = "user"): stray = StrayCat( ws=websocket, user_id=user_id, + main_loop=asyncio.get_running_loop() ) strays[user_id] = stray diff --git a/core/pyproject.toml b/core/pyproject.toml index d488ddb6..c904a43a 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "Cheshire-Cat" description = "Production ready AI assistant framework" -version = "1.4.8" +version = "1.5.0" requires-python = ">=3.10" license = { file="LICENSE" } authors = [ @@ -15,8 +15,8 @@ dependencies = [ "pandas==1.5.3", "scikit-learn==1.2.1", "qdrant_client==1.7.2", - "langchain==0.1.4", - "langchain-community", + "langchain==0.1.10", + "langchain-community==0.0.25", "langchain-openai", "langchain-google-genai", "openai==1.10.0", diff --git a/core/tests/looking_glass/test_cheshire_cat.py b/core/tests/looking_glass/test_cheshire_cat.py index afd488b6..5e357d1f 100644 --- a/core/tests/looking_glass/test_cheshire_cat.py +++ b/core/tests/looking_glass/test_cheshire_cat.py @@ -61,14 +61,26 @@ def test_default_embedder_loaded(cheshire_cat): assert sample_embed == out -def test_tools_embedded(cheshire_cat): +def test_procedures_embedded(cheshire_cat): # get embedded tools - tools = cheshire_cat.memory.vectors.procedural.get_all_points() - assert len(tools) == 1 - - # some check on the embedding - assert "get_the_time" in tools[0].payload["page_content"] - assert isinstance(tools[0].vector, list) - sample_embed = DumbEmbedder().embed_query("I'm smarter than a random embedder BTW") - assert len(tools[0].vector) == len(sample_embed) # right embed size + procedures = cheshire_cat.memory.vectors.procedural.get_all_points() + assert len(procedures) == 3 + + for p in procedures: + assert p.payload["metadata"]["source"] == "get_the_time" + assert p.payload["metadata"]["type"] == "tool" + trigger_type = p.payload["metadata"]["trigger_type"] + content = p.payload["page_content"] + assert trigger_type in ["start_example", "description"] + + if trigger_type == "start_example": + assert content in ["what time is it", "get the time"] + if trigger_type == "description": + assert content == "get_the_time: Useful to get the current time when asked. Input is always None." + + # some check on the embedding + assert isinstance(p.vector, list) + expected_embed = cheshire_cat.embedder.embed_query(content) + assert len(p.vector) == len(expected_embed) # same embed + # assert p.vector == expected_embed TODO: Qdrant does unwanted normalization \ No newline at end of file diff --git a/core/tests/looking_glass/test_stray_cat.py b/core/tests/looking_glass/test_stray_cat.py index 4a0f4c25..629cf067 100644 --- a/core/tests/looking_glass/test_stray_cat.py +++ b/core/tests/looking_glass/test_stray_cat.py @@ -1,4 +1,5 @@ import pytest +import asyncio from cat.looking_glass.stray_cat import StrayCat from cat.memory.working_memory import WorkingMemory @@ -6,7 +7,7 @@ @pytest.fixture def stray(client): - yield StrayCat(user_id="Alice") + yield StrayCat(user_id="Alice", main_loop=asyncio.new_event_loop()) def test_stray_initialization(stray): diff --git a/core/tests/mad_hatter/test_mad_hatter.py b/core/tests/mad_hatter/test_mad_hatter.py index 3de09f78..1a33eb48 100644 --- a/core/tests/mad_hatter/test_mad_hatter.py +++ b/core/tests/mad_hatter/test_mad_hatter.py @@ -43,11 +43,12 @@ def test_instantiation_discovery(mad_hatter): assert tool.plugin_id == "core_plugin" assert tool.cat is None assert tool.name == "get_the_time" - assert "get_the_time" in tool.description - assert "what time is it" in tool.docstring + assert tool.description == "Useful to get the current time when asked. Input is always None." assert isfunction(tool.func) assert tool.return_direct == False - assert tool.examples == [] + assert len(tool.start_examples) == 2 + assert "what time is it" in tool.start_examples + assert "get the time" in tool.start_examples # list of active plugins in DB is correct active_plugins = mad_hatter.load_active_plugins_from_db() @@ -81,6 +82,10 @@ def test_plugin_install(mad_hatter: MadHatter, plugin_is_flat): new_tool = mad_hatter.plugins["mock_plugin"].tools[0] assert new_tool.plugin_id == "mock_plugin" assert id(new_tool) == id(mad_hatter.tools[1]) # cached and same object in memory! + # tool examples found + assert len(new_tool.start_examples) == 2 + assert "mock tool example 1" in new_tool.start_examples + assert "mock tool example 2" in new_tool.start_examples # hooks found new_hooks = mad_hatter.plugins["mock_plugin"].hooks diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index 1c01f7a2..f8ebbaf4 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -81,9 +81,13 @@ def test_activate_plugin(plugin): assert isinstance(tool, CatTool) assert tool.plugin_id == "mock_plugin" assert tool.name == "mock_tool" - assert "mock_tool" in tool.description + assert tool.description == "Used to test mock tools. Input is the topic." assert isfunction(tool.func) assert tool.return_direct is True + # tool examples found + assert len(tool.start_examples) == 2 + assert "mock tool example 1" in tool.start_examples + assert "mock tool example 2" in tool.start_examples def test_deactivate_plugin(plugin): diff --git a/core/tests/mocks/mock_plugin/mock_form.py b/core/tests/mocks/mock_plugin/mock_form.py new file mode 100644 index 00000000..8fef3d6c --- /dev/null +++ b/core/tests/mocks/mock_plugin/mock_form.py @@ -0,0 +1,39 @@ +from typing import List, Dict +from datetime import date, time +from enum import Enum +from pydantic import BaseModel, Field, ConfigDict +from cat.log import log +from cat.experimental.form import form, CatForm + +class PizzaBorderEnum(Enum): + HIGH = "high" + LOW = "low" + +# simple pydantic model +class PizzaOrder(BaseModel): + pizza_type: str + pizza_border: PizzaBorderEnum + phone: str = Field(max_length=10) + +@form +class PizzaForm(CatForm): + description = "Pizza Order" + model_class = PizzaOrder + start_examples = [ + "order a pizza", + "I want pizza" + ] + stop_examples = [ + "stop pizza order", + "I do not want a pizza anymore", + ] + + ask_confirm: bool = True + + def submit(self, form_data): + + msg = f"Form submitted: {form_data}" + #self.cat.send_ws_message(msg, msg_type="chat") + return { + "output": msg + } \ No newline at end of file diff --git a/core/tests/mocks/mock_plugin/mock_tool.py b/core/tests/mocks/mock_plugin/mock_tool.py index 501b10c3..10072fa4 100644 --- a/core/tests/mocks/mock_plugin/mock_tool.py +++ b/core/tests/mocks/mock_plugin/mock_tool.py @@ -1,7 +1,8 @@ from cat.mad_hatter.decorators import tool +tool_examples = ["mock tool example 1", "mock tool example 2"] -@tool(return_direct=True) +@tool(return_direct=True, examples=tool_examples) def mock_tool(topic, cat): """Used to test mock tools. Input is the topic.""" diff --git a/core/tests/routes/embedder/test_embedder_setting.py b/core/tests/routes/embedder/test_embedder_setting.py index b45fcf8c..f81ed306 100644 --- a/core/tests/routes/embedder/test_embedder_setting.py +++ b/core/tests/routes/embedder/test_embedder_setting.py @@ -2,7 +2,7 @@ from json import dumps from fastapi.encoders import jsonable_encoder from cat.factory.embedder import get_embedders_schemas -from tests.utils import get_embedded_tools +from tests.utils import get_procedural_memory_contents def test_get_all_embedder_settings(client): @@ -81,9 +81,9 @@ def test_upsert_embedder_settings(client): def test_upsert_embedder_settings_updates_collections(client): - tools = get_embedded_tools(client) - assert len(tools) == 1 - assert len(tools[0]["vector"]) == 2367 # default embedder + procedures = get_procedural_memory_contents(client) + assert len(procedures) == 3 + assert len(procedures[0]["vector"]) == 2367 # default embedder # set a different embedder from default one (same class different size) embedder_config = { @@ -92,8 +92,9 @@ def test_upsert_embedder_settings_updates_collections(client): response = client.put("/embedder/settings/EmbedderFakeConfig", json=embedder_config) assert response.status_code == 200 - tools = get_embedded_tools(client) - assert len(tools) == 1 - assert len(tools[0]["vector"]) == embedder_config["size"] + procedures = get_procedural_memory_contents(client) + assert len(procedures) == 3 + for vec in procedures: + assert len(vec["vector"]) == embedder_config["size"] diff --git a/core/tests/routes/memory/test_memory_collection.py b/core/tests/routes/memory/test_memory_collection.py index 8ea13edd..96d00a2d 100644 --- a/core/tests/routes/memory/test_memory_collection.py +++ b/core/tests/routes/memory/test_memory_collection.py @@ -15,7 +15,7 @@ def test_memory_collections_created(client): # check correct number of default points collections_n_points = { c["name"]: c["vectors_count"] for c in json["collections"]} # there is at least an embedded tool in procedural collection - assert collections_n_points["procedural"] == 1 + assert collections_n_points["procedural"] == 3 # all other collections should be empty assert collections_n_points["episodic"] == 0 assert collections_n_points["declarative"] == 0 @@ -76,7 +76,7 @@ def test_memory_collection_procedural_has_tools_after_clear(client): # procedural emmory contains one tool (get_the_time) collections_n_points = get_collections_names_and_point_count(client) - assert collections_n_points["procedural"] == 1 + assert collections_n_points["procedural"] == 3 # delete procedural memory response = client.delete("/memory/collections/procedural") @@ -85,7 +85,7 @@ def test_memory_collection_procedural_has_tools_after_clear(client): # tool should be automatically re-embedded after memory deletion collections_n_points = get_collections_names_and_point_count(client) - assert collections_n_points["procedural"] == 1 # still 1! + assert collections_n_points["procedural"] == 3 # still 1! def test_memory_collections_wipe(client): @@ -106,7 +106,7 @@ def test_memory_collections_wipe(client): response = client.post("/rabbithole/", files=files) collections_n_points = get_collections_names_and_point_count(client) - assert collections_n_points["procedural"] == 1 # default tool + assert collections_n_points["procedural"] == 3 # default tool assert collections_n_points["episodic"] == 1 # websocket msg assert collections_n_points["declarative"] > 1 # several chunks @@ -116,6 +116,6 @@ def test_memory_collections_wipe(client): assert response.status_code == 200 collections_n_points = get_collections_names_and_point_count(client) - assert collections_n_points["procedural"] == 1 # default tool is re-emebedded + assert collections_n_points["procedural"] == 3 # default tool is re-emebedded assert collections_n_points["episodic"] == 0 assert collections_n_points["declarative"] == 0 \ No newline at end of file diff --git a/core/tests/routes/plugins/test_plugin_settings.py b/core/tests/routes/plugins/test_plugin_settings.py index 95fb65a5..c7c3bd8d 100644 --- a/core/tests/routes/plugins/test_plugin_settings.py +++ b/core/tests/routes/plugins/test_plugin_settings.py @@ -1,4 +1,3 @@ -from tests.utils import get_embedded_tools from fixture_just_installed_plugin import just_installed_plugin diff --git a/core/tests/routes/plugins/test_plugin_toggle.py b/core/tests/routes/plugins/test_plugin_toggle.py index 13074517..22c97b62 100644 --- a/core/tests/routes/plugins/test_plugin_toggle.py +++ b/core/tests/routes/plugins/test_plugin_toggle.py @@ -1,5 +1,5 @@ -from tests.utils import get_embedded_tools +from tests.utils import get_procedural_memory_contents from fixture_just_installed_plugin import just_installed_plugin @@ -29,11 +29,20 @@ def test_deactivate_plugin(client, just_installed_plugin): assert response.json()["data"]["active"] == False # tool has been taken away - tools = get_embedded_tools(client) - assert len(tools) == 1 - tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "mock_tool" not in tool_names - assert "get_the_time" in tool_names # from core_plugin + procedures = get_procedural_memory_contents(client) + assert len(procedures) == 3 + procedures_sources = list(map(lambda t: t["metadata"]["source"], procedures)) + assert "mock_tool" not in procedures_sources + assert "PizzaForm" not in procedures_sources + assert "get_the_time" in procedures_sources # from core_plugin + + # only examples for core tool + procedures_types = list(map(lambda t: t["metadata"]["type"], procedures)) + assert procedures_types.count("tool") == 3 + assert procedures_types.count("form") == 0 + procedures_triggers = list(map(lambda t: t["metadata"]["trigger_type"], procedures)) + assert procedures_triggers.count("start_example") == 2 + assert procedures_triggers.count("description") == 1 def test_reactivate_plugin(client, just_installed_plugin): @@ -55,9 +64,18 @@ def test_reactivate_plugin(client, just_installed_plugin): response = client.get("/plugins/mock_plugin") assert response.json()["data"]["active"] == True - # tool has been re-embedded - tools = get_embedded_tools(client) - assert len(tools) == 2 - tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "mock_tool" in tool_names - assert "get_the_time" in tool_names # from core_plugin \ No newline at end of file + # check whether procedures have been re-embedded + procedures = get_procedural_memory_contents(client) + assert len(procedures) == 9 # two tools, 4 tools examples, 3 form triggers + procedures_names = list(map(lambda t: t["metadata"]["source"], procedures)) + assert procedures_names.count("mock_tool") == 3 + assert procedures_names.count("get_the_time") == 3 + assert procedures_names.count("PizzaForm") == 3 + + procedures_sources = list(map(lambda t: t["metadata"]["type"], procedures)) + assert procedures_sources.count("tool") == 6 + assert procedures_sources.count("form") == 3 + + procedures_triggers = list(map(lambda t: t["metadata"]["trigger_type"], procedures)) + assert procedures_triggers.count("start_example") == 6 + assert procedures_triggers.count("description") == 3 diff --git a/core/tests/routes/plugins/test_plugins_install_uninstall.py b/core/tests/routes/plugins/test_plugins_install_uninstall.py index 2e019238..4df35944 100644 --- a/core/tests/routes/plugins/test_plugins_install_uninstall.py +++ b/core/tests/routes/plugins/test_plugins_install_uninstall.py @@ -1,7 +1,7 @@ import os import time import shutil -from tests.utils import get_embedded_tools +from tests.utils import get_procedural_memory_contents from fixture_just_installed_plugin import just_installed_plugin @@ -34,11 +34,20 @@ def test_plugin_install_from_zip(client, just_installed_plugin): assert os.path.exists(mock_plugin_final_folder) # check whether new tool has been embedded - tools = get_embedded_tools(client) - assert len(tools) == 2 - tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "mock_tool" in tool_names - assert "get_the_time" in tool_names # from core_plugin + procedures = get_procedural_memory_contents(client) + assert len(procedures) == 9 # two tools, 4 tools examples, 3 form triggers + procedures_names = list(map(lambda t: t["metadata"]["source"], procedures)) + assert procedures_names.count("mock_tool") == 3 + assert procedures_names.count("get_the_time") == 3 + assert procedures_names.count("PizzaForm") == 3 + + procedures_sources = list(map(lambda t: t["metadata"]["type"], procedures)) + assert procedures_sources.count("tool") == 6 + assert procedures_sources.count("form") == 3 + + procedures_triggers = list(map(lambda t: t["metadata"]["trigger_type"], procedures)) + assert procedures_triggers.count("start_example") == 6 + assert procedures_triggers.count("description") == 3 def test_plugin_uninstall(client, just_installed_plugin): @@ -57,11 +66,25 @@ def test_plugin_uninstall(client, just_installed_plugin): assert not os.path.exists(mock_plugin_final_folder) # plugin folder removed from disk # plugin tool disappeared - tools = get_embedded_tools(client) - assert len(tools) == 1 - tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "mock_tool" not in tool_names - assert "get_the_time" in tool_names # from core_plugin - - - + procedures = get_procedural_memory_contents(client) + assert len(procedures) == 3 + procedures_names = set(map(lambda t: t["metadata"]["source"], procedures)) + assert procedures_names == {"get_the_time"} + + # only examples for core tool + # Ensure unique procedure sources + procedures_sources = list(map(lambda t: t["metadata"]["type"], procedures)) + assert procedures_sources.count("tool") == 3 + assert procedures_sources.count("form") == 0 + + tool_start_examples = [] + form_start_examples = [] + for p in procedures: + if p["metadata"]["type"] == "tool" and p["metadata"]["trigger_type"] == "start_example": + tool_start_examples.append(p) + + if p["metadata"]["type"] == "form" and p["metadata"]["trigger_type"] == "start_example": + form_start_examples.append(p) + + assert len(tool_start_examples) == 2 + assert len(form_start_examples) == 0 diff --git a/core/tests/routes/plugins/test_plugins_registry.py b/core/tests/routes/plugins/test_plugins_registry.py index f3966d1f..7b692fe5 100644 --- a/core/tests/routes/plugins/test_plugins_registry.py +++ b/core/tests/routes/plugins/test_plugins_registry.py @@ -1,6 +1,6 @@ import os import shutil -from tests.utils import get_embedded_tools, create_mock_plugin_zip +from tests.utils import create_mock_plugin_zip # TODO: registry responses here should be mocked, at the moment we are actually calling the service diff --git a/core/tests/routes/rabbithole/test_upload_memories.py b/core/tests/routes/rabbithole/test_upload_memories.py index 209d558c..4c9b3b53 100644 --- a/core/tests/routes/rabbithole/test_upload_memories.py +++ b/core/tests/routes/rabbithole/test_upload_memories.py @@ -29,7 +29,7 @@ def test_upload_memory(client): # new declarative memory was saved collections_n_points = get_collections_names_and_point_count(client) assert collections_n_points["declarative"] == 1 # new declarative memory (just uploaded) - assert collections_n_points["procedural"] == 1 # default tool + assert collections_n_points["procedural"] == 3 # default tool assert collections_n_points["episodic"] == 0 diff --git a/core/tests/utils.py b/core/tests/utils.py index 967e3a6a..1b12d49b 100644 --- a/core/tests/utils.py +++ b/core/tests/utils.py @@ -58,7 +58,7 @@ def create_mock_plugin_zip(flat: bool): # utility to retrieve embedded tools from endpoint -def get_embedded_tools(client): +def get_procedural_memory_contents(client): params = { "text": "random" } diff --git a/readme/CHANGELOG.md b/readme/CHANGELOG.md index 37fbfc4a..3afe21a7 100644 --- a/readme/CHANGELOG.md +++ b/readme/CHANGELOG.md @@ -1,5 +1,31 @@ # Changelog +## 1.4.8 ( 2024-02-10 ) + +New in version 1.4.8 + +- fix Huggingface endpoint integration by @valentimarco +- optimize plugins' dependencies checks by @kodaline and @pingdred +- adapter for OpenAI compatible endpoints by @AlessandroSpallina +- optimizations for temp files and logs by @pingdred +- Levenshtein distance utility by @horw +- customizable query param for recall functionality by @pazoff +- alternative syntax for `@hook` by @zAlweNy26 +- `@tool` examples` by @zAlweNy26 +- ENV variable sfor Qdrant endpoint by @lucapirrone +- endpoints' final `/` standardization by @zAlweNy26 +- logs refactoring by @giovannialbero1992 +- chuck size and overlap in RabbitHole based on tokens by @nickprock +- CustomOllama LLM adapter by @valentimarco +- plugin upgradeability flag by @bositalia +- FatsEmbed base model and model enum by @nickprock and @valentimarco +- bump langchain and openai versions by @Pingdred +- new `before_cat_stores_episodic_memory` hook by @lucapirrone +- fix cat plugins folder bug in test suite by @nickprock +- bump qdrant client version by @nickprock + +## (long time passed here without changelod updates) + ## 0.0.5 ( 2023-06-05 ) ### Enhancements diff --git a/readme/ROADMAP.md b/readme/ROADMAP.md index a7150c1d..d9f6e8bd 100644 --- a/readme/ROADMAP.md +++ b/readme/ROADMAP.md @@ -1,46 +1,46 @@ -* **Version 2** +* **Version 1.5** * Technical * Plugins - * redesign hooks & tools signature - * tools with more than one arg (structured Tool) - * no cat argument - * registry online + * redesign hooks & tools signature (OK) + * tools with more than one arg (Ok, working on forms) + * no cat argument (OK, cat is a StrayCat) + * registry online (OK) * Agent * Custom hookable agent * Async agent * Output dictionary retry (guardrails, kor, guidance) - * (streaming?) + * Streaming (OK) * Unit tests - * Half coverage (main classes) + * Half coverage (OK) * Admin - * sync / async calls consistent management - * adapt to design system - * show registry plugins (core should send them alongside the installed ones) + * sync / async calls consistent management (OK) + * adapt to design system (OK) + * show registry plugins (OK) * filters for memory search * Deploy - * docker image! - * compose with local LLM + embedder - ready to use + * docker image! (OK) + * compose with local LLM + embedder - ready to use (OK) * (nginx?) * LLM improvements * explicit support for chat vs completion * each LLM has its own default template * User support (not management) - * fix bugs - * sessions + * fix bugs (OK) + * sessions (OK) * Outreach * Community - * 1 live event - * 4 meow talk - * 1 challenge + * 1 live event (OK) + * 4 meow talk (OK) + * 1 challenge (OK) * Dissemination * use cases examples * tutorials on hooks * hook discovery tool - * website analytics + * website analytics (OK) * Branding - * logo - * website + docs + admin design system + * logo (OK) + * website + docs + admin design system (OK) ---