Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Option to disable cross encoder models #286

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
da037d1
Currently cross encoder models are used to rank the search results bu…
azaylamba Dec 23, 2023
1c3b8ce
Enhancement: Add user feedback for responses
azaylamba Dec 24, 2023
c8dc554
Revert "Enhancement: Add user feedback for responses"
azaylamba Dec 24, 2023
550d2d0
Merge branch 'main' into main
azaylamba Jan 17, 2024
8dd11d8
Merge branch 'aws-samples:main' into main
azaylamba Jan 25, 2024
42c6edd
Merge branch 'main' of https://github.com/azaylamba/aws-genai-llm-cha…
azaylamba Feb 4, 2024
efb1a99
Addressed review comments related to cross encoding.
azaylamba Feb 4, 2024
b58737d
Removed prompt for selecting embedding models as it is not required now.
azaylamba Feb 4, 2024
cb8793d
Resolving merge conflicts
azaylamba Feb 9, 2024
cf0dfc1
Resolving merge conflicts
azaylamba Feb 9, 2024
13ce71e
Derived value of crossEncodingEnabled based on enableEmbeddingModelsV…
azaylamba Feb 9, 2024
2522839
Reverted unwanted change
azaylamba Feb 9, 2024
4669419
Merge branch 'main' into main
bigadsoleiman Feb 13, 2024
1667e9c
Merge branch 'main' into main
azaylamba Feb 24, 2024
1102491
Default embeddings model prompt was not set
azaylamba Feb 24, 2024
2047641
Merge branch 'main' into main
bigadsoleiman Mar 8, 2024
a09713e
Merge branch 'main' into main
azaylamba Apr 13, 2024
dca47d0
Corrected the NagSuppression conditions
azaylamba Apr 20, 2024
c2eabf4
Merge branch 'main' into main
azaylamba Jul 13, 2024
6a7c92b
Addressed review comments
azaylamba Jul 13, 2024
494f3b1
Added default value for cross encoder models
azaylamba Jul 15, 2024
efa9fa8
Merge branch 'main' into main
azaylamba Jul 18, 2024
61b73d2
Used enableSagemakerModels config for SM models
azaylamba Jul 18, 2024
feb5752
Merge branch 'main' of https://github.com/azaylamba/aws-genai-llm-cha…
azaylamba Jul 18, 2024
6850a9a
Merge branch 'main' into main
azaylamba Aug 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ export function getConfig(): SystemConfig {
},
llms: {
// sagemaker: [SupportedSageMakerModels.FalconLite]
enableSagemakerModels: false,
sagemaker: [],
},
rag: {
enabled: false,
enableEmbeddingModelsViaSagemaker: false,
engines: {
aurora: {
enabled: false,
Expand All @@ -42,11 +44,13 @@ export function getConfig(): SystemConfig {
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
dimensions: 1024,
default: false,
},
{
provider: "sagemaker",
name: "sentence-transformers/all-MiniLM-L6-v2",
dimensions: 384,
default: false,
},
{
provider: "bedrock",
Expand All @@ -58,8 +62,10 @@ export function getConfig(): SystemConfig {
provider: "openai",
name: "text-embedding-ada-002",
dimensions: 1536,
default: false,
},
],
crossEncodingEnabled: false,
crossEncoderModels: [
{
provider: "sagemaker",
Expand Down
60 changes: 46 additions & 14 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,25 @@ const embeddingModels = [
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
dimensions: 1024,
default: false,
},
{
provider: "sagemaker",
name: "sentence-transformers/all-MiniLM-L6-v2",
dimensions: 384,
default: false,
},
{
provider: "bedrock",
name: "amazon.titan-embed-text-v1",
dimensions: 1536,
default: false,
},
{
provider: "openai",
name: "text-embedding-ada-002",
dimensions: 1536,
default: false,
},
];

Expand Down Expand Up @@ -203,6 +207,15 @@ async function processCreateOptions(options: any): Promise<void> {
message: "Do you want to enable RAG",
initial: options.enableRag || false,
},
{
type: "confirm",
name: "enableEmbeddingModelsViaSagemaker",
message: "Do you want to enable embedding models via SageMaker?",
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
initial: options.enableEmbeddingModelsViaSagemaker || false,
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
skip(): boolean {
return !(this as any).state.answers.enableRag;
},
},
{
type: "multiselect",
name: "ragsToEnable",
Expand Down Expand Up @@ -328,7 +341,7 @@ async function processCreateOptions(options: any): Promise<void> {
choices: embeddingModels.map((m) => ({ name: m.name, value: m })),
initial: options.defaultEmbedding || undefined,
skip(): boolean {
return !(this as any).state.answers.enableRag;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @azaylamba, are you able to address this comment?

return !answers.enableRag;
},
},
];
Expand All @@ -349,10 +362,13 @@ async function processCreateOptions(options: any): Promise<void> {
}
: undefined,
llms: {
enableSagemakerModels: answers.enableSagemakerModels,
sagemaker: answers.sagemakerModels,
},
rag: {
enabled: answers.enableRag,
enableEmbeddingModelsViaSagemaker:
answers.enableEmbeddingModelsViaSagemaker,
engines: {
aurora: {
enabled: answers.ragsToEnable.includes("aurora"),
Expand All @@ -367,28 +383,44 @@ async function processCreateOptions(options: any): Promise<void> {
enterprise: false,
},
},
crossEncodingEnabled: answers.enableEmbeddingModelsViaSagemaker,
embeddingsModels: [{}],
crossEncoderModels: [{}],
},
};

if (
answers.enableEmbeddingModelsViaSagemaker &&
answers.enableSagemakerModels
) {
config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
} else {
config.rag.crossEncoderModels[0] = {
provider: "None",
name: "None",
default: true,
};
}
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
if (!config.rag.enableEmbeddingModelsViaSagemaker) {
config.rag.embeddingsModels = embeddingModels.filter(model => model.provider !== "sagemaker");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic should also be applied to the list of models shown in the UI when selecting the default embedding model.

} else {
config.rag.embeddingsModels = embeddingModels;
}
// If we have not enabled rag the default embedding is set to the first model
if (!answers.enableRag) {
models.defaultEmbedding = embeddingModels[0].name;
(config.rag.embeddingsModels[0] as any).default = true;
} else {
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
}
});
}

config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.embeddingsModels = embeddingModels;
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
}
});

