diff --git a/bin/config.ts b/bin/config.ts index ca8b5d021..bd00646ac 100644 --- a/bin/config.ts +++ b/bin/config.ts @@ -23,10 +23,12 @@ export function getConfig(): SystemConfig { }, llms: { // sagemaker: [SupportedSageMakerModels.FalconLite] + enableSagemakerModels: false, sagemaker: [], }, rag: { enabled: false, + enableEmbeddingModelsViaSagemaker: false, engines: { aurora: { enabled: false, @@ -45,11 +47,13 @@ export function getConfig(): SystemConfig { provider: "sagemaker", name: "intfloat/multilingual-e5-large", dimensions: 1024, + default: false, }, { provider: "sagemaker", name: "sentence-transformers/all-MiniLM-L6-v2", dimensions: 384, + default: false, }, { provider: "bedrock", @@ -77,8 +81,10 @@ export function getConfig(): SystemConfig { provider: "openai", name: "text-embedding-ada-002", dimensions: 1536, + default: false, }, ], + crossEncodingEnabled: false, crossEncoderModels: [ { provider: "sagemaker", diff --git a/cli/magic-config.ts b/cli/magic-config.ts index 4584199e1..864bcaa03 100644 --- a/cli/magic-config.ts +++ b/cli/magic-config.ts @@ -34,7 +34,6 @@ function getTimeZonesWithCurrentTime(): { message: string; name: string }[] { function getCountryCodesAndNames(): { message: string; name: string }[] { // Use country-list to get an array of countries with their codes and names const countries = getData(); - // Map the country data to match the desired output structure const countryInfo = countries.map(({ code, name }) => { return { message: `${name} (${code})`, name: code }; @@ -93,16 +92,19 @@ const embeddingModels = [ provider: "sagemaker", name: "intfloat/multilingual-e5-large", dimensions: 1024, + default: false, }, { provider: "sagemaker", name: "sentence-transformers/all-MiniLM-L6-v2", dimensions: 384, + default: false, }, { provider: "bedrock", name: "amazon.titan-embed-text-v1", dimensions: 1536, + default: false, }, //Support for inputImage is not yet implemented for amazon.titan-embed-image-v1 { @@ -124,6 +126,7 @@ const embeddingModels = [ provider: "openai", name: "text-embedding-ada-002", dimensions: 1536, + default: false, }, ]; @@ -175,6 +178,8 @@ const embeddingModels = [ options.startScheduleEndDate = config.llms?.sagemakerSchedule?.startScheduleEndDate; options.enableRag = config.rag.enabled; + options.enableEmbeddingModelsViaSagemaker = + config.rag.enableEmbeddingModelsViaSagemaker; options.ragsToEnable = Object.keys(config.rag.engines ?? {}).filter( (v: string) => ( @@ -577,6 +582,16 @@ async function processCreateOptions(options: any): Promise { message: "Do you want to enable RAG", initial: options.enableRag || false, }, + { + type: "confirm", + name: "enableEmbeddingModelsViaSagemaker", + message: + "Do you want to enable embedding and cross-encoder models via SageMaker?", + initial: options.enableEmbeddingModelsViaSagemaker || false, + skip(): boolean { + return !(this as any).state.answers.enableRag; + }, + }, { type: "multiselect", name: "ragsToEnable", @@ -705,7 +720,6 @@ async function processCreateOptions(options: any): Promise { if ((this as any).state.answers.enableRag) { return value ? true : "Select a default embedding model"; } - return true; }, skip() { @@ -1046,6 +1060,7 @@ async function processCreateOptions(options: any): Promise { } : undefined, llms: { + enableSagemakerModels: answers.enableSagemakerModels, sagemaker: answers.sagemakerModels, huggingfaceApiSecretArn: answers.huggingfaceApiSecretArn, sagemakerSchedule: answers.enableSagemakerModelsSchedule @@ -1065,6 +1080,8 @@ async function processCreateOptions(options: any): Promise { }, rag: { enabled: answers.enableRag, + enableEmbeddingModelsViaSagemaker: + answers.enableEmbeddingModelsViaSagemaker, engines: { aurora: { enabled: answers.ragsToEnable.includes("aurora"), @@ -1079,27 +1096,35 @@ async function processCreateOptions(options: any): Promise { enterprise: false, }, }, + crossEncodingEnabled: answers.enableEmbeddingModelsViaSagemaker, embeddingsModels: [{}], crossEncoderModels: [{}], }, }; - // If we have not enabled rag the default embedding is set to the first model - if (!answers.enableRag) { - models.defaultEmbedding = embeddingModels[0].name; - } - config.rag.crossEncoderModels[0] = { provider: "sagemaker", name: "cross-encoder/ms-marco-MiniLM-L-12-v2", default: true, }; - config.rag.embeddingsModels = embeddingModels; - config.rag.embeddingsModels.forEach((m: any) => { - if (m.name === models.defaultEmbedding) { - m.default = true; - } - }); + + if (!config.rag.enableEmbeddingModelsViaSagemaker) { + config.rag.embeddingsModels = embeddingModels.filter( + (model) => model.provider !== "sagemaker" + ); + } else { + config.rag.embeddingsModels = embeddingModels; + } + // If we have not enabled rag the default embedding is set to the first model + if (!answers.enableRag) { + (config.rag.embeddingsModels[0] as any).default = true; + } else { + config.rag.embeddingsModels.forEach((m: any) => { + if (m.name === models.defaultEmbedding) { + m.default = true; + } + }); + } config.rag.engines.kendra.createIndex = answers.ragsToEnable.includes("kendra"); diff --git a/lib/aws-genai-llm-chatbot-stack.ts b/lib/aws-genai-llm-chatbot-stack.ts index ce2469d15..5bef9f8cb 100644 --- a/lib/aws-genai-llm-chatbot-stack.ts +++ b/lib/aws-genai-llm-chatbot-stack.ts @@ -157,10 +157,9 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { identityPool: authentication.identityPool, api: chatBotApi, chatbotFilesBucket: chatBotApi.filesBucket, - crossEncodersEnabled: - typeof ragEngines?.sageMakerRagModels?.model !== "undefined", + crossEncodersEnabled: props.config.rag.crossEncodingEnabled, sagemakerEmbeddingsEnabled: - typeof ragEngines?.sageMakerRagModels?.model !== "undefined", + props.config.rag.enableEmbeddingModelsViaSagemaker, }); if (props.config.cognitoFederation?.enabled) { @@ -331,10 +330,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { ] ); - if ( - props.config.rag.engines.aurora.enabled || - props.config.rag.engines.opensearch.enabled - ) { + if (props.config.llms.enableSagemakerModels) { NagSuppressions.addResourceSuppressionsByPath( this, [ @@ -370,59 +366,59 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { }, ] ); - if (props.config.rag.engines.aurora.enabled) { - NagSuppressions.addResourceSuppressionsByPath( - this, - `/${this.stackName}/RagEngines/AuroraPgVector/AuroraDatabase/Secret/Resource`, - [ - { - id: "AwsSolutions-SMG4", - reason: "Secret created implicitly by CDK.", - }, - ] - ); - NagSuppressions.addResourceSuppressionsByPath( - this, - [ - `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupFunction/ServiceRole/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/DefaultPolicy/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspace/Role/DefaultPolicy/Resource`, - ], - [ - { - id: "AwsSolutions-IAM4", - reason: "IAM role implicitly created by CDK.", - }, - { - id: "AwsSolutions-IAM5", - reason: "IAM role implicitly created by CDK.", - }, - ] - ); - } - if (props.config.rag.engines.opensearch.enabled) { - NagSuppressions.addResourceSuppressionsByPath( - this, - [ - `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/Resource`, - `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`, - `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspace/Role/DefaultPolicy/Resource`, - ], - [ - { - id: "AwsSolutions-IAM4", - reason: "IAM role implicitly created by CDK.", - }, - { - id: "AwsSolutions-IAM5", - reason: "IAM role implicitly created by CDK.", - }, - ] - ); - } + } + if (props.config.rag.engines.aurora.enabled) { + NagSuppressions.addResourceSuppressionsByPath( + this, + `/${this.stackName}/RagEngines/AuroraPgVector/AuroraDatabase/Secret/Resource`, + [ + { + id: "AwsSolutions-SMG4", + reason: "Secret created implicitly by CDK.", + }, + ] + ); + NagSuppressions.addResourceSuppressionsByPath( + this, + [ + `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupFunction/ServiceRole/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/DefaultPolicy/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspace/Role/DefaultPolicy/Resource`, + ], + [ + { + id: "AwsSolutions-IAM4", + reason: "IAM role implicitly created by CDK.", + }, + { + id: "AwsSolutions-IAM5", + reason: "IAM role implicitly created by CDK.", + }, + ] + ); + } + if (props.config.rag.engines.opensearch.enabled) { + NagSuppressions.addResourceSuppressionsByPath( + this, + [ + `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/Resource`, + `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`, + `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspace/Role/DefaultPolicy/Resource`, + ], + [ + { + id: "AwsSolutions-IAM4", + reason: "IAM role implicitly created by CDK.", + }, + { + id: "AwsSolutions-IAM5", + reason: "IAM role implicitly created by CDK.", + }, + ] + ); } if (props.config.rag.engines.kendra.enabled) { NagSuppressions.addResourceSuppressionsByPath( diff --git a/lib/chatbot-api/functions/api-handler/routes/cross_encoders.py b/lib/chatbot-api/functions/api-handler/routes/cross_encoders.py index 8fcebbf16..effa8b3e1 100644 --- a/lib/chatbot-api/functions/api-handler/routes/cross_encoders.py +++ b/lib/chatbot-api/functions/api-handler/routes/cross_encoders.py @@ -30,14 +30,19 @@ def models(): @tracer.capture_method def cross_encoders(input: dict): request = CrossEncodersRequest(**input) - selected_model = genai_core.cross_encoder.get_cross_encoder_model( - request.provider, request.model - ) - - if selected_model is None: - raise genai_core.types.CommonError("Model not found") - - ret_value = genai_core.cross_encoder.rank_passages( - selected_model, request.reference, request.passages - ) - return [{"score": v, "passage": p} for v, p in zip(ret_value, request.passages)] + config = genai_core.parameters.get_config() + crossEncodingEnabled = config["rag"]["crossEncodingEnabled"] + if (crossEncodingEnabled): + selected_model = genai_core.cross_encoder.get_cross_encoder_model( + request.provider, request.model + ) + + if selected_model is None: + raise genai_core.types.CommonError("Model not found") + + ret_value = genai_core.cross_encoder.rank_passages( + selected_model, request.reference, request.passages + ) + return [{"score": v, "passage": p} for v, p in zip(ret_value, request.passages)] + + return [{"score": 0, "passage": p} for p in request.passages] diff --git a/lib/rag-engines/data-import/index.ts b/lib/rag-engines/data-import/index.ts index 65372b040..da4253975 100644 --- a/lib/rag-engines/data-import/index.ts +++ b/lib/rag-engines/data-import/index.ts @@ -132,7 +132,7 @@ export class DataImport extends Construct { processingBucket, auroraDatabase: props.auroraDatabase, ragDynamoDBTables: props.ragDynamoDBTables, - sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint, + sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model?.endpoint, openSearchVector: props.openSearchVector, } ); diff --git a/lib/rag-engines/index.ts b/lib/rag-engines/index.ts index 2171caa0a..21c1be28f 100644 --- a/lib/rag-engines/index.ts +++ b/lib/rag-engines/index.ts @@ -41,10 +41,7 @@ export class RagEngines extends Construct { const tables = new RagDynamoDBTables(this, "RagDynamoDBTables"); let sageMakerRagModels: SageMakerRagModels | null = null; - if ( - props.config.rag.engines.aurora.enabled || - props.config.rag.engines.opensearch.enabled - ) { + if (props.config.llms.enableSagemakerModels) { sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", { shared: props.shared, config: props.config, diff --git a/lib/rag-engines/sagemaker-rag-models/index.ts b/lib/rag-engines/sagemaker-rag-models/index.ts index b4a920ce8..840610fdd 100644 --- a/lib/rag-engines/sagemaker-rag-models/index.ts +++ b/lib/rag-engines/sagemaker-rag-models/index.ts @@ -24,20 +24,22 @@ export class SageMakerRagModels extends Construct { .filter((c) => c.provider === "sagemaker") .map((c) => c.name); - const model = new SageMakerModel(this, "Model", { - vpc: props.shared.vpc, - region: cdk.Aws.REGION, - model: { - type: DeploymentType.CustomInferenceScript, - modelId: [ - ...sageMakerEmbeddingsModelIds, - ...sageMakerCrossEncoderModelIds, - ], - codeFolder: path.join(__dirname, "./model"), - instanceType: "ml.g4dn.xlarge", - }, - }); + if (sageMakerEmbeddingsModelIds?.length > 0 || sageMakerCrossEncoderModelIds?.length > 0) { + const model = new SageMakerModel(this, "Model", { + vpc: props.shared.vpc, + region: cdk.Aws.REGION, + model: { + type: DeploymentType.CustomInferenceScript, + modelId: [ + ...sageMakerEmbeddingsModelIds, + ...sageMakerCrossEncoderModelIds, + ], + codeFolder: path.join(__dirname, "./model"), + instanceType: "ml.g4dn.xlarge", + }, + }); - this.model = model; + this.model = model; + } } } diff --git a/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py b/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py index 32b95540b..f7f1f1c83 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py +++ b/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py @@ -20,6 +20,7 @@ def query_workspace_aurora( full_response: bool, threshold: int = 0, ): + config = genai_core.parameters.get_config() table_name = sql.Identifier(workspace_id.replace("-", "")) embeddings_model_provider = workspace["embeddings_model_provider"] embeddings_model_name = workspace["embeddings_model_name"] @@ -38,13 +39,6 @@ def query_workspace_aurora( if selected_model is None: raise CommonError("Embeddings model not found") - cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model( - cross_encoder_model_provider, cross_encoder_model_name - ) - - if cross_encoder_model is None: - raise CommonError("Cross encoder model not found") - query_embeddings = genai_core.embeddings.generate_embeddings( selected_model, [query], Task.RETRIEVE )[0] @@ -186,24 +180,33 @@ def query_workspace_aurora( item["keyword_search_score"] = current["keyword_search_score"] unique_items = list(unique_items.values()) - score_dict = dict({}) - if len(unique_items) > 0: - passages = [record["content"] for record in unique_items] - passage_scores = genai_core.cross_encoder.rank_passages( - cross_encoder_model, query, passages + + if (config["rag"]["crossEncodingEnabled"]): + cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model( + cross_encoder_model_provider, cross_encoder_model_name ) - for i in range(len(unique_items)): - score = passage_scores[i] - unique_items[i]["score"] = score - score_dict[unique_items[i]["chunk_id"]] = score + if cross_encoder_model is None: + raise genai_core.types.CommonError("Cross encoder model not found") + + score_dict = dict({}) + if len(unique_items) > 0: + passages = [record["content"] for record in unique_items] + passage_scores = genai_core.cross_encoder.rank_passages( + cross_encoder_model, query, passages + ) - unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) + for i in range(len(unique_items)): + score = passage_scores[i] + unique_items[i]["score"] = score + score_dict[unique_items[i]["chunk_id"]] = score - for record in vector_search_records: - record["score"] = score_dict[record["chunk_id"]] - for record in keyword_search_records: - record["score"] = score_dict[record["chunk_id"]] + unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) + + for record in vector_search_records: + record["score"] = score_dict[record["chunk_id"]] + for record in keyword_search_records: + record["score"] = score_dict[record["chunk_id"]] if full_response: unique_items = unique_items[:limit] @@ -218,9 +221,11 @@ def query_workspace_aurora( "keyword_search_items": convert_types(keyword_search_records), } else: - ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[ - :limit - ] + if config["rag"]["crossEncodingEnabled"]: + ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[:limit] + else: + ret_items = unique_items[:limit] + if len(ret_items) < limit: # inner product metric is negative hence we sort ascending if metric == "inner": @@ -296,4 +301,4 @@ def _convert_records(source: str, records: List[dict]): converted_records.append(converted) - return converted_records + return converted_records \ No newline at end of file diff --git a/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py b/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py index 05d359012..71c3811ca 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py +++ b/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py @@ -18,6 +18,7 @@ def query_workspace_open_search( ): index_name = workspace_id.replace("-", "") + config = genai_core.parameters.get_config() embeddings_model_provider = workspace["embeddings_model_provider"] embeddings_model_name = workspace["embeddings_model_name"] cross_encoder_model_provider = workspace["cross_encoder_model_provider"] @@ -37,13 +38,6 @@ def query_workspace_open_search( if selected_model is None: raise CommonError("Embeddings model not found") - cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model( - cross_encoder_model_provider, cross_encoder_model_name - ) - - if cross_encoder_model is None: - raise CommonError("Cross encoder model not found") - query_embeddings = genai_core.embeddings.generate_embeddings( selected_model, [query], Task.RETRIEVE )[0] @@ -96,23 +90,32 @@ def query_workspace_open_search( item["keyword_search_score"] = current["keyword_search_score"] unique_items = list(unique_items.values()) - score_dict = dict({}) - if len(unique_items) > 0: - passages = [record["content"] for record in unique_items] - passage_scores = genai_core.cross_encoder.rank_passages( - cross_encoder_model, query, passages + + if (config["rag"]["crossEncodingEnabled"]): + cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model( + cross_encoder_model_provider, cross_encoder_model_name ) - for i in range(len(unique_items)): - score = passage_scores[i] - unique_items[i]["score"] = score - score_dict[unique_items[i]["chunk_id"]] = score - unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) + if cross_encoder_model is None: + raise genai_core.types.CommonError("Cross encoder model not found") + + score_dict = dict({}) + if len(unique_items) > 0: + passages = [record["content"] for record in unique_items] + passage_scores = genai_core.cross_encoder.rank_passages( + cross_encoder_model, query, passages + ) + + for i in range(len(unique_items)): + score = passage_scores[i] + unique_items[i]["score"] = score + score_dict[unique_items[i]["chunk_id"]] = score + unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) - for record in vector_search_records: - record["score"] = score_dict[record["chunk_id"]] - for record in keyword_search_records: - record["score"] = score_dict[record["chunk_id"]] + for record in vector_search_records: + record["score"] = score_dict[record["chunk_id"]] + for record in keyword_search_records: + record["score"] = score_dict[record["chunk_id"]] if full_response: unique_items = unique_items[:limit] @@ -125,9 +128,11 @@ def query_workspace_open_search( "keyword_search_items": keyword_search_records, } else: - ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[ - :limit - ] + if config["rag"]["crossEncodingEnabled"]: + ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[:limit] + else: + ret_items = unique_items[:limit] + if len(ret_items) < limit: unique_items = sorted( unique_items, key=lambda x: x["vector_search_score"] or -1, reverse=True diff --git a/lib/shared/types.ts b/lib/shared/types.ts index c9501ff64..c70d460a1 100644 --- a/lib/shared/types.ts +++ b/lib/shared/types.ts @@ -108,6 +108,7 @@ export interface SystemConfig { }; }; llms: { + enableSagemakerModels: boolean; sagemaker: SupportedSageMakerModels[]; huggingfaceApiSecretArn?: string; sagemakerSchedule?: { @@ -125,6 +126,7 @@ export interface SystemConfig { }; rag: { enabled: boolean; + enableEmbeddingModelsViaSagemaker: boolean; engines: { aurora: { enabled: boolean; @@ -150,6 +152,7 @@ export interface SystemConfig { dimensions: number; default?: boolean; }[]; + crossEncodingEnabled: boolean; crossEncoderModels: { provider: ModelProvider; name: string; diff --git a/lib/user-interface/react-app/src/common/helpers/embeddings-model-helper.ts b/lib/user-interface/react-app/src/common/helpers/embeddings-model-helper.ts index 93902dccd..56e1bd308 100644 --- a/lib/user-interface/react-app/src/common/helpers/embeddings-model-helper.ts +++ b/lib/user-interface/react-app/src/common/helpers/embeddings-model-helper.ts @@ -1,5 +1,6 @@ import { SelectProps } from "@cloudscape-design/components"; import { EmbeddingModel } from "../../API"; +import { AppConfig } from "../types"; export abstract class EmbeddingsModelHelper { static getSelectOption(model?: string): SelectProps.Option | null { @@ -32,9 +33,18 @@ export abstract class EmbeddingsModelHelper { }; } - static getSelectOptions(embeddingsModels: EmbeddingModel[]) { + static getSelectOptions( + appContext: AppConfig | null, + embeddingsModels: EmbeddingModel[] + ) { const modelsMap = new Map(); embeddingsModels.forEach((model) => { + if ( + model.provider === "sagemaker" && + !appContext?.config.sagemaker_embeddings_enabled + ) { + return; + } let items = modelsMap.get(model.provider); if (!items) { items = []; diff --git a/lib/user-interface/react-app/src/common/types.ts b/lib/user-interface/react-app/src/common/types.ts index b67839dd6..acc1eea23 100644 --- a/lib/user-interface/react-app/src/common/types.ts +++ b/lib/user-interface/react-app/src/common/types.ts @@ -71,6 +71,7 @@ export enum DocumentSubscriptionStatus { export interface AuroraWorkspaceCreateInput { name: string; embeddingsModel: SelectProps.Option | null; + crossEncodingEnabled: boolean; crossEncoderModel: SelectProps.Option | null; languages: readonly SelectProps.Option[]; metric: string; @@ -84,6 +85,7 @@ export interface OpenSearchWorkspaceCreateInput { name: string; embeddingsModel: SelectProps.Option | null; languages: readonly SelectProps.Option[]; + crossEncodingEnabled: boolean; crossEncoderModel: SelectProps.Option | null; hybridSearch: boolean; chunkSize: number; diff --git a/lib/user-interface/react-app/src/pages/rag/create-workspace/aurora-form.tsx b/lib/user-interface/react-app/src/pages/rag/create-workspace/aurora-form.tsx index 09533e860..555b57a46 100644 --- a/lib/user-interface/react-app/src/pages/rag/create-workspace/aurora-form.tsx +++ b/lib/user-interface/react-app/src/pages/rag/create-workspace/aurora-form.tsx @@ -114,6 +114,7 @@ function AuroraFooter(props: { diff --git a/lib/user-interface/react-app/src/pages/rag/create-workspace/create-workspace-aurora.tsx b/lib/user-interface/react-app/src/pages/rag/create-workspace/create-workspace-aurora.tsx index 6312501c2..8283e4cac 100644 --- a/lib/user-interface/react-app/src/pages/rag/create-workspace/create-workspace-aurora.tsx +++ b/lib/user-interface/react-app/src/pages/rag/create-workspace/create-workspace-aurora.tsx @@ -30,11 +30,12 @@ const metrics = [ const defaults: AuroraWorkspaceCreateInput = { name: "", embeddingsModel: null, + crossEncodingEnabled: false, crossEncoderModel: null, languages: [{ value: "english", label: "English" }], metric: metrics[0].value, index: true, - hybridSearch: true, + hybridSearch: false, chunkSize: 1000, chunkOverlap: 200, }; @@ -51,6 +52,8 @@ export default function CreateWorkspaceAurora() { embeddingsModel: EmbeddingsModelHelper.getSelectOption( appContext?.config.default_embeddings_model ), + crossEncodingEnabled: appContext?.config.cross_encoders_enabled || false, + hybridSearch: appContext?.config.cross_encoders_enabled || false, crossEncoderModel: OptionsHelper.getSelectOption( appContext?.config.default_cross_encoder_model ), diff --git a/lib/user-interface/react-app/src/pages/rag/create-workspace/create-workspace-opensearch.tsx b/lib/user-interface/react-app/src/pages/rag/create-workspace/create-workspace-opensearch.tsx index a297536d7..5b5c10646 100644 --- a/lib/user-interface/react-app/src/pages/rag/create-workspace/create-workspace-opensearch.tsx +++ b/lib/user-interface/react-app/src/pages/rag/create-workspace/create-workspace-opensearch.tsx @@ -15,9 +15,10 @@ const nameRegex = /^[\w+_-]+$/; const defaults: OpenSearchWorkspaceCreateInput = { name: "", embeddingsModel: null, + crossEncodingEnabled: false, crossEncoderModel: null, languages: [{ value: "english", label: "English" }], - hybridSearch: true, + hybridSearch: false, chunkSize: 1000, chunkOverlap: 200, }; @@ -34,6 +35,8 @@ export default function CreateWorkspaceOpenSearch() { embeddingsModel: EmbeddingsModelHelper.getSelectOption( appContext?.config.default_embeddings_model ), + crossEncodingEnabled: appContext?.config.cross_encoders_enabled || false, + hybridSearch: appContext?.config.cross_encoders_enabled || false, crossEncoderModel: OptionsHelper.getSelectOption( appContext?.config.default_cross_encoder_model ), diff --git a/lib/user-interface/react-app/src/pages/rag/create-workspace/cross-encoder-selector-field.tsx b/lib/user-interface/react-app/src/pages/rag/create-workspace/cross-encoder-selector-field.tsx index 2326dd2ee..014fe6135 100644 --- a/lib/user-interface/react-app/src/pages/rag/create-workspace/cross-encoder-selector-field.tsx +++ b/lib/user-interface/react-app/src/pages/rag/create-workspace/cross-encoder-selector-field.tsx @@ -9,6 +9,7 @@ import { Utils } from "../../../common/utils"; interface CrossEncoderSelectorProps { submitting: boolean; + disabled: boolean; onChange: (data: Partial<{ crossEncoderModel: SelectProps.Option }>) => void; selectedModel: SelectProps.Option | null; errors: Record; @@ -46,7 +47,7 @@ export function CrossEncoderSelectorField(props: CrossEncoderSelectorProps) { return (