Skip to content

Commit

Permalink
feat: Disable Sagemaker endpoint (or cross-encoder per workspace) (#588)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Ajay Lamba <azaystudy@gmail.com>
Co-authored-by: Bigad Soleiman <bigadsoleiman@gmail.com>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent b6a5d5a commit a1d2aa9
Show file tree
Hide file tree
Showing 35 changed files with 468 additions and 289 deletions.
8 changes: 7 additions & 1 deletion bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ import { existsSync, readFileSync } from "fs";

export function getConfig(): SystemConfig {
if (existsSync("./bin/config.json")) {
return JSON.parse(readFileSync("./bin/config.json").toString("utf8"));
return JSON.parse(
readFileSync("./bin/config.json").toString("utf8")
) as SystemConfig;
}
// Default config
return {
Expand Down Expand Up @@ -48,11 +50,13 @@ export function getConfig(): SystemConfig {
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
dimensions: 1024,
default: false,
},
{
provider: "sagemaker",
name: "sentence-transformers/all-MiniLM-L6-v2",
dimensions: 384,
default: false,
},
{
provider: "bedrock",
Expand Down Expand Up @@ -80,8 +84,10 @@ export function getConfig(): SystemConfig {
provider: "openai",
name: "text-embedding-ada-002",
dimensions: 1536,
default: false,
},
],
crossEncodingEnabled: false,
crossEncoderModels: [
{
provider: "sagemaker",
Expand Down
73 changes: 55 additions & 18 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
SupportedSageMakerModels,
SystemConfig,
SupportedBedrockRegion,
ModelConfig,
} from "../lib/shared/types";
import { LIB_VERSION } from "./version.js";
import * as fs from "fs";
Expand All @@ -34,7 +35,6 @@ function getTimeZonesWithCurrentTime(): { message: string; name: string }[] {
function getCountryCodesAndNames(): { message: string; name: string }[] {
// Use country-list to get an array of countries with their codes and names
const countries = getData();

// Map the country data to match the desired output structure
const countryInfo = countries.map(({ code, name }) => {
return { message: `${name} (${code})`, name: code };
Expand Down Expand Up @@ -88,21 +88,24 @@ const secretManagerArnRegExp = RegExp(
/arn:aws:secretsmanager:[\w-_]+:\d+:secret:[\w-_]+/
);

const embeddingModels = [
const embeddingModels: ModelConfig[] = [
{
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
dimensions: 1024,
default: false,
},
{
provider: "sagemaker",
name: "sentence-transformers/all-MiniLM-L6-v2",
dimensions: 384,
default: false,
},
{
provider: "bedrock",
name: "amazon.titan-embed-text-v1",
dimensions: 1536,
default: false,
},
//Support for inputImage is not yet implemented for amazon.titan-embed-image-v1
{
Expand All @@ -124,6 +127,7 @@ const embeddingModels = [
provider: "openai",
name: "text-embedding-ada-002",
dimensions: 1536,
default: false,
},
];

Expand Down Expand Up @@ -179,6 +183,8 @@ const embeddingModels = [
options.startScheduleEndDate =
config.llms?.sagemakerSchedule?.startScheduleEndDate;
options.enableRag = config.rag.enabled;
options.deployDefaultSagemakerModels =
config.rag.deployDefaultSagemakerModels;
options.ragsToEnable = Object.keys(config.rag.engines ?? {}).filter(
(v: string) =>
(
Expand Down Expand Up @@ -608,6 +614,16 @@ async function processCreateOptions(options: any): Promise<void> {
message: "Do you want to enable RAG",
initial: options.enableRag || false,
},
{
type: "confirm",
name: "deployDefaultSagemakerModels",
message:
"Do you want to deploy the default embedding and cross-encoder models via SageMaker?",
initial: options.deployDefaultSagemakerModels || false,
skip(): boolean {
return !(this as any).state.answers.enableRag;
},
},
{
type: "multiselect",
name: "ragsToEnable",
Expand Down Expand Up @@ -810,10 +826,17 @@ async function processCreateOptions(options: any): Promise<void> {
choices: embeddingModels.map((m) => ({ name: m.name, value: m })),
initial: options.defaultEmbedding,
validate(value: string) {
const embeding = embeddingModels.find((i) => i.name === value);
if (
embeding &&
(this as any).state.answers.deployDefaultSagemakerModels === false &&
embeding?.provider === "sagemaker"
) {
return "SageMaker default models are not enabled. Please select another model.";
}
if ((this as any).state.answers.enableRag) {
return value ? true : "Select a default embedding model";
}

return true;
},
skip() {
Expand Down Expand Up @@ -1219,6 +1242,7 @@ async function processCreateOptions(options: any): Promise<void> {
}
: undefined,
llms: {
enableSagemakerModels: answers.enableSagemakerModels,
rateLimitPerAIP: advancedSettings?.llmRateLimitPerIP
? Number(advancedSettings?.llmRateLimitPerIP)
: undefined,
Expand All @@ -1241,6 +1265,7 @@ async function processCreateOptions(options: any): Promise<void> {
},
rag: {
enabled: answers.enableRag,
deployDefaultSagemakerModels: answers.deployDefaultSagemakerModels,
engines: {
aurora: {
enabled: answers.ragsToEnable.includes("aurora"),
Expand All @@ -1259,28 +1284,40 @@ async function processCreateOptions(options: any): Promise<void> {
external: [{}],
},
},
embeddingsModels: [{}],
crossEncoderModels: [{}],
embeddingsModels: [] as ModelConfig[],
crossEncoderModels: [] as ModelConfig[],
},
};

if (config.rag.enabled && config.rag.deployDefaultSagemakerModels) {
config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.embeddingsModels = embeddingModels;
} else if (config.rag.enabled) {
config.rag.embeddingsModels = embeddingModels.filter(
(model) => model.provider !== "sagemaker"
);
for (const model of config.rag.embeddingsModels) {
model.default = model.name === models.defaultEmbedding;
}
} else {
config.rag.embeddingsModels = [];
}

// If we have not enabled rag the default embedding is set to the first model
if (!answers.enableRag) {
models.defaultEmbedding = embeddingModels[0].name;
(config.rag.embeddingsModels[0] as any).default = true;
} else {
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
}
});
}

config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.embeddingsModels = embeddingModels;
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
}
});

config.rag.engines.kendra.createIndex =
answers.ragsToEnable.includes("kendra");
config.rag.engines.kendra.enabled =
Expand Down
75 changes: 56 additions & 19 deletions integtests/chatbot-api/aurora_workspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def run_before_and_after_tests(client: AppSyncClient):
for workspace in client.list_workspaces():
if (
workspace.get("name") == "INTEG_TEST_AURORA"
and workspace.get("status") == "ready"
):
or workspace.get("name") == "INTEG_TEST_AURORA_WITHOUT_RERANK"
) and workspace.get("status") == "ready":
client.delete_workspace(workspace.get("id"))


Expand All @@ -22,23 +22,25 @@ def test_create(client: AppSyncClient, default_embed_model):
if engine.get("enabled") == False:
pytest.skip_flag = True
pytest.skip("Aurora is not enabled.")
pytest.workspace = client.create_aurora_workspace(
input={
"kind": "auro2",
"name": "INTEG_TEST_AURORA",
"embeddingsModelProvider": "bedrock",
"embeddingsModelName": default_embed_model,
"crossEncoderModelName": "cross-encoder/ms-marco-MiniLM-L-12-v2",
"crossEncoderModelProvider": "sagemaker",
"languages": ["english"],
"index": True,
"hybridSearch": True,
"metric": "inner",
"chunkingStrategy": "recursive",
"chunkSize": 1000,
"chunkOverlap": 200,
}
)
input = {
"kind": "auro2",
"name": "INTEG_TEST_AURORA_WITHOUT_RERANK",
"embeddingsModelProvider": "bedrock",
"embeddingsModelName": default_embed_model,
"languages": ["english"],
"index": True,
"hybridSearch": True,
"metric": "inner",
"chunkingStrategy": "recursive",
"chunkSize": 1000,
"chunkOverlap": 200,
}
input_with_rerank = input.copy()
input_with_rerank["name"] = "INTEG_TEST_AURORA"
input_with_rerank["crossEncoderModelName"] = "cross-encoder/ms-marco-MiniLM-L-12-v2"
input_with_rerank["crossEncoderModelProvider"] = "sagemaker"
pytest.workspace = client.create_aurora_workspace(input=input_with_rerank)
pytest.workspace_no_re_rank = client.create_aurora_workspace(input=input)

ready = False
retries = 0
Expand All @@ -56,6 +58,7 @@ def test_create(client: AppSyncClient, default_embed_model):
def test_add_rss(client: AppSyncClient):
if pytest.skip_flag == True:
pytest.skip("Aurora is not enabled.")

pytest.document = client.add_rss_feed(
input={
"workspaceId": pytest.workspace.get("id"),
Expand All @@ -67,6 +70,17 @@ def test_add_rss(client: AppSyncClient):
"limit": 2,
}
)
client.add_rss_feed(
input={
"workspaceId": pytest.workspace_no_re_rank.get("id"),
"title": "INTEG_TEST_AURORA_TITLE",
"address": "https://github.com/aws-samples/aws-genai-llm-chatbot/"
+ "releases.atom",
"contentTypes": ["text/html"],
"followLinks": True,
"limit": 2,
}
)

ready = False
retries = 0
Expand Down Expand Up @@ -137,6 +151,29 @@ def test_search_document(client: AppSyncClient):
assert ready == True


def test_search_document_no_reank(client: AppSyncClient):
if pytest.skip_flag == True:
pytest.skip("Aurora is not enabled.")
ready = False
retries = 0
# Wait for the page to be crawled. This starts on a cron every 5 min.
while not ready and retries < 50:
time.sleep(15)
retries += 1
result = client.semantic_search(
input={
"workspaceId": pytest.workspace_no_re_rank.get("id"),
"query": "Release github",
}
)
if len(result.get("items")) > 1:
ready = True
assert result.get("engine") == "aurora"
# Re-ranking score is no set but the results are ordered by Aurora.
assert result.get("items")[0].get("score") is None
assert ready == True


def test_query_llm(client, default_model, default_provider):
if pytest.skip_flag == True:
pytest.skip("Aurora is not enabled.")
Expand Down
Loading

0 comments on commit a1d2aa9

Please sign in to comment.