Skip to content

Commit

Permalink
feat: Added Mistral-7B-Instruct-v0.3 support using Jumpstart (#553)
Browse files Browse the repository at this point in the history
* chore: Upgraded dependencies + fix code analytics warning

* test: Add sagemaker integ test.

* chore: Migrate file upload script to langchain 0.2

---------

Co-authored-by: Nikolai Grinko <grinko.nikolai@gmail.com>
  • Loading branch information
charles-marion and grinko authored Aug 27, 2024
1 parent 3075d2c commit 21de272
Show file tree
Hide file tree
Showing 18 changed files with 154 additions and 34 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ jobs:
pip install -r pytest_requirements.txt
flake8 .
bandit -r .
pip-audit -r pytest_requirements.txt || true
pip-audit -r lib/shared/web-crawler-batch-job/requirements.txt || true
pip-audit -r lib/shared/file-import-batch-job/requirements.txt || true
pip-audit -r pytest_requirements.txt
pip-audit -r lib/shared/web-crawler-batch-job/requirements.txt
pip-audit -r lib/shared/file-import-batch-job/requirements.txt
pytest tests/
- name: Frontend
working-directory: ./lib/user-interface/react-app
Expand Down
4 changes: 2 additions & 2 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The following Python packages may be included in this product:
- cfnresponse==1.1.2
- opensearch-py==2.3.1
- openai==0.28.0
- requests==2.31.0
- requests==2.32.0
- huggingface-hub
- hf-transfer
- aws_xray_sdk==2.12.1
Expand Down Expand Up @@ -363,7 +363,7 @@ SOFTWARE.

The following Python packages may be included in this product:

- langchain==0.1.5
- langchain==0.2.14

These packages each contain the following license and notice below:

Expand Down
9 changes: 3 additions & 6 deletions integtests/chatbot-api/kendra_workspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def test_add_file(client: AppSyncClient):

fields = result.get("fields")
cleaned_fields = fields.replace("{", "").replace("}", "")
pairs = [pair.strip() for pair in cleaned_fields.split(',')]
fields_dict = dict(pair.split('=', 1) for pair in pairs)
pairs = [pair.strip() for pair in cleaned_fields.split(",")]
fields_dict = dict(pair.split("=", 1) for pair in pairs)
files = {"file": b"The Integ Test flower is yellow."}
response = requests.post(result.get("url"), data=fields_dict, files=files)
assert response.status_code == 204
Expand All @@ -78,10 +78,7 @@ def test_add_file(client: AppSyncClient):
assert syncInProgress == False

documents = client.list_documents(
input={
"workspaceId": pytest.workspace.get("id"),
"documentType": "file"
}
input={"workspaceId": pytest.workspace.get("id"), "documentType": "file"}
)
pytest.document = documents.get("items")[0]
assert pytest.document.get("status") == "processed"
Expand Down
68 changes: 68 additions & 0 deletions integtests/chatbot-api/sagemaker_session_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# This test will only run if the dolly sagemaker endpoint was create.
# It aims to validate the sagemaker flow
import json
import time
import uuid

import pytest


def test_jumpstart_sagemaker_endpoint(client):
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
models = client.list_models()
model = next(i for i in models if i.get("name") == model_name)
if model is None:
pytest.skip("Mistra v0.3 is not enabled.")
session_id = str(uuid.uuid4())
request = {
"action": "run",
"modelInterface": "langchain",
"data": {
"mode": "chain",
"text": "Hello, my name is Tom.",
"files": [],
"modelName": model_name,
"provider": "sagemaker",
"sessionId": session_id,
},
"modelKwargs": {"maxTokens": 150},
}

client.send_query(json.dumps(request))

found = False
retries = 0
while not found and retries < 20:
time.sleep(1)
retries += 1
session = client.get_session(session_id)
if (
session != None
and len(session.get("history")) == 2
and "tom" in session.get("history")[1].get("content").lower()
):
found = True
break
assert found == True

request = request.copy()
# The goal here is to test the conversation history
request["data"]["text"] = "What is my name?"

client.send_query(json.dumps(request))

found = False
retries = 0
while not found and retries < 20:
time.sleep(1)
retries += 1
session = client.get_session(session_id)
if (
session != None
and len(session.get("history")) == 4
and "tom" in session.get("history")[3].get("content").lower()
):
found = True
break

assert found == True
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from enum import Enum
from aws_lambda_powertools import Logger
from langchain.callbacks.base import BaseCallbackHandler
Expand Down Expand Up @@ -56,6 +57,13 @@ def __bind_callbacks(self):
if method in valid_callback_names:
setattr(self.callback_handler, method, getattr(self, method))

def get_endpoint(self, model_id):
clean_name = "SAGEMAKER_ENDPOINT_" + re.sub(r"[\s.\/\-_]", "", model_id).upper()
if os.getenv(clean_name):
return os.getenv(clean_name)
else:
return model_id

def get_llm(self, model_kwargs={}):
raise ValueError("llm must be implemented")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_llm(self, model_kwargs={}):
params["max_new_tokens"] = model_kwargs["maxTokens"]

return SagemakerEndpoint(
endpoint_name=self.model_id,
endpoint_name=self.get_endpoint(self.model_id),
region_name=os.environ["AWS_REGION"],
content_handler=content_handler,
model_kwargs=params,
Expand Down Expand Up @@ -89,3 +89,4 @@ def get_condense_question_prompt(self):

# Register the adapter
registry.register(r"(?i)sagemaker\.mistralai-Mistral*", SMMistralInstructAdapter)
registry.register(r"(?i)sagemaker\.mistralai/Mistral*", SMMistralInstructAdapter)
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def handle_failed_records(records):
"timestamp": str(int(round(datetime.now().timestamp()))),
"data": {
"sessionId": session_id,
"content": str(error),
# Log a vague message because the error can contain
# internal information
"content": "Something went wrong",
"type": "text",
},
}
Expand All @@ -166,7 +168,12 @@ def handler(event, context: LambdaContext):
except BatchProcessingError as e:
logger.error(e)

logger.info(processed_messages)
for message in processed_messages:
logger.info(
"Request compelte with status " + message[0],
status=message[0],
cause=message[1],
)
handle_failed_records(
message for message in processed_messages if message[0] == "fail"
)
Expand Down
2 changes: 1 addition & 1 deletion lib/model-interfaces/langchain/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ export class LangChainInterface extends Construct {
resources: [endpoint.ref],
})
);
const cleanName = name.replace(/[\s.\-_]/g, "").toUpperCase();
const cleanName = name.replace(/[\s./\-_]/g, "").toUpperCase();
this.requestHandler.addEnvironment(
`SAGEMAKER_ENDPOINT_${cleanName}`,
endpoint.attrEndpointName
Expand Down
36 changes: 36 additions & 0 deletions lib/models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,42 @@ export class Models extends Construct {
});
}