config.rag.engines.kendra.createIndex =
answers.ragsToEnable.includes("kendra");
config.rag.engines.kendra.enabled =
Expand Down
11 changes: 3 additions & 8 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,8 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
identityPool: authentication.identityPool,
api: chatBotApi,
chatbotFilesBucket: chatBotApi.filesBucket,
crossEncodersEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
sagemakerEmbeddingsEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
crossEncodersEnabled: props.config.rag.crossEncodingEnabled,
sagemakerEmbeddingsEnabled: props.config.rag.enableEmbeddingModelsViaSagemaker,
});

/**
Expand Down Expand Up @@ -283,10 +281,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
]
);

if (
props.config.rag.engines.aurora.enabled ||
props.config.rag.engines.opensearch.enabled
) {
if (props.config.llms.enableSagemakerModels) {
NagSuppressions.addResourceSuppressionsByPath(
this,
[
Expand Down
27 changes: 16 additions & 11 deletions lib/chatbot-api/functions/api-handler/routes/cross_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@ def models():
@tracer.capture_method
def cross_encoders(input: dict):
request = CrossEncodersRequest(**input)
selected_model = genai_core.cross_encoder.get_cross_encoder_model(
request.provider, request.model
)

if selected_model is None:
raise genai_core.types.CommonError("Model not found")

ret_value = genai_core.cross_encoder.rank_passages(
selected_model, request.reference, request.passages
)
return [{"score": v, "passage": p} for v, p in zip(ret_value, request.passages)]
config = genai_core.parameters.get_config()
crossEncodingEnabled = config["rag"]["crossEncodingEnabled"]
if (crossEncodingEnabled):
selected_model = genai_core.cross_encoder.get_cross_encoder_model(
request.provider, request.model
)

if selected_model is None:
raise genai_core.types.CommonError("Model not found")

ret_value = genai_core.cross_encoder.rank_passages(
selected_model, request.reference, request.passages
)
return [{"score": v, "passage": p} for v, p in zip(ret_value, request.passages)]

return [{"score": 0, "passage": p} for p in request.passages]
2 changes: 1 addition & 1 deletion lib/rag-engines/data-import/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export class DataImport extends Construct {
processingBucket,
auroraDatabase: props.auroraDatabase,
ragDynamoDBTables: props.ragDynamoDBTables,
sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint,
sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model?.endpoint,
openSearchVector: props.openSearchVector,
}
);
Expand Down
5 changes: 1 addition & 4 deletions lib/rag-engines/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ export class RagEngines extends Construct {
const tables = new RagDynamoDBTables(this, "RagDynamoDBTables");

let sageMakerRagModels: SageMakerRagModels | null = null;
if (
props.config.rag.engines.aurora.enabled ||
props.config.rag.engines.opensearch.enabled
) {
if (props.config.llms.enableSagemakerModels) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be checking crossEncodingEnabled and not enableSageMakerModels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, but won't that be confusing that crossEncodingEnabled is driving the Sagemaker models instead of the config props.config.llms.enableSagemakerModels which is specific for sagemaker models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right and props.config.llms.enableSagemakerModels is better

sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", {
shared: props.shared,
config: props.config,
Expand Down
30 changes: 16 additions & 14 deletions lib/rag-engines/sagemaker-rag-models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,22 @@ export class SageMakerRagModels extends Construct {
.filter((c) => c.provider === "sagemaker")
.map((c) => c.name);

const model = new SageMakerModel(this, "Model", {
vpc: props.shared.vpc,
region: cdk.Aws.REGION,
model: {
type: DeploymentType.CustomInferenceScript,
modelId: [
...sageMakerEmbeddingsModelIds,
...sageMakerCrossEncoderModelIds,
],
codeFolder: path.join(__dirname, "./model"),
instanceType: "ml.g4dn.xlarge",
},
});
if (sageMakerEmbeddingsModelIds?.length > 0 || sageMakerCrossEncoderModelIds?.length > 0) {
const model = new SageMakerModel(this, "Model", {
vpc: props.shared.vpc,
region: cdk.Aws.REGION,
model: {
type: DeploymentType.CustomInferenceScript,
modelId: [
...sageMakerEmbeddingsModelIds,
...sageMakerCrossEncoderModelIds,
],
codeFolder: path.join(__dirname, "./model"),
instanceType: "ml.g4dn.xlarge",
},
});

this.model = model;
this.model = model;
}
}
}
55 changes: 30 additions & 25 deletions lib/shared/layers/python-sdk/python/genai_core/aurora/query.py
azaylamba marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def query_workspace_aurora(
full_response: bool,
threshold: int = 0,
):
config = genai_core.parameters.get_config()
table_name = sql.Identifier(workspace_id.replace("-", ""))
embeddings_model_provider = workspace["embeddings_model_provider"]
embeddings_model_name = workspace["embeddings_model_name"]
Expand All @@ -37,13 +38,6 @@ def query_workspace_aurora(
if selected_model is None:
raise genai_core.types.CommonError("Embeddings model not found")

cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model(
cross_encoder_model_provider, cross_encoder_model_name
)

if cross_encoder_model is None:
raise genai_core.types.CommonError("Cross encoder model not found")

query_embeddings = genai_core.embeddings.generate_embeddings(
selected_model, [query]
)[0]
Expand Down Expand Up @@ -185,24 +179,33 @@ def query_workspace_aurora(
item["keyword_search_score"] = current["keyword_search_score"]

unique_items = list(unique_items.values())
score_dict = dict({})
if len(unique_items) > 0:
passages = [record["content"] for record in unique_items]
passage_scores = genai_core.cross_encoder.rank_passages(
cross_encoder_model, query, passages

if (config["rag"]["crossEncodingEnabled"]):
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model(
cross_encoder_model_provider, cross_encoder_model_name
)

for i in range(len(unique_items)):
score = passage_scores[i]
unique_items[i]["score"] = score
score_dict[unique_items[i]["chunk_id"]] = score
if cross_encoder_model is None:
raise genai_core.types.CommonError("Cross encoder model not found")

score_dict = dict({})
if len(unique_items) > 0:
passages = [record["content"] for record in unique_items]
passage_scores = genai_core.cross_encoder.rank_passages(
cross_encoder_model, query, passages
)

unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True)
for i in range(len(unique_items)):
score = passage_scores[i]
unique_items[i]["score"] = score
score_dict[unique_items[i]["chunk_id"]] = score

for record in vector_search_records:
record["score"] = score_dict[record["chunk_id"]]
for record in keyword_search_records:
record["score"] = score_dict[record["chunk_id"]]
unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True)

for record in vector_search_records:
record["score"] = score_dict[record["chunk_id"]]
for record in keyword_search_records:
record["score"] = score_dict[record["chunk_id"]]

if full_response:
unique_items = unique_items[:limit]
Expand All @@ -217,9 +220,11 @@ def query_workspace_aurora(
"keyword_search_items": convert_types(keyword_search_records),
}
else:
ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[
:limit
]
if config["rag"]["crossEncodingEnabled"]:
ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[:limit]
else:
ret_items = unique_items[:limit]

if len(ret_items) < limit:
azaylamba marked this conversation as resolved.
Show resolved Hide resolved
# inner product metric is negative hence we sort ascending
if metric == "inner":
Expand Down Expand Up @@ -295,4 +300,4 @@ def _convert_records(source: str, records: List[dict]):

converted_records.append(converted)

return converted_records
return converted_records
Loading