Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Option to disable cross encoder models #286

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
da037d1
Currently cross encoder models are used to rank the search results bu…
azaylamba Dec 23, 2023
1c3b8ce
Enhancement: Add user feedback for responses
azaylamba Dec 24, 2023
c8dc554
Revert "Enhancement: Add user feedback for responses"
azaylamba Dec 24, 2023
550d2d0
Merge branch 'main' into main
azaylamba Jan 17, 2024
8dd11d8
Merge branch 'aws-samples:main' into main
azaylamba Jan 25, 2024
42c6edd
Merge branch 'main' of https://github.com/azaylamba/aws-genai-llm-cha…
azaylamba Feb 4, 2024
efb1a99
Addressed review comments related to cross encoding.
azaylamba Feb 4, 2024
b58737d
Removed prompt for selecting embedding models as it is not required now.
azaylamba Feb 4, 2024
cb8793d
Resolving merge conflicts
azaylamba Feb 9, 2024
cf0dfc1
Resolving merge conflicts
azaylamba Feb 9, 2024
13ce71e
Derived value of crossEncodingEnabled based on enableEmbeddingModelsV…
azaylamba Feb 9, 2024
2522839
Reverted unwanted change
azaylamba Feb 9, 2024
4669419
Merge branch 'main' into main
bigadsoleiman Feb 13, 2024
1667e9c
Merge branch 'main' into main
azaylamba Feb 24, 2024
1102491
Default embeddings model prompt was not set
azaylamba Feb 24, 2024
2047641
Merge branch 'main' into main
bigadsoleiman Mar 8, 2024
a09713e
Merge branch 'main' into main
azaylamba Apr 13, 2024
dca47d0
Corrected the NagSuppression conditions
azaylamba Apr 20, 2024
c2eabf4
Merge branch 'main' into main
azaylamba Jul 13, 2024
6a7c92b
Addressed review comments
azaylamba Jul 13, 2024
494f3b1
Added default value for cross encoder models
azaylamba Jul 15, 2024
efa9fa8
Merge branch 'main' into main
azaylamba Jul 18, 2024
61b73d2
Used enableSagemakerModels config for SM models
azaylamba Jul 18, 2024
feb5752
Merge branch 'main' of https://github.com/azaylamba/aws-genai-llm-cha…
azaylamba Jul 18, 2024
6850a9a
Merge branch 'main' into main
azaylamba Aug 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ function getTimeZonesWithCurrentTime(): { message: string; name: string }[] {
function getCountryCodesAndNames(): { message: string; name: string }[] {
// Use country-list to get an array of countries with their codes and names
const countries = getData();

// Map the country data to match the desired output structure
const countryInfo = countries.map(({ code, name }) => {
return { message: `${name} (${code})`, name: code };
Expand Down Expand Up @@ -177,6 +176,8 @@ const embeddingModels = [
options.startScheduleEndDate =
config.llms?.sagemakerSchedule?.startScheduleEndDate;
options.enableRag = config.rag.enabled;
options.enableEmbeddingModelsViaSagemaker =
config.rag.enableEmbeddingModelsViaSagemaker;
options.ragsToEnable = Object.keys(config.rag.engines ?? {}).filter(
(v: string) => (config.rag.engines as any)[v].enabled
);
Expand Down Expand Up @@ -575,7 +576,8 @@ async function processCreateOptions(options: any): Promise<void> {
{
type: "confirm",
name: "enableEmbeddingModelsViaSagemaker",
message: "Do you want to enable embedding models via SageMaker?",
message:
"Do you want to enable embedding and cross-encoder models via SageMaker?",
initial: options.enableEmbeddingModelsViaSagemaker || false,
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
skip(): boolean {
return !(this as any).state.answers.enableRag;
Expand Down Expand Up @@ -1091,10 +1093,7 @@ async function processCreateOptions(options: any): Promise<void> {
},
};

if (
answers.enableEmbeddingModelsViaSagemaker &&
answers.enableSagemakerModels
) {
if (config.rag.crossEncodingEnabled) {
config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
Expand All @@ -1108,7 +1107,9 @@ async function processCreateOptions(options: any): Promise<void> {
};
}
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
if (!config.rag.enableEmbeddingModelsViaSagemaker) {
config.rag.embeddingsModels = embeddingModels.filter(model => model.provider !== "sagemaker");
config.rag.embeddingsModels = embeddingModels.filter(
(model) => model.provider !== "sagemaker"
);
} else {
config.rag.embeddingsModels = embeddingModels;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/rag-engines/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export class RagEngines extends Construct {
const tables = new RagDynamoDBTables(this, "RagDynamoDBTables");

let sageMakerRagModels: SageMakerRagModels | null = null;
if (props.config.llms.enableSagemakerModels) {
if (props.config.rag.crossEncodingEnabled) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were right, props.config.llms.enableSagemakerModels is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the condition.

sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", {
shared: props.shared,
config: props.config,
Expand Down
30 changes: 13 additions & 17 deletions lib/user-interface/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import { Shared } from "../shared";
import { SystemConfig } from "../shared/types";
import { Utils } from "../shared/utils";
import { ChatBotApi } from "../chatbot-api";
import { PrivateWebsite } from "./private-website"
import { PublicWebsite } from "./public-website"
import { PrivateWebsite } from "./private-website";
import { PublicWebsite } from "./public-website";
import { NagSuppressions } from "cdk-nag";

export interface UserInterfaceProps {
Expand Down Expand Up @@ -52,13 +52,13 @@ export class UserInterface extends Construct {
removalPolicy: cdk.RemovalPolicy.DESTROY,
blockPublicAccess: s3.BlockPublicAccess.BLOCK_ALL,
autoDeleteObjects: true,
bucketName: props.config.privateWebsite ? props.config.domain : undefined,
bucketName: props.config.privateWebsite ? props.config.domain : undefined,
websiteIndexDocument: "index.html",
websiteErrorDocument: "index.html",
enforceSSL: true,
serverAccessLogsBucket: uploadLogsBucket,
});

// Deploy either Private (only accessible within VPC) or Public facing website
let apiEndpoint: string;
let websocketEndpoint: string;
Expand All @@ -74,8 +74,6 @@ export class UserInterface extends Construct {
this.publishedDomain = distribution.distributionDomainName;
}



const exportsAsset = s3deploy.Source.jsonData("aws-exports.json", {
aws_project_region: cdk.Aws.REGION,
aws_cognito_region: cdk.Aws.REGION,
Expand Down Expand Up @@ -118,6 +116,8 @@ export class UserInterface extends Construct {
rag_enabled: props.config.rag.enabled,
cross_encoders_enabled: props.crossEncodersEnabled,
sagemaker_embeddings_enabled: props.sagemakerEmbeddingsEnabled,
enable_embedding_models_via_sagemaker:
props.config.rag.enableEmbeddingModelsViaSagemaker,
default_embeddings_model: Utils.getDefaultEmbeddingsModel(props.config),
default_cross_encoder_model: Utils.getDefaultCrossEncoderModel(
props.config
Expand Down Expand Up @@ -221,21 +221,17 @@ export class UserInterface extends Construct {
prune: false,
sources: [asset, exportsAsset],
destinationBucket: websiteBucket,
distribution: props.config.privateWebsite ? undefined : distribution
distribution: props.config.privateWebsite ? undefined : distribution,
});


/**
* CDK NAG suppression
*/
NagSuppressions.addResourceSuppressions(
uploadLogsBucket,
[
{
id: "AwsSolutions-S1",
reason: "Bucket is the server access logs bucket for websiteBucket.",
},
]
);
NagSuppressions.addResourceSuppressions(uploadLogsBucket, [
{
id: "AwsSolutions-S1",
reason: "Bucket is the server access logs bucket for websiteBucket.",
},
]);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { SelectProps } from "@cloudscape-design/components";
import { EmbeddingModel } from "../../API";
import { AppConfig } from "../types";

export abstract class EmbeddingsModelHelper {
static getSelectOption(model?: string): SelectProps.Option | null {
Expand Down Expand Up @@ -32,9 +33,18 @@ export abstract class EmbeddingsModelHelper {
};
}

static getSelectOptions(embeddingsModels: EmbeddingModel[]) {
static getSelectOptions(
appContext: AppConfig | null,
embeddingsModels: EmbeddingModel[]
) {
const modelsMap = new Map<string, EmbeddingModel[]>();
embeddingsModels.forEach((model) => {
if (
model.provider === "sagemaker" &&
!appContext?.config.enable_embedding_models_via_sagemaker
) {
return;
}
let items = modelsMap.get(model.provider);
if (!items) {
items = [];
Expand Down
1 change: 1 addition & 0 deletions lib/user-interface/react-app/src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export interface AppConfig {
rag_enabled: boolean;
cross_encoders_enabled: boolean;
sagemaker_embeddings_enabled: boolean;
enable_embedding_models_via_sagemaker: boolean;
massi-ang marked this conversation as resolved.
Show resolved Hide resolved
api_endpoint: string;
websocket_endpoint: string;
default_embeddings_model: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,20 @@ function AuroraFooter(props: {
Create an index
</Toggle>
</FormField>
{props.data.crossEncodingEnabled && (
<>
<HybridSearchField
submitting={props.submitting}
errors={props.errors}
checked={props.data.hybridSearch}
onChange={props.onChange}
/>
<CrossEncoderSelectorField
errors={props.errors}
submitting={props.submitting}
selectedModel={props.data.crossEncoderModel}
onChange={props.onChange}
/>
</>
)}
<HybridSearchField
submitting={props.submitting}
disabled={!props.data.crossEncodingEnabled}
errors={props.errors}
checked={props.data.hybridSearch}
onChange={props.onChange}
/>
<CrossEncoderSelectorField
errors={props.errors}
submitting={props.submitting}
disabled={!props.data.crossEncodingEnabled}
selectedModel={props.data.crossEncoderModel}
onChange={props.onChange}
/>
<ChunkSelectorField
submitting={props.submitting}
onChange={props.onChange}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { Utils } from "../../../common/utils";

interface CrossEncoderSelectorProps {
submitting: boolean;
disabled: boolean;
onChange: (data: Partial<{ crossEncoderModel: SelectProps.Option }>) => void;
selectedModel: SelectProps.Option | null;
errors: Record<string, string | string[]>;
Expand Down Expand Up @@ -45,7 +46,7 @@ export function CrossEncoderSelectorField(props: CrossEncoderSelectorProps) {

return (
<Select
disabled={props.submitting}
disabled={props.submitting || props.disabled}
selectedAriaLabel="Selected"
placeholder="Choose a cross-encoder model"
statusType={crossEncoderModelsStatus}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export default function EmbeddingSelector(props: EmbeddingsSelectionProps) {
}, [appContext]);

const embeddingsModelOptions =
EmbeddingsModelHelper.getSelectOptions(embeddingsModels);
EmbeddingsModelHelper.getSelectOptions(appContext, embeddingsModels);

return (
<FormField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { FormField, Toggle } from "@cloudscape-design/components";

interface HybridSearchProps {
submitting: boolean;
disabled: boolean;
onChange: (data: Partial<{ hybridSearch: boolean }>) => void;
checked: boolean;
errors: Record<string, string | string[]>;
Expand All @@ -15,7 +16,7 @@ export function HybridSearchField(props: HybridSearchProps) {
errorText={props.errors.hybridSearch}
>
<Toggle
disabled={props.submitting}
disabled={props.submitting || props.disabled}
checked={props.checked}
onChange={({ detail: { checked } }) =>
props.onChange({ hybridSearch: checked })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,20 @@ function OpenSearchFooter(props: {
return (
<ExpandableSection headerText="Additional settings" variant="footer">
<SpaceBetween size="l">
{props.data.crossEncodingEnabled && (
<>
<HybridSearchField
submitting={props.submitting}
errors={props.errors}
checked={props.data.hybridSearch}
onChange={props.onChange}
/>
<CrossEncoderSelectorField
errors={props.errors}
submitting={props.submitting}
selectedModel={props.data.crossEncoderModel}
onChange={props.onChange}
/>
</>
)}
<HybridSearchField
submitting={props.submitting}
disabled={!props.data.crossEncodingEnabled}
errors={props.errors}
checked={props.data.hybridSearch}
onChange={props.onChange}
/>
<CrossEncoderSelectorField
errors={props.errors}
submitting={props.submitting}
disabled={!props.data.crossEncodingEnabled}
selectedModel={props.data.crossEncoderModel}
onChange={props.onChange}
/>
<ChunkSelectorField
submitting={props.submitting}
onChange={props.onChange}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ export default function Embeddings() {
};

const embeddingsModelOptions = EmbeddingsModelHelper.getSelectOptions(
appContext,
embeddingsModelsResults
);

Expand Down