-
Notifications
You must be signed in to change notification settings - Fork 325
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding support for mixtral and mistral 0.2
- Loading branch information
1 parent
dda6e4e
commit d7d0b92
Showing
5 changed files
with
164 additions
and
1 deletion.
There are no files selected for viewing
1 change: 1 addition & 0 deletions
1
...l-interfaces/langchain/functions/request-handler/adapters/sagemaker/mistralai/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .mistral_instruct import * | ||
from .mixtral_instruct import * |
88 changes: 88 additions & 0 deletions
88
...aces/langchain/functions/request-handler/adapters/sagemaker/mistralai/mixtral_instruct.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import json | ||
import os | ||
|
||
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint | ||
from langchain.prompts.prompt import PromptTemplate | ||
|
||
from ...base import ModelAdapter | ||
from ...registry import registry | ||
|
||
|
||
class MixtralInstructContentHandler(LLMContentHandler): | ||
content_type = "application/json" | ||
accepts = "application/json" | ||
|
||
def transform_input(self, prompt, model_kwargs) -> bytes: | ||
input_str = json.dumps( | ||
{ | ||
"inputs": prompt, | ||
"parameters": { | ||
"do_sample": True, | ||
"max_new_tokens": model_kwargs.get("max_new_tokens", 32768), | ||
"top_p": model_kwargs.get("top_p", 0.9), | ||
"temperature": model_kwargs.get("temperature", 0.6), | ||
"return_full_text": False, | ||
"stop": ["###", "</s>"], | ||
}, | ||
} | ||
) | ||
return input_str.encode("utf-8") | ||
|
||
def transform_output(self, output: bytes): | ||
response_json = json.loads(output.read().decode("utf-8")) | ||
return response_json[0]["generated_text"] | ||
|
||
|
||
content_handler = MixtralInstructContentHandler() | ||
|
||
|
||
class SMMixtralInstructAdapter(ModelAdapter): | ||
def __init__(self, model_id, **kwargs): | ||
self.model_id = model_id | ||
|
||
super().__init__(**kwargs) | ||
|
||
def get_llm(self, model_kwargs={}): | ||
params = {} | ||
if "temperature" in model_kwargs: | ||
params["temperature"] = model_kwargs["temperature"] | ||
if "topP" in model_kwargs: | ||
params["top_p"] = model_kwargs["topP"] | ||
if "maxTokens" in model_kwargs: | ||
params["max_new_tokens"] = model_kwargs["maxTokens"] | ||
|
||
return SagemakerEndpoint( | ||
endpoint_name=self.model_id, | ||
region_name=os.environ["AWS_REGION"], | ||
content_handler=content_handler, | ||
model_kwargs=params, | ||
callbacks=[self.callback_handler], | ||
) | ||
|
||
def get_qa_prompt(self): | ||
template = """<s>[INST] Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.[/INST] | ||
{context} | ||
</s>[INST] {question} [/INST]""" | ||
|
||
return PromptTemplate.from_template(template) | ||
|
||
def get_prompt(self): | ||
template = """<s>[INST] The following is a friendly conversation between a human and an AI. If the AI does not know the answer to a question, it truthfully says it does not know.[/INST] | ||
{chat_history} | ||
<s>[INST] {input} [/INST]""" | ||
|
||
return PromptTemplate.from_template(template) | ||
|
||
def get_condense_question_prompt(self): | ||
template = """<s>[INST] Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.[/INST] | ||
{chat_history} | ||
</s>[INST] {question} [/INST]""" | ||
|
||
return PromptTemplate.from_template(template) | ||
|
||
|
||
# Register the adapter | ||
registry.register(r"(?i)sagemaker\.mistralai-Mixtral*", SMMixtralInstructAdapter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters