Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Option to disable cross encoder models #286

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
da037d1
Currently cross encoder models are used to rank the search results bu…
azaylamba Dec 23, 2023
1c3b8ce
Enhancement: Add user feedback for responses
azaylamba Dec 24, 2023
c8dc554
Revert "Enhancement: Add user feedback for responses"
azaylamba Dec 24, 2023
550d2d0
Merge branch 'main' into main
azaylamba Jan 17, 2024
8dd11d8
Merge branch 'aws-samples:main' into main
azaylamba Jan 25, 2024
42c6edd
Merge branch 'main' of https://github.com/azaylamba/aws-genai-llm-cha…
azaylamba Feb 4, 2024
efb1a99
Addressed review comments related to cross encoding.
azaylamba Feb 4, 2024
b58737d
Removed prompt for selecting embedding models as it is not required now.
azaylamba Feb 4, 2024
cb8793d
Resolving merge conflicts
azaylamba Feb 9, 2024
cf0dfc1
Resolving merge conflicts
azaylamba Feb 9, 2024
13ce71e
Derived value of crossEncodingEnabled based on enableEmbeddingModelsV…
azaylamba Feb 9, 2024
2522839
Reverted unwanted change
azaylamba Feb 9, 2024
4669419
Merge branch 'main' into main
bigadsoleiman Feb 13, 2024
1667e9c
Merge branch 'main' into main
azaylamba Feb 24, 2024
1102491
Default embeddings model prompt was not set
azaylamba Feb 24, 2024
2047641
Merge branch 'main' into main
bigadsoleiman Mar 8, 2024
a09713e
Merge branch 'main' into main
azaylamba Apr 13, 2024
dca47d0
Corrected the NagSuppression conditions
azaylamba Apr 20, 2024
c2eabf4
Merge branch 'main' into main
azaylamba Jul 13, 2024
6a7c92b
Addressed review comments
azaylamba Jul 13, 2024
494f3b1
Added default value for cross encoder models
azaylamba Jul 15, 2024
efa9fa8
Merge branch 'main' into main
azaylamba Jul 18, 2024
61b73d2
Used enableSagemakerModels config for SM models
azaylamba Jul 18, 2024
feb5752
Merge branch 'main' of https://github.com/azaylamba/aws-genai-llm-cha…
azaylamba Jul 18, 2024
6850a9a
Merge branch 'main' into main
azaylamba Aug 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ export function getConfig(): SystemConfig {
},
llms: {
// sagemaker: [SupportedSageMakerModels.FalconLite]
enableSagemakerModels: false,
sagemaker: [],
},
rag: {
enabled: false,
enableEmbeddingModelsViaSagemaker: false,
engines: {
aurora: {
enabled: false,
Expand All @@ -45,11 +47,13 @@ export function getConfig(): SystemConfig {
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
dimensions: 1024,
default: false,
},
{
provider: "sagemaker",
name: "sentence-transformers/all-MiniLM-L6-v2",
dimensions: 384,
default: false,
},
{
provider: "bedrock",
Expand Down Expand Up @@ -77,8 +81,10 @@ export function getConfig(): SystemConfig {
provider: "openai",
name: "text-embedding-ada-002",
dimensions: 1536,
default: false,
},
],
crossEncodingEnabled: false,
crossEncoderModels: [
{
provider: "sagemaker",
Expand Down
51 changes: 38 additions & 13 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ function getTimeZonesWithCurrentTime(): { message: string; name: string }[] {
function getCountryCodesAndNames(): { message: string; name: string }[] {
// Use country-list to get an array of countries with their codes and names
const countries = getData();

// Map the country data to match the desired output structure
const countryInfo = countries.map(({ code, name }) => {
return { message: `${name} (${code})`, name: code };
Expand Down Expand Up @@ -93,16 +92,19 @@ const embeddingModels = [
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
dimensions: 1024,
default: false,
},
{
provider: "sagemaker",
name: "sentence-transformers/all-MiniLM-L6-v2",
dimensions: 384,
default: false,
},
{
provider: "bedrock",
name: "amazon.titan-embed-text-v1",
dimensions: 1536,
default: false,
},
//Support for inputImage is not yet implemented for amazon.titan-embed-image-v1
{
Expand All @@ -124,6 +126,7 @@ const embeddingModels = [
provider: "openai",
name: "text-embedding-ada-002",
dimensions: 1536,
default: false,
},
];

Expand Down Expand Up @@ -175,6 +178,8 @@ const embeddingModels = [
options.startScheduleEndDate =
config.llms?.sagemakerSchedule?.startScheduleEndDate;
options.enableRag = config.rag.enabled;
options.enableEmbeddingModelsViaSagemaker =
config.rag.enableEmbeddingModelsViaSagemaker;
options.ragsToEnable = Object.keys(config.rag.engines ?? {}).filter(
(v: string) =>
(
Expand Down Expand Up @@ -577,6 +582,16 @@ async function processCreateOptions(options: any): Promise<void> {
message: "Do you want to enable RAG",
initial: options.enableRag || false,
},
{
type: "confirm",
name: "enableEmbeddingModelsViaSagemaker",
message:
"Do you want to enable embedding and cross-encoder models via SageMaker?",
initial: options.enableEmbeddingModelsViaSagemaker || false,
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
skip(): boolean {
return !(this as any).state.answers.enableRag;
},
},
{
type: "multiselect",
name: "ragsToEnable",
Expand Down Expand Up @@ -705,7 +720,6 @@ async function processCreateOptions(options: any): Promise<void> {
if ((this as any).state.answers.enableRag) {
return value ? true : "Select a default embedding model";
}

return true;
},
skip() {
Expand Down Expand Up @@ -1046,6 +1060,7 @@ async function processCreateOptions(options: any): Promise<void> {
}
: undefined,
llms: {
enableSagemakerModels: answers.enableSagemakerModels,
sagemaker: answers.sagemakerModels,
huggingfaceApiSecretArn: answers.huggingfaceApiSecretArn,
sagemakerSchedule: answers.enableSagemakerModelsSchedule
Expand All @@ -1065,6 +1080,8 @@ async function processCreateOptions(options: any): Promise<void> {
},
rag: {
enabled: answers.enableRag,
enableEmbeddingModelsViaSagemaker:
answers.enableEmbeddingModelsViaSagemaker,
engines: {
aurora: {
enabled: answers.ragsToEnable.includes("aurora"),
Expand All @@ -1079,27 +1096,35 @@ async function processCreateOptions(options: any): Promise<void> {
enterprise: false,
},
},
crossEncodingEnabled: answers.enableEmbeddingModelsViaSagemaker,
embeddingsModels: [{}],
crossEncoderModels: [{}],
},
};

// If we have not enabled rag the default embedding is set to the first model
if (!answers.enableRag) {
models.defaultEmbedding = embeddingModels[0].name;
}

config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.embeddingsModels = embeddingModels;
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
}
});

if (!config.rag.enableEmbeddingModelsViaSagemaker) {
config.rag.embeddingsModels = embeddingModels.filter(
(model) => model.provider !== "sagemaker"
);
} else {
config.rag.embeddingsModels = embeddingModels;
}
// If we have not enabled rag the default embedding is set to the first model
if (!answers.enableRag) {
(config.rag.embeddingsModels[0] as any).default = true;
} else {
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
}
});
}

config.rag.engines.kendra.createIndex =
answers.ragsToEnable.includes("kendra");
Expand Down
116 changes: 56 additions & 60 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,9 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
identityPool: authentication.identityPool,
api: chatBotApi,
chatbotFilesBucket: chatBotApi.filesBucket,
crossEncodersEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
crossEncodersEnabled: props.config.rag.crossEncodingEnabled,
sagemakerEmbeddingsEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
props.config.rag.enableEmbeddingModelsViaSagemaker,
});

if (props.config.cognitoFederation?.enabled) {
Expand Down Expand Up @@ -331,10 +330,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
]
);

if (
props.config.rag.engines.aurora.enabled ||
props.config.rag.engines.opensearch.enabled
) {
if (props.config.llms.enableSagemakerModels) {
NagSuppressions.addResourceSuppressionsByPath(
this,
[
Expand Down Expand Up @@ -370,59 +366,59 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
},
]
);
if (props.config.rag.engines.aurora.enabled) {
NagSuppressions.addResourceSuppressionsByPath(
this,
`/${this.stackName}/RagEngines/AuroraPgVector/AuroraDatabase/Secret/Resource`,
[
{
id: "AwsSolutions-SMG4",
reason: "Secret created implicitly by CDK.",
},
]
);
NagSuppressions.addResourceSuppressionsByPath(
this,
[
`/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupFunction/ServiceRole/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/DefaultPolicy/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspace/Role/DefaultPolicy/Resource`,
],
[
{
id: "AwsSolutions-IAM4",
reason: "IAM role implicitly created by CDK.",
},
{
id: "AwsSolutions-IAM5",
reason: "IAM role implicitly created by CDK.",
},
]
);
}
if (props.config.rag.engines.opensearch.enabled) {
NagSuppressions.addResourceSuppressionsByPath(
this,
[
`/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/Resource`,
`/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`,
`/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspace/Role/DefaultPolicy/Resource`,
],
[
{
id: "AwsSolutions-IAM4",
reason: "IAM role implicitly created by CDK.",
},
{
id: "AwsSolutions-IAM5",
reason: "IAM role implicitly created by CDK.",
},
]
);
}
}
if (props.config.rag.engines.aurora.enabled) {
NagSuppressions.addResourceSuppressionsByPath(
this,
`/${this.stackName}/RagEngines/AuroraPgVector/AuroraDatabase/Secret/Resource`,
[
{
id: "AwsSolutions-SMG4",
reason: "Secret created implicitly by CDK.",
},
]
);
NagSuppressions.addResourceSuppressionsByPath(
this,
[
`/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupFunction/ServiceRole/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/DefaultPolicy/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`,
`/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspace/Role/DefaultPolicy/Resource`,
],
[
{
id: "AwsSolutions-IAM4",
reason: "IAM role implicitly created by CDK.",
},
{
id: "AwsSolutions-IAM5",
reason: "IAM role implicitly created by CDK.",
},
]
);
}
if (props.config.rag.engines.opensearch.enabled) {
NagSuppressions.addResourceSuppressionsByPath(
this,
[
`/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/Resource`,
`/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`,
`/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspace/Role/DefaultPolicy/Resource`,
],
[
{
id: "AwsSolutions-IAM4",
reason: "IAM role implicitly created by CDK.",
},
{
id: "AwsSolutions-IAM5",
reason: "IAM role implicitly created by CDK.",
},
]
);
}
if (props.config.rag.engines.kendra.enabled) {
NagSuppressions.addResourceSuppressionsByPath(
Expand Down
27 changes: 16 additions & 11 deletions lib/chatbot-api/functions/api-handler/routes/cross_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@ def models():
@tracer.capture_method
def cross_encoders(input: dict):
request = CrossEncodersRequest(**input)
selected_model = genai_core.cross_encoder.get_cross_encoder_model(
request.provider, request.model
)

if selected_model is None:
raise genai_core.types.CommonError("Model not found")

ret_value = genai_core.cross_encoder.rank_passages(
selected_model, request.reference, request.passages
)
return [{"score": v, "passage": p} for v, p in zip(ret_value, request.passages)]
config = genai_core.parameters.get_config()
crossEncodingEnabled = config["rag"]["crossEncodingEnabled"]
if (crossEncodingEnabled):
selected_model = genai_core.cross_encoder.get_cross_encoder_model(
request.provider, request.model
)

if selected_model is None:
raise genai_core.types.CommonError("Model not found")

ret_value = genai_core.cross_encoder.rank_passages(
selected_model, request.reference, request.passages
)
return [{"score": v, "passage": p} for v, p in zip(ret_value, request.passages)]

return [{"score": 0, "passage": p} for p in request.passages]
2 changes: 1 addition & 1 deletion lib/rag-engines/data-import/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ export class DataImport extends Construct {
processingBucket,
auroraDatabase: props.auroraDatabase,
ragDynamoDBTables: props.ragDynamoDBTables,
sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint,
sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model?.endpoint,
openSearchVector: props.openSearchVector,
}
);
Expand Down
5 changes: 1 addition & 4 deletions lib/rag-engines/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ export class RagEngines extends Construct {
const tables = new RagDynamoDBTables(this, "RagDynamoDBTables");

let sageMakerRagModels: SageMakerRagModels | null = null;
if (
props.config.rag.engines.aurora.enabled ||
props.config.rag.engines.opensearch.enabled
) {
if (props.config.llms.enableSagemakerModels) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be checking crossEncodingEnabled and not enableSageMakerModels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but won't that be confusing that crossEncodingEnabled is driving the Sagemaker models instead of the config props.config.llms.enableSagemakerModels which is specific for sagemaker models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right and props.config.llms.enableSagemakerModels is better

sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", {
shared: props.shared,
config: props.config,
Expand Down
30 changes: 16 additions & 14 deletions lib/rag-engines/sagemaker-rag-models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,22 @@ export class SageMakerRagModels extends Construct {
.filter((c) => c.provider === "sagemaker")
.map((c) => c.name);

const model = new SageMakerModel(this, "Model", {
vpc: props.shared.vpc,
region: cdk.Aws.REGION,
model: {
type: DeploymentType.CustomInferenceScript,
modelId: [
...sageMakerEmbeddingsModelIds,
...sageMakerCrossEncoderModelIds,
],
codeFolder: path.join(__dirname, "./model"),
instanceType: "ml.g4dn.xlarge",
},
});
if (sageMakerEmbeddingsModelIds?.length > 0 || sageMakerCrossEncoderModelIds?.length > 0) {
const model = new SageMakerModel(this, "Model", {
vpc: props.shared.vpc,
region: cdk.Aws.REGION,
model: {
type: DeploymentType.CustomInferenceScript,
modelId: [
...sageMakerEmbeddingsModelIds,
...sageMakerCrossEncoderModelIds,
],
codeFolder: path.join(__dirname, "./model"),
instanceType: "ml.g4dn.xlarge",
},
});

this.model = model;
this.model = model;
}
}
}
Loading