if (
props.config.llms?.sagemaker.includes(
SupportedSageMakerModels.Mistral7b_Instruct3
)
) {
const MISTRACL_7B_3_ENDPOINT_NAME = "mistralai/Mistral-7B-Instruct-v0.3";

const mistral7BInstruct3 = new JumpStartSageMakerEndpoint(
this,
"Mistral7b_Instruct3",
{
model: JumpStartModel.HUGGINGFACE_LLM_MISTRAL_7B_INSTRUCT_3_0_0,
instanceType: SageMakerInstanceType.ML_G5_2XLARGE,
vpcConfig: {
securityGroupIds: [props.shared.vpc.vpcDefaultSecurityGroup],
subnets: props.shared.vpc.privateSubnets.map(
(subnet) => subnet.subnetId
),
},
endpointName: "Mistral-7B-Instruct-v0-3",
}
);

this.suppressCdkNagWarningForEndpointRole(mistral7BInstruct3.role);

models.push({
name: MISTRACL_7B_3_ENDPOINT_NAME,
endpoint: mistral7BInstruct3.cfnEndpoint,
responseStreamingSupported: false,
inputModalities: [Modality.Text],
outputModalities: [Modality.Text],
interface: ModelInterface.LangChain,
ragSupported: true,
});
}

if (
props.config.llms?.sagemaker.includes(
SupportedSageMakerModels.Llama2_13b_Chat
Expand Down
2 changes: 1 addition & 1 deletion lib/shared/file-import-batch-job/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import genai_core.documents
import genai_core.workspaces
import genai_core.aurora.create
from langchain.document_loaders import S3FileLoader
from langchain_community.document_loaders import S3FileLoader

WORKSPACE_ID = os.environ.get("WORKSPACE_ID")
DOCUMENT_ID = os.environ.get("DOCUMENT_ID")
Expand Down
3 changes: 2 additions & 1 deletion lib/shared/file-import-batch-job/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ numpy==1.26.0
cfnresponse==1.1.2
aws_requests_auth==0.4.3
requests-aws4auth==1.2.3
langchain==0.1.11
langchain==0.2.14
langchain-community==0.2.12
opensearch-py==2.3.1
psycopg2-binary==2.9.7
pgvector==0.2.2
Expand Down
4 changes: 2 additions & 2 deletions lib/shared/layers/common/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ numpy==1.26.0
cfnresponse==1.1.2
aws_requests_auth==0.4.3
requests-aws4auth==1.2.3
langchain==0.2.3
langchain-community==0.2.4
langchain==0.2.14
langchain-community==0.2.12
langchain-aws==0.1.6
opensearch-py==2.4.2
psycopg2-binary==2.9.7
Expand Down
5 changes: 3 additions & 2 deletions lib/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ export type ModelProvider = "sagemaker" | "bedrock" | "openai";

export enum SupportedSageMakerModels {
FalconLite = "FalconLite [ml.g5.12xlarge]",
Idefics_9b = "Idefics_9b (Multimodal) [ml.g5.12xlarge]",
Idefics_80b = "Idefics_80b (Multimodal) [ml.g5.48xlarge]",
Llama2_13b_Chat = "Llama2_13b_Chat [ml.g5.12xlarge]",
Mistral7b_Instruct = "Mistral7b_Instruct 0.1 [ml.g5.2xlarge]",
Mistral7b_Instruct2 = "Mistral7b_Instruct 0.2 [ml.g5.2xlarge]",
Mistral7b_Instruct3 = "Mistral7b_Instruct 0.3 [ml.g5.2xlarge]",
Mixtral_8x7b_Instruct = "Mixtral_8x7B_Instruct 0.1 [ml.g5.48xlarge]",
Idefics_9b = "Idefics_9b (Multimodal) [ml.g5.12xlarge]",
Idefics_80b = "Idefics_80b (Multimodal) [ml.g5.48xlarge]",
}

export enum SupportedRegion {
Expand Down
4 changes: 2 additions & 2 deletions lib/shared/web-crawler-batch-job/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ numpy==1.26.0
cfnresponse==1.1.2
aws_requests_auth==0.4.3
requests-aws4auth==1.2.3
langchain==0.1.11
langchain==0.2.14
opensearch-py==2.3.1
psycopg2-binary==2.9.7
pgvector==0.2.2
pydantic==2.4.0
urllib3<2
openai==0.28.0
beautifulsoup4==4.12.2
requests==2.31.0
requests==2.32.0
attrs==23.1.0
feedparser==6.0.11
aws_xray_sdk==2.12.1
Expand Down
8 changes: 4 additions & 4 deletions lib/user-interface/react-app/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pytest_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ black==24.8.0
flake8==7.1.0
selenium==4.16
pdfplumber==0.11.0
pyopenssl==23.3.0
pyopenssl==24.2.1
cryptography==42.0.4
-r lib/shared/layers/common/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ def test_parse_url(mocker):
["text/html"],
)
assert "Release v.4.0.7 " in reponse[0]
assert "https://github.com/" in reponse[1]
assert "https://docs.github.com/" in reponse[2]
assert len(reponse[1]) > 0 # Found urls from the same domain
assert len(reponse[2]) > 0 # Found urls from a differnt domain

0 comments on commit 21de272

Please sign in to comment.