Skip to content

Commit

Permalink
adding support for mixtral and mistral 0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Rob-Powell authored and bigadsoleiman committed Jan 20, 2024
1 parent dda6e4e commit d7d0b92
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mistral_instruct import *
from .mixtral_instruct import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import os

from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
from langchain.prompts.prompt import PromptTemplate

from ...base import ModelAdapter
from ...registry import registry


class MixtralInstructContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, prompt, model_kwargs) -> bytes:
input_str = json.dumps(
{
"inputs": prompt,
"parameters": {
"do_sample": True,
"max_new_tokens": model_kwargs.get("max_new_tokens", 32768),
"top_p": model_kwargs.get("top_p", 0.9),
"temperature": model_kwargs.get("temperature", 0.6),
"return_full_text": False,
"stop": ["###", "</s>"],
},
}
)
return input_str.encode("utf-8")

def transform_output(self, output: bytes):
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]


content_handler = MixtralInstructContentHandler()


class SMMixtralInstructAdapter(ModelAdapter):
def __init__(self, model_id, **kwargs):
self.model_id = model_id

super().__init__(**kwargs)

def get_llm(self, model_kwargs={}):
params = {}
if "temperature" in model_kwargs:
params["temperature"] = model_kwargs["temperature"]
if "topP" in model_kwargs:
params["top_p"] = model_kwargs["topP"]
if "maxTokens" in model_kwargs:
params["max_new_tokens"] = model_kwargs["maxTokens"]

return SagemakerEndpoint(
endpoint_name=self.model_id,
region_name=os.environ["AWS_REGION"],
content_handler=content_handler,
model_kwargs=params,
callbacks=[self.callback_handler],
)

def get_qa_prompt(self):
template = """<s>[INST] Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.[/INST]
{context}
</s>[INST] {question} [/INST]"""

return PromptTemplate.from_template(template)

def get_prompt(self):
template = """<s>[INST] The following is a friendly conversation between a human and an AI. If the AI does not know the answer to a question, it truthfully says it does not know.[/INST]
{chat_history}
<s>[INST] {input} [/INST]"""

return PromptTemplate.from_template(template)

def get_condense_question_prompt(self):
template = """<s>[INST] Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.[/INST]
{chat_history}
</s>[INST] {question} [/INST]"""

return PromptTemplate.from_template(template)


# Register the adapter
registry.register(r"(?i)sagemaker\.mistralai-Mixtral*", SMMixtralInstructAdapter)
70 changes: 70 additions & 0 deletions lib/models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,76 @@ export class Models extends Construct {
});
}


if (
props.config.llms?.sagemaker.includes(
SupportedSageMakerModels.Mistral7b_Instruct2
)
) {
const mistral7bInstruct2 = new SageMakerModel(this, "Mistral7BInstruct2", {
vpc: props.shared.vpc,
region: cdk.Aws.REGION,
model: {
type: DeploymentType.Container,
modelId: "mistralai/Mistral-7B-Instruct-v0.2",
container: ContainerImages.HF_PYTORCH_LLM_TGI_INFERENCE_1_3_3,
instanceType: "ml.g5.2xlarge",
containerStartupHealthCheckTimeoutInSeconds: 300,
env: {
SM_NUM_GPUS: JSON.stringify(1),
MAX_INPUT_LENGTH: JSON.stringify(2048),
MAX_TOTAL_TOKENS: JSON.stringify(4096),
MAX_CONCURRENT_REQUESTS: JSON.stringify(4),
},
},
});

models.push({
name: mistral7bInstruct2.endpoint.endpointName!,
endpoint: mistral7bInstruct2.endpoint,
responseStreamingSupported: false,
inputModalities: [Modality.Text],
outputModalities: [Modality.Text],
interface: ModelInterface.LangChain,
ragSupported: true,
});
}


if (
props.config.llms?.sagemaker.includes(
SupportedSageMakerModels.Mixtral_8x7b_Instruct
)
) {
const mixtral8x7binstruct = new SageMakerModel(this, "Mixtral8x7binstruct", {
vpc: props.shared.vpc,
region: cdk.Aws.REGION,
model: {
type: DeploymentType.Container,
modelId: "mistralai/Mixtral-8x7B-Instruct-v0.1",
container: ContainerImages.HF_PYTORCH_LLM_TGI_INFERENCE_1_3_3,
instanceType: "ml.g5.48xlarge",
containerStartupHealthCheckTimeoutInSeconds: 300,
env: {
SM_NUM_GPUS: JSON.stringify(8),
MAX_INPUT_LENGTH: JSON.stringify(24576),
MAX_TOTAL_TOKENS: JSON.stringify(32768),
MAX_BATCH_PREFILL_TOKENS: JSON.stringify(24576),
MAX_CONCURRENT_REQUESTS: JSON.stringify(4),
},
},
});

models.push({
name: mixtral8x7binstruct.endpoint.endpointName!,
endpoint: mixtral8x7binstruct.endpoint,
responseStreamingSupported: false,
inputModalities: [Modality.Text],
outputModalities: [Modality.Text],
interface: ModelInterface.LangChain,
ragSupported: true,
});
}
// To get Jumpstart model ARNs do the following
// 1. Identify the modelId via https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html
// 2. Run the following code
Expand Down
2 changes: 2 additions & 0 deletions lib/sagemaker-model/container-images.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ export class ContainerImages {
"huggingface-pytorch-tgi-inference:2.0.1-tgi1.0.3-gpu-py39-cu118-ubuntu20.04";
static readonly HF_PYTORCH_LLM_TGI_INFERENCE_1_1_0 =
"huggingface-pytorch-tgi-inference:2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04";
static readonly HF_PYTORCH_LLM_TGI_INFERENCE_1_3_3 =
"huggingface-pytorch-tgi-inference:2.1.1-tgi1.3.3-gpu-py310-cu121-ubuntu20.04";
static readonly HF_PYTORCH_LLM_TGI_INFERENCE_LATEST =
ContainerImages.HF_PYTORCH_LLM_TGI_INFERENCE_1_1_0;
/*
Expand Down
4 changes: 3 additions & 1 deletion lib/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ export type ModelProvider = "sagemaker" | "bedrock" | "openai";
export enum SupportedSageMakerModels {
FalconLite = "FalconLite",
Llama2_13b_Chat = "Llama2_13b_Chat",
Mistral7b_Instruct = "Mistral7b_Instruct",
Mistral7b_Instruct = "Mistral7b_Instruct 0.1",
Mistral7b_Instruct2 = "Mistral7b_Instruct 0.2",
Mixtral_8x7b_Instruct = "Mixtral-8x7B Instruct 0.1",
Idefics_9b = "Idefics_9b (Multimodal)",
Idefics_80b = "Idefics_80b (Multimodal)",
}
Expand Down

0 comments on commit d7d0b92

Please sign in to comment.