From 088d6ab57e025574a2f069931b139d5e63e0cc67 Mon Sep 17 00:00:00 2001 From: GREG BONE Date: Thu, 10 Oct 2024 20:28:24 -0600 Subject: [PATCH 1/5] feat: new sessions table --- .../chatbot-dynamodb-tables/index.ts | 14 +- .../functions/api-handler/routes/sessions.py | 6 +- lib/chatbot-api/index.ts | 3 - lib/chatbot-api/rest-api.ts | 2 - lib/model-interfaces/idefics/index.ts | 2 - lib/model-interfaces/langchain/index.ts | 2 - .../langchain/chat_message_history.py | 167 ++++++++++++------ .../python-sdk/python/genai_core/sessions.py | 167 ++++++++++++++---- .../src/components/chatbot/sessions.tsx | 31 +++- 9 files changed, 285 insertions(+), 109 deletions(-) diff --git a/lib/chatbot-api/chatbot-dynamodb-tables/index.ts b/lib/chatbot-api/chatbot-dynamodb-tables/index.ts index a7e54455e..920a60102 100644 --- a/lib/chatbot-api/chatbot-dynamodb-tables/index.ts +++ b/lib/chatbot-api/chatbot-dynamodb-tables/index.ts @@ -15,13 +15,16 @@ export class ChatBotDynamoDBTables extends Construct { constructor(scope: Construct, id: string, props: ChatBotDynamoDBTablesProps) { super(scope, id); - const sessionsTable = new dynamodb.Table(this, "SessionsTable", { + // Create the sessions table with a partition key of USER# + // and a sort key of SK of SESSION#> + // No need to the global secondary index for this table + const sessionsTable = new dynamodb.Table(this, "SessionTable", { partitionKey: { - name: "SessionId", + name: "PK", type: dynamodb.AttributeType.STRING, }, sortKey: { - name: "UserId", + name: "SK", type: dynamodb.AttributeType.STRING, }, billingMode: dynamodb.BillingMode.PAY_PER_REQUEST, @@ -36,11 +39,6 @@ export class ChatBotDynamoDBTables extends Construct { pointInTimeRecovery: true, }); - sessionsTable.addGlobalSecondaryIndex({ - indexName: this.byUserIdIndex, - partitionKey: { name: "UserId", type: dynamodb.AttributeType.STRING }, - }); - this.sessionsTable = sessionsTable; } } diff --git a/lib/chatbot-api/functions/api-handler/routes/sessions.py b/lib/chatbot-api/functions/api-handler/routes/sessions.py index c98bd60fc..01f04d6f0 100644 --- a/lib/chatbot-api/functions/api-handler/routes/sessions.py +++ b/lib/chatbot-api/functions/api-handler/routes/sessions.py @@ -44,9 +44,7 @@ def get_sessions(): return [ { "id": session.get("SessionId"), - "title": session.get("History", [{}])[0] - .get("data", {}) - .get("content", ""), + "title": session.get("Title", ""), "startTime": f'{session.get("StartTime")}Z', } for session in sessions @@ -76,7 +74,7 @@ def get_session(id: str): "type": item.get("type"), "content": item.get("data", {}).get("content"), "metadata": json.dumps( - item.get("data", {}).get("additional_kwargs"), + item.get("data", {}).get("additional_kwargs", {}), cls=genai_core.utils.json.CustomEncoder, ), } diff --git a/lib/chatbot-api/index.ts b/lib/chatbot-api/index.ts index 88adda71f..a495ce4c4 100644 --- a/lib/chatbot-api/index.ts +++ b/lib/chatbot-api/index.ts @@ -34,7 +34,6 @@ export class ChatBotApi extends Construct { public readonly messagesTopic: sns.Topic; public readonly outBoundQueue: sqs.Queue; public readonly sessionsTable: dynamodb.Table; - public readonly byUserIdIndex: string; public readonly filesBucket: s3.Bucket; public readonly userFeedbackBucket: s3.Bucket; public readonly graphqlApi: appsync.GraphqlApi; @@ -120,7 +119,6 @@ export class ChatBotApi extends Construct { const apiResolvers = new ApiResolvers(this, "RestApi", { ...props, sessionsTable: chatTables.sessionsTable, - byUserIdIndex: chatTables.byUserIdIndex, api, userFeedbackBucket: chatBuckets.userFeedbackBucket, filesBucket: chatBuckets.filesBucket, @@ -158,7 +156,6 @@ export class ChatBotApi extends Construct { this.messagesTopic = realtimeBackend.messagesTopic; this.outBoundQueue = realtimeBackend.queue; this.sessionsTable = chatTables.sessionsTable; - this.byUserIdIndex = chatTables.byUserIdIndex; this.userFeedbackBucket = chatBuckets.userFeedbackBucket; this.filesBucket = chatBuckets.filesBucket; this.graphqlApi = api; diff --git a/lib/chatbot-api/rest-api.ts b/lib/chatbot-api/rest-api.ts index 21349316e..18620d3a7 100644 --- a/lib/chatbot-api/rest-api.ts +++ b/lib/chatbot-api/rest-api.ts @@ -23,7 +23,6 @@ export interface ApiResolversProps { readonly ragEngines?: RagEngines; readonly userPool: cognito.UserPool; readonly sessionsTable: dynamodb.Table; - readonly byUserIdIndex: string; readonly filesBucket: s3.Bucket; readonly userFeedbackBucket: s3.Bucket; readonly modelsParameter: ssm.StringParameter; @@ -70,7 +69,6 @@ export class ApiResolvers extends Construct { props.shared.xOriginVerifySecret.secretArn, API_KEYS_SECRETS_ARN: props.shared.apiKeysSecret.secretArn, SESSIONS_TABLE_NAME: props.sessionsTable.tableName, - SESSIONS_BY_USER_ID_INDEX_NAME: props.byUserIdIndex, USER_FEEDBACK_BUCKET_NAME: props.userFeedbackBucket?.bucketName ?? "", UPLOAD_BUCKET_NAME: props.ragEngines?.uploadBucket?.bucketName ?? "", CHATBOT_FILES_BUCKET_NAME: props.filesBucket.bucketName, diff --git a/lib/model-interfaces/idefics/index.ts b/lib/model-interfaces/idefics/index.ts index 8e52bc3b4..4e457f9cc 100644 --- a/lib/model-interfaces/idefics/index.ts +++ b/lib/model-interfaces/idefics/index.ts @@ -21,7 +21,6 @@ interface IdeficsInterfaceProps { readonly config: SystemConfig; readonly messagesTopic: sns.Topic; readonly sessionsTable: dynamodb.Table; - readonly byUserIdIndex: string; readonly chatbotFilesBucket: s3.Bucket; readonly createPrivateGateway: boolean; } @@ -68,7 +67,6 @@ export class IdeficsInterface extends Construct { ...props.shared.defaultEnvironmentVariables, CONFIG_PARAMETER_NAME: props.shared.configParameter.parameterName, SESSIONS_TABLE_NAME: props.sessionsTable.tableName, - SESSIONS_BY_USER_ID_INDEX_NAME: props.byUserIdIndex, MESSAGES_TOPIC_ARN: props.messagesTopic.topicArn, CHATBOT_FILES_BUCKET_NAME: props.chatbotFilesBucket.bucketName, CHATBOT_FILES_PRIVATE_API: api?.url ?? "", diff --git a/lib/model-interfaces/langchain/index.ts b/lib/model-interfaces/langchain/index.ts index 31c1f730e..24b5583b3 100644 --- a/lib/model-interfaces/langchain/index.ts +++ b/lib/model-interfaces/langchain/index.ts @@ -20,7 +20,6 @@ interface LangChainInterfaceProps { readonly ragEngines?: RagEngines; readonly messagesTopic: sns.Topic; readonly sessionsTable: dynamodb.Table; - readonly byUserIdIndex: string; } export class LangChainInterface extends Construct { @@ -51,7 +50,6 @@ export class LangChainInterface extends Construct { ...props.shared.defaultEnvironmentVariables, CONFIG_PARAMETER_NAME: props.shared.configParameter.parameterName, SESSIONS_TABLE_NAME: props.sessionsTable.tableName, - SESSIONS_BY_USER_ID_INDEX_NAME: props.byUserIdIndex, API_KEYS_SECRETS_ARN: props.shared.apiKeysSecret.secretArn, MESSAGES_TOPIC_ARN: props.messagesTopic.topicArn, WORKSPACES_TABLE_NAME: diff --git a/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py b/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py index aa686ef56..228d2ceb9 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py +++ b/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py @@ -5,18 +5,19 @@ from decimal import Decimal from datetime import datetime from botocore.exceptions import ClientError +from langchain_core.messages.ai import AIMessage, AIMessageChunk +from operator import itemgetter from langchain.schema import BaseChatMessageHistory from langchain.schema.messages import ( BaseMessage, _message_to_dict, - messages_from_dict, - messages_to_dict, + _message_from_dict, ) -from langchain_core.messages.ai import AIMessage, AIMessageChunk +from genai_core.sessions import delete_session client = boto3.resource("dynamodb") -logger = Logger() +logger = Logger(level="DEBUG") class DynamoDBChatMessageHistory(BaseChatMessageHistory): @@ -25,55 +26,111 @@ def __init__( table_name: str, session_id: str, user_id: str, + max_messages: int = None, # Added max_messages parameter ): self.table = client.Table(table_name) self.session_id = session_id self.user_id = user_id + self.max_messages = max_messages # Store max_messages + + + def _get_full_history(self) -> List[BaseMessage]: + """Query all messages from DynamoDB for the current session""" + messages: List[BaseMessage] = [] + response = self.table.query( + KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)", + FilterExpression="#itemType = :itemType", + ExpressionAttributeNames={ + "#pk": "PK", + "#sk": "SK", + "#itemType": "ItemType" + }, + ScanIndexForward=True, + ExpressionAttributeValues={ + ":user_id": f"USER#{self.user_id}", + ":session_prefix": f"SESSION#{self.session_id}", + ":itemType": "message" + } + ) + items = response.get('Items', []) + + return items @property def messages(self) -> List[BaseMessage]: - """Retrieve the messages from DynamoDB""" - response = None - try: - response = self.table.get_item( - Key={"SessionId": self.session_id, "UserId": self.user_id} - ) - except ClientError as error: - if error.response["Error"]["Code"] == "ResourceNotFoundException": - logger.warning("No record found with session id: %s", self.session_id) - else: - logger.exception(error) + """Get the last max_messages from the full history""" + full_history_items = self._get_full_history() - if response and "Item" in response: - items = response["Item"]["History"] - else: - items = [] + # Hande case where max_messages is None + if self.max_messages is None: + self.max.messages = len(full_history_items) + + # Slice before processing + relevant_items = full_history_items[-self.max_messages:] - messages = messages_from_dict(items) - return messages + # Use itemgetter and list comprehension + get_history_data = itemgetter('History') + return [_message_from_dict(get_history_data(item) or '') for item in relevant_items] def add_message(self, message: BaseMessage) -> None: """Append the message to the record in DynamoDB""" - messages = messages_to_dict(self.messages) - if isinstance(message, AIMessageChunk): - # When streaming with RunnableWithMessageHistory, - # it would add a chunk to the history but it expects a text as content. - ai_message = "" - for c in message.content: - if "text" in c: - ai_message = ai_message + c.get("text") - _message = _message_to_dict(AIMessage(ai_message)) - else: - _message = _message_to_dict(message) - messages.append(_message) - try: + current_time = datetime.now().isoformat() + + # messages = messages_to_dict(self.messages) + if isinstance(message, AIMessageChunk): + # When streaming with RunnableWithMessageHistory, + # it would add a chunk to the history but it expects a text as content. + ai_message = "" + for c in message.content: + if "text" in c: + ai_message = ai_message + c.get("text") + _message = _message_to_dict(AIMessage(ai_message)) + else: + _message = _message_to_dict(message) + + try: + self.table.update_item( + Key={ + "PK": f"USER#{self.user_id}", + "SK": f"SESSION#{self.session_id}" + }, + UpdateExpression="SET LastUpdateTime = :time", + ConditionExpression="attribute_exists(PK)", + ExpressionAttributeValues={ + ":time": current_time + } + ) + except ClientError as err: + if err.response['Error']['Code'] == 'ConditionalCheckFailedException': + # Session doesn't exist, so create a new one + self.table.put_item( + Item={ + "PK": f"USER#{self.user_id}", + "SK": f"SESSION#{self.session_id}", + "Title": _message_to_dict(message) + .get("data", {}) + .get("content", ""), + "StartTime": current_time, + "ItemType": "session", + "SessionId": self.session_id, + "LastUpdateTime": current_time, + } + ) + else: + # If some other error occurs, re-raise the exception + raise + + + self.table.put_item( Item={ - "SessionId": self.session_id, - "UserId": self.user_id, - "StartTime": datetime.now().isoformat(), - "History": messages, + "PK": f"USER#{self.user_id}", + "SK": f"SESSION#{self.session_id}#{current_time}", + "StartTime": current_time, + "History": _message, # Store full history in DynamoDB + "ItemType": "message", + "Role": _message.get("type"), } ) except ClientError as err: @@ -81,31 +138,41 @@ def add_message(self, message: BaseMessage) -> None: def add_metadata(self, metadata: dict) -> None: """Add additional metadata to the last message""" - messages = messages_to_dict(self.messages) - if not messages: + full_history_items = self._get_full_history() + if not full_history_items: return metadata = json.loads(json.dumps(metadata), parse_float=Decimal) - messages[-1]["data"]["additional_kwargs"] = metadata + + most_recent_history = full_history_items[-1] + + most_recent_history["History"]["data"]["additional_kwargs"] = metadata try: - self.table.put_item( - Item={ - "SessionId": self.session_id, - "UserId": self.user_id, - "StartTime": datetime.now().isoformat(), - "History": messages, + + # Perform the update operation + self.table.update_item( + Key={ + "PK": f"USER#{self.user_id}", + "SK": f"SESSION#{self.session_id}#{most_recent_history['StartTime']}" + }, + UpdateExpression="SET #data = :data", + ExpressionAttributeNames={ + "#data": "History" + }, + ExpressionAttributeValues={ + ":data": most_recent_history["History"] } ) except Exception as err: logger.exception(err) + logger.exception(f"Failed to update metadata: {err}") def clear(self) -> None: """Clear session memory from DynamoDB""" try: - self.table.delete_item( - Key={"SessionId": self.session_id, "UserId": self.user_id} - ) + delete_session(self.session_id, self.user_id) + except ClientError as err: logger.exception(err) diff --git a/lib/shared/layers/python-sdk/python/genai_core/sessions.py b/lib/shared/layers/python-sdk/python/genai_core/sessions.py index 9f2ea5c96..5269932a5 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/sessions.py +++ b/lib/shared/layers/python-sdk/python/genai_core/sessions.py @@ -1,68 +1,164 @@ import os from aws_lambda_powertools import Logger import boto3 +import json from botocore.exceptions import ClientError +from boto3.dynamodb.conditions import Key, Attr +from typing import List, Dict, Any AWS_REGION = os.environ["AWS_REGION"] SESSIONS_TABLE_NAME = os.environ["SESSIONS_TABLE_NAME"] -SESSIONS_BY_USER_ID_INDEX_NAME = os.environ["SESSIONS_BY_USER_ID_INDEX_NAME"] dynamodb = boto3.resource("dynamodb", region_name=AWS_REGION) table = dynamodb.Table(SESSIONS_TABLE_NAME) logger = Logger() +def _get_messages_by_session_id(session_id, user_id): + items = [] + try: + response = table.query( + KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)", + FilterExpression="#item_type = :session_type", + ExpressionAttributeNames={ + "#pk": "PK", + "#sk": "SK", + "#item_type": "ItemType" + }, + ExpressionAttributeValues={ + ':user_id': f'USER#{user_id}', + ':session_prefix': f'SESSION#{session_id}', + ':session_type': 'message' + }, + ScanIndexForward=True + ) + + items = response.get('Items', []) + + # If there are more items, continue querying + while 'LastEvaluatedKey' in response: + response = table.query( + KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)", + ExpressionAttributeNames={ + "#pk": "PK", + "#sk": "SK" + }, + ExpressionAttributeValues={ + ':user_id': f'USER#{user_id}', + ':session_prefix': f'SESSION#{session_id}' + }, + ScanIndexForward=True + ) + items.extend(response.get('Items', [])) + + except ClientError as error: + if error.response["Error"]["Code"] == "ResourceNotFoundException": + logger.warning("No record found with session id: %s", session_id) + else: + logger.exception(error) + + return items def get_session(session_id, user_id): - response = {} try: - response = table.get_item(Key={"SessionId": session_id, "UserId": user_id}) + items = _get_messages_by_session_id(session_id, user_id) + + # Build data structure so that it can be returned to the client + returnItem = { + "SessionId": session_id, + "UserId": user_id, + "History": [], + "StartTime": None, + } + + for item in items: + if 'ItemType' in item: + if item['ItemType'] == 'message': + returnItem['History'].append(item['History']) + returnItem['StartTime']= item['StartTime'] + + except ClientError as error: if error.response["Error"]["Code"] == "ResourceNotFoundException": logger.warning("No record found with session id: %s", session_id) else: logger.exception(error) - return response.get("Item", {}) + return returnItem +def list_sessions_by_user_id(user_id: str) -> List[Dict[str, Any]]: + """ + List all sessions for a given user ID. -def list_sessions_by_user_id(user_id): - items = [] + Args: + user_id (str): The ID of the user. + + Returns: + List[Dict[str, Any]]: A list of session items. + """ + session_items = [] try: last_evaluated_key = None while True: + query_params = { + "KeyConditionExpression": "#pk = :user_id AND begins_with(#sk, :session_prefix)", + "FilterExpression": "#item_type = :session_type", + "ExpressionAttributeNames": { + "#pk": "PK", + "#sk": "SK", + "#item_type": "ItemType" + }, + "ExpressionAttributeValues": { + ":user_id": f"USER#{user_id}", + ":session_prefix": "SESSION#", + ":session_type": "session" + } + } + if last_evaluated_key: - response = table.query( - KeyConditionExpression="UserId = :user_id", - ExpressionAttributeValues={":user_id": user_id}, - IndexName=SESSIONS_BY_USER_ID_INDEX_NAME, - ExclusiveStartKey=last_evaluated_key, - ) - else: - response = table.query( - KeyConditionExpression="UserId = :user_id", - ExpressionAttributeValues={":user_id": user_id}, - IndexName=SESSIONS_BY_USER_ID_INDEX_NAME, - ) - - items.extend(response.get("Items", [])) + query_params["ExclusiveStartKey"] = last_evaluated_key + + response = table.query(**query_params) + + session_items.extend(response.get("Items", [])) last_evaluated_key = response.get("LastEvaluatedKey") if not last_evaluated_key: break + logger.info(f"Retrieved {len(session_items)} sessions for user {user_id}") except ClientError as error: if error.response["Error"]["Code"] == "ResourceNotFoundException": - logger.warning("No record found for user id: %s", user_id) + logger.warning(f"No records found for user id: {user_id}") else: - logger.exception(error) - - return items + logger.exception(f"Error retrieving sessions for user {user_id}: {error}") + return session_items def delete_session(session_id, user_id): try: - table.delete_item(Key={"SessionId": session_id, "UserId": user_id}) + session_history = _get_messages_by_session_id(session_id, user_id) + + if not session_history: + return {"id": session_id, "deleted": False} + + # Delete messages in session history + for item in session_history: + table.delete_item( + Key={ + "PK": item["PK"], + "SK": item["SK"], + } + ) + + # Delete the session item + table.delete_item( + Key={ + "PK": f"USER#{user_id}", + "SK": f"SESSION#{session_id}", + } + ) + except ClientError as error: if error.response["Error"]["Code"] == "ResourceNotFoundException": logger.warning("No record found with session id: %s", session_id) @@ -74,12 +170,23 @@ def delete_session(session_id, user_id): return {"id": session_id, "deleted": True} + def delete_user_sessions(user_id): - sessions = list_sessions_by_user_id(user_id) - ret_value = [] + try: + sessions = list_sessions_by_user_id(user_id) # Get all sessions for the user + ret_value = [] - for session in sessions: - result = delete_session(session["SessionId"], user_id) - ret_value.append({"id": session["SessionId"], "deleted": result["deleted"]}) + for session in sessions: + # Extract the session ID from the SK (assuming SK is in the format 'SESSION#') + session_id = session["SK"].split("#")[1] # Extracting session ID from 'SESSION#' + + # Delete each session + result = delete_session(session_id, user_id) + ret_value.append({"id": session_id, "deleted": result["deleted"]}) + except ClientError as error: + if error.response["Error"]["Code"] == "ResourceNotFoundException": + logger.warning("No record found for user id: %s", user_id) + else: + logger.exception(error) return ret_value diff --git a/lib/user-interface/react-app/src/components/chatbot/sessions.tsx b/lib/user-interface/react-app/src/components/chatbot/sessions.tsx index 1cc756aac..a189a103f 100644 --- a/lib/user-interface/react-app/src/components/chatbot/sessions.tsx +++ b/lib/user-interface/react-app/src/components/chatbot/sessions.tsx @@ -87,11 +87,19 @@ export default function Sessions(props: SessionsProps) { setIsLoading(true); const apiClient = new ApiClient(appContext); - await Promise.all( - selectedItems.map((s) => apiClient.sessions.deleteSession(s.id)) - ); - await getSessions(); - setIsLoading(false); + try { + await Promise.all( + selectedItems.map((s) => apiClient.sessions.deleteSession(s.id)) + ); + await getSessions(); + setSelectedItems([]); // Clear selected items + } catch (error) { + console.error("Error deleting sessions:", error); + setGlobalError("Failed to delete selected sessions. Please try again."); + } finally { + setIsLoading(false); + setShowModalDelete(false); // Close the modal regardless of success or failure + } }; const deleteUserSessions = async () => { @@ -99,9 +107,16 @@ export default function Sessions(props: SessionsProps) { setIsLoading(true); const apiClient = new ApiClient(appContext); - await apiClient.sessions.deleteSessions(); - await getSessions(); - setIsLoading(false); + try { + await apiClient.sessions.deleteSessions(); + await getSessions(); + } catch (error) { + console.error("Error deleting all sessions:", error); + setGlobalError("Failed to delete all sessions. Please try again."); + } finally { + setIsLoading(false); + setDeleteAllSessions(false); // Close the modal regardless of success or failure + } }; return ( From 22ccda28290773a69742d347cde369108f96c246 Mon Sep 17 00:00:00 2001 From: GREG BONE Date: Thu, 10 Oct 2024 20:38:33 -0600 Subject: [PATCH 2/5] chore: fix snapshot test failures --- lib/aws-genai-llm-chatbot-stack.ts | 2 - .../chatbot-dynamodb-tables/index.ts | 2 +- tests/__snapshots__/cdk-app.test.ts.snap | 64 ++----------------- .../chatbot-api-construct.test.ts.snap | 36 ++--------- 4 files changed, 13 insertions(+), 91 deletions(-) diff --git a/lib/aws-genai-llm-chatbot-stack.ts b/lib/aws-genai-llm-chatbot-stack.ts index c73837895..7563b4297 100644 --- a/lib/aws-genai-llm-chatbot-stack.ts +++ b/lib/aws-genai-llm-chatbot-stack.ts @@ -77,7 +77,6 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { ragEngines, messagesTopic: chatBotApi.messagesTopic, sessionsTable: chatBotApi.sessionsTable, - byUserIdIndex: chatBotApi.byUserIdIndex, }); // Route all incoming messages targeted to langchain to the langchain model interface queue @@ -120,7 +119,6 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { config: props.config, messagesTopic: chatBotApi.messagesTopic, sessionsTable: chatBotApi.sessionsTable, - byUserIdIndex: chatBotApi.byUserIdIndex, chatbotFilesBucket: chatBotApi.filesBucket, createPrivateGateway: ideficsModels.length > 0, }); diff --git a/lib/chatbot-api/chatbot-dynamodb-tables/index.ts b/lib/chatbot-api/chatbot-dynamodb-tables/index.ts index 920a60102..3e0ae1147 100644 --- a/lib/chatbot-api/chatbot-dynamodb-tables/index.ts +++ b/lib/chatbot-api/chatbot-dynamodb-tables/index.ts @@ -18,7 +18,7 @@ export class ChatBotDynamoDBTables extends Construct { // Create the sessions table with a partition key of USER# // and a sort key of SK of SESSION#> // No need to the global secondary index for this table - const sessionsTable = new dynamodb.Table(this, "SessionTable", { + const sessionsTable = new dynamodb.Table(this, "SessionsTable", { partitionKey: { name: "PK", type: dynamodb.AttributeType.STRING, diff --git a/tests/__snapshots__/cdk-app.test.ts.snap b/tests/__snapshots__/cdk-app.test.ts.snap index ba4a1dc45..cc4c8875a 100644 --- a/tests/__snapshots__/cdk-app.test.ts.snap +++ b/tests/__snapshots__/cdk-app.test.ts.snap @@ -975,36 +975,22 @@ def submit_response(event: dict, context, response_status: str, error_message: s "Properties": { "AttributeDefinitions": [ { - "AttributeName": "SessionId", + "AttributeName": "PK", "AttributeType": "S", }, { - "AttributeName": "UserId", + "AttributeName": "SK", "AttributeType": "S", }, ], "BillingMode": "PAY_PER_REQUEST", - "GlobalSecondaryIndexes": [ - { - "IndexName": "byUserId", - "KeySchema": [ - { - "AttributeName": "UserId", - "KeyType": "HASH", - }, - ], - "Projection": { - "ProjectionType": "ALL", - }, - }, - ], "KeySchema": [ { - "AttributeName": "SessionId", + "AttributeName": "PK", "KeyType": "HASH", }, { - "AttributeName": "UserId", + "AttributeName": "SK", "KeyType": "RANGE", }, ], @@ -3273,7 +3259,6 @@ schema { "EndpointName", ], }, - "SESSIONS_BY_USER_ID_INDEX_NAME": "byUserId", "SESSIONS_TABLE_NAME": { "Ref": "ChatBotApiChatDynamoDBTablesSessionsTable92B891E3", }, @@ -3872,18 +3857,7 @@ schema { ], }, { - "Fn::Join": [ - "", - [ - { - "Fn::GetAtt": [ - "ChatBotApiChatDynamoDBTablesSessionsTable92B891E3", - "Arn", - ], - }, - "/index/*", - ], - ], + "Ref": "AWS::NoValue", }, ], }, @@ -5036,7 +5010,6 @@ schema { "EndpointName", ], }, - "SESSIONS_BY_USER_ID_INDEX_NAME": "byUserId", "SESSIONS_TABLE_NAME": { "Ref": "ChatBotApiChatDynamoDBTablesSessionsTable92B891E3", }, @@ -5329,18 +5302,7 @@ schema { ], }, { - "Fn::Join": [ - "", - [ - { - "Fn::GetAtt": [ - "ChatBotApiChatDynamoDBTablesSessionsTable92B891E3", - "Arn", - ], - }, - "/index/*", - ], - ], + "Ref": "AWS::NoValue", }, ], }, @@ -5838,7 +5800,6 @@ schema { "EndpointName", ], }, - "SESSIONS_BY_USER_ID_INDEX_NAME": "byUserId", "SESSIONS_TABLE_NAME": { "Ref": "ChatBotApiChatDynamoDBTablesSessionsTable92B891E3", }, @@ -6274,18 +6235,7 @@ schema { ], }, { - "Fn::Join": [ - "", - [ - { - "Fn::GetAtt": [ - "ChatBotApiChatDynamoDBTablesSessionsTable92B891E3", - "Arn", - ], - }, - "/index/*", - ], - ], + "Ref": "AWS::NoValue", }, ], }, diff --git a/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap b/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap index 7bb87d8fd..cb9012e41 100644 --- a/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap +++ b/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap @@ -1195,36 +1195,22 @@ def submit_response(event: dict, context, response_status: str, error_message: s "Properties": { "AttributeDefinitions": [ { - "AttributeName": "SessionId", + "AttributeName": "PK", "AttributeType": "S", }, { - "AttributeName": "UserId", + "AttributeName": "SK", "AttributeType": "S", }, ], "BillingMode": "PAY_PER_REQUEST", - "GlobalSecondaryIndexes": [ - { - "IndexName": "byUserId", - "KeySchema": [ - { - "AttributeName": "UserId", - "KeyType": "HASH", - }, - ], - "Projection": { - "ProjectionType": "ALL", - }, - }, - ], "KeySchema": [ { - "AttributeName": "SessionId", + "AttributeName": "PK", "KeyType": "HASH", }, { - "AttributeName": "UserId", + "AttributeName": "SK", "KeyType": "RANGE", }, ], @@ -3364,7 +3350,6 @@ schema { "EndpointName", ], }, - "SESSIONS_BY_USER_ID_INDEX_NAME": "byUserId", "SESSIONS_TABLE_NAME": { "Ref": "ChatBotApiConstructChatDynamoDBTablesSessionsTableD81EF9A7", }, @@ -3934,18 +3919,7 @@ schema { ], }, { - "Fn::Join": [ - "", - [ - { - "Fn::GetAtt": [ - "ChatBotApiConstructChatDynamoDBTablesSessionsTableD81EF9A7", - "Arn", - ], - }, - "/index/*", - ], - ], + "Ref": "AWS::NoValue", }, ], }, From 6df8aca753ddf2acc98cc9d1af6ed6d5b86ed392 Mon Sep 17 00:00:00 2001 From: GREG BONE Date: Fri, 11 Oct 2024 13:53:58 -0600 Subject: [PATCH 3/5] chore: fix pytests for sessions --- tests/chatbot-api/functions/api-handler/routes/sessions_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/chatbot-api/functions/api-handler/routes/sessions_test.py b/tests/chatbot-api/functions/api-handler/routes/sessions_test.py index 8bcd278d5..6d2b27024 100644 --- a/tests/chatbot-api/functions/api-handler/routes/sessions_test.py +++ b/tests/chatbot-api/functions/api-handler/routes/sessions_test.py @@ -31,7 +31,7 @@ def test_get_sessions(mocker): expected = [ { "id": session.get("SessionId"), - "title": "content", + "title": "", "startTime": session.get("StartTime") + "Z", } ] From f8e5bd262b2330c5f3e350890344cf58bc88e8ec Mon Sep 17 00:00:00 2001 From: GREG BONE Date: Fri, 11 Oct 2024 14:09:22 -0600 Subject: [PATCH 4/5] chore: run eslint --- .../langchain/chat_message_history.py | 43 +++++++--------- .../python-sdk/python/genai_core/sessions.py | 51 ++++++++++--------- 2 files changed, 44 insertions(+), 50 deletions(-) diff --git a/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py b/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py index 228d2ceb9..e10148b19 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py +++ b/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py @@ -33,26 +33,25 @@ def __init__( self.user_id = user_id self.max_messages = max_messages # Store max_messages - def _get_full_history(self) -> List[BaseMessage]: """Query all messages from DynamoDB for the current session""" messages: List[BaseMessage] = [] - response = self.table.query( + response = self.table.query( KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)", FilterExpression="#itemType = :itemType", ExpressionAttributeNames={ "#pk": "PK", "#sk": "SK", - "#itemType": "ItemType" + "#itemType": "ItemType", }, ScanIndexForward=True, ExpressionAttributeValues={ ":user_id": f"USER#{self.user_id}", ":session_prefix": f"SESSION#{self.session_id}", - ":itemType": "message" - } + ":itemType": "message", + }, ) - items = response.get('Items', []) + items = response.get("Items", []) return items @@ -64,13 +63,15 @@ def messages(self) -> List[BaseMessage]: # Hande case where max_messages is None if self.max_messages is None: self.max.messages = len(full_history_items) - + # Slice before processing - relevant_items = full_history_items[-self.max_messages:] + relevant_items = full_history_items[-self.max_messages :] # Use itemgetter and list comprehension - get_history_data = itemgetter('History') - return [_message_from_dict(get_history_data(item) or '') for item in relevant_items] + get_history_data = itemgetter("History") + return [ + _message_from_dict(get_history_data(item) or "") for item in relevant_items + ] def add_message(self, message: BaseMessage) -> None: """Append the message to the record in DynamoDB""" @@ -93,22 +94,20 @@ def add_message(self, message: BaseMessage) -> None: self.table.update_item( Key={ "PK": f"USER#{self.user_id}", - "SK": f"SESSION#{self.session_id}" + "SK": f"SESSION#{self.session_id}", }, UpdateExpression="SET LastUpdateTime = :time", ConditionExpression="attribute_exists(PK)", - ExpressionAttributeValues={ - ":time": current_time - } + ExpressionAttributeValues={":time": current_time}, ) except ClientError as err: - if err.response['Error']['Code'] == 'ConditionalCheckFailedException': + if err.response["Error"]["Code"] == "ConditionalCheckFailedException": # Session doesn't exist, so create a new one self.table.put_item( Item={ "PK": f"USER#{self.user_id}", "SK": f"SESSION#{self.session_id}", - "Title": _message_to_dict(message) + "Title": _message_to_dict(message) .get("data", {}) .get("content", ""), "StartTime": current_time, @@ -121,8 +120,6 @@ def add_message(self, message: BaseMessage) -> None: # If some other error occurs, re-raise the exception raise - - self.table.put_item( Item={ "PK": f"USER#{self.user_id}", @@ -154,15 +151,11 @@ def add_metadata(self, metadata: dict) -> None: self.table.update_item( Key={ "PK": f"USER#{self.user_id}", - "SK": f"SESSION#{self.session_id}#{most_recent_history['StartTime']}" + "SK": f"SESSION#{self.session_id}#{most_recent_history['StartTime']}", }, UpdateExpression="SET #data = :data", - ExpressionAttributeNames={ - "#data": "History" - }, - ExpressionAttributeValues={ - ":data": most_recent_history["History"] - } + ExpressionAttributeNames={"#data": "History"}, + ExpressionAttributeValues={":data": most_recent_history["History"]}, ) except Exception as err: diff --git a/lib/shared/layers/python-sdk/python/genai_core/sessions.py b/lib/shared/layers/python-sdk/python/genai_core/sessions.py index 5269932a5..ccaec1f45 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/sessions.py +++ b/lib/shared/layers/python-sdk/python/genai_core/sessions.py @@ -14,6 +14,7 @@ table = dynamodb.Table(SESSIONS_TABLE_NAME) logger = Logger() + def _get_messages_by_session_id(session_id, user_id): items = [] try: @@ -23,33 +24,30 @@ def _get_messages_by_session_id(session_id, user_id): ExpressionAttributeNames={ "#pk": "PK", "#sk": "SK", - "#item_type": "ItemType" + "#item_type": "ItemType", }, ExpressionAttributeValues={ - ':user_id': f'USER#{user_id}', - ':session_prefix': f'SESSION#{session_id}', - ':session_type': 'message' + ":user_id": f"USER#{user_id}", + ":session_prefix": f"SESSION#{session_id}", + ":session_type": "message", }, - ScanIndexForward=True + ScanIndexForward=True, ) - items = response.get('Items', []) + items = response.get("Items", []) # If there are more items, continue querying - while 'LastEvaluatedKey' in response: + while "LastEvaluatedKey" in response: response = table.query( KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)", - ExpressionAttributeNames={ - "#pk": "PK", - "#sk": "SK" - }, + ExpressionAttributeNames={"#pk": "PK", "#sk": "SK"}, ExpressionAttributeValues={ - ':user_id': f'USER#{user_id}', - ':session_prefix': f'SESSION#{session_id}' + ":user_id": f"USER#{user_id}", + ":session_prefix": f"SESSION#{session_id}", }, - ScanIndexForward=True + ScanIndexForward=True, ) - items.extend(response.get('Items', [])) + items.extend(response.get("Items", [])) except ClientError as error: if error.response["Error"]["Code"] == "ResourceNotFoundException": @@ -59,6 +57,7 @@ def _get_messages_by_session_id(session_id, user_id): return items + def get_session(session_id, user_id): try: items = _get_messages_by_session_id(session_id, user_id) @@ -72,11 +71,10 @@ def get_session(session_id, user_id): } for item in items: - if 'ItemType' in item: - if item['ItemType'] == 'message': - returnItem['History'].append(item['History']) - returnItem['StartTime']= item['StartTime'] - + if "ItemType" in item: + if item["ItemType"] == "message": + returnItem["History"].append(item["History"]) + returnItem["StartTime"] = item["StartTime"] except ClientError as error: if error.response["Error"]["Code"] == "ResourceNotFoundException": @@ -86,6 +84,7 @@ def get_session(session_id, user_id): return returnItem + def list_sessions_by_user_id(user_id: str) -> List[Dict[str, Any]]: """ List all sessions for a given user ID. @@ -106,13 +105,13 @@ def list_sessions_by_user_id(user_id: str) -> List[Dict[str, Any]]: "ExpressionAttributeNames": { "#pk": "PK", "#sk": "SK", - "#item_type": "ItemType" + "#item_type": "ItemType", }, "ExpressionAttributeValues": { ":user_id": f"USER#{user_id}", ":session_prefix": "SESSION#", - ":session_type": "session" - } + ":session_type": "session", + }, } if last_evaluated_key: @@ -135,6 +134,7 @@ def list_sessions_by_user_id(user_id: str) -> List[Dict[str, Any]]: return session_items + def delete_session(session_id, user_id): try: session_history = _get_messages_by_session_id(session_id, user_id) @@ -170,7 +170,6 @@ def delete_session(session_id, user_id): return {"id": session_id, "deleted": True} - def delete_user_sessions(user_id): try: sessions = list_sessions_by_user_id(user_id) # Get all sessions for the user @@ -178,7 +177,9 @@ def delete_user_sessions(user_id): for session in sessions: # Extract the session ID from the SK (assuming SK is in the format 'SESSION#') - session_id = session["SK"].split("#")[1] # Extracting session ID from 'SESSION#' + session_id = session["SK"].split("#")[ + 1 + ] # Extracting session ID from 'SESSION#' # Delete each session result = delete_session(session_id, user_id) From 37e883b4c200836191087cf68e08a0ec1d884327 Mon Sep 17 00:00:00 2001 From: GREG BONE Date: Tue, 15 Oct 2024 15:28:30 -0500 Subject: [PATCH 5/5] chore(ci): fix linting issues --- .../langchain/chat_message_history.py | 18 +++++++++++++----- .../python-sdk/python/genai_core/sessions.py | 17 +++++++++++------ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py b/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py index e10148b19..2232a24a4 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py +++ b/lib/shared/layers/python-sdk/python/genai_core/langchain/chat_message_history.py @@ -35,9 +35,10 @@ def __init__( def _get_full_history(self) -> List[BaseMessage]: """Query all messages from DynamoDB for the current session""" - messages: List[BaseMessage] = [] response = self.table.query( - KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)", + KeyConditionExpression=( + "#pk = :user_id AND begins_with(#sk, :session_prefix)" + ), FilterExpression="#itemType = :itemType", ExpressionAttributeNames={ "#pk": "PK", @@ -151,11 +152,18 @@ def add_metadata(self, metadata: dict) -> None: self.table.update_item( Key={ "PK": f"USER#{self.user_id}", - "SK": f"SESSION#{self.session_id}#{most_recent_history['StartTime']}", + "SK": ( + f"SESSION#{self.session_id}" + f"#{most_recent_history['StartTime']}" + ), }, UpdateExpression="SET #data = :data", - ExpressionAttributeNames={"#data": "History"}, - ExpressionAttributeValues={":data": most_recent_history["History"]}, + ExpressionAttributeNames={ + "#data": "History" + }, + ExpressionAttributeValues={ + ":data": most_recent_history["History"] + }, ) except Exception as err: diff --git a/lib/shared/layers/python-sdk/python/genai_core/sessions.py b/lib/shared/layers/python-sdk/python/genai_core/sessions.py index ccaec1f45..361d64629 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/sessions.py +++ b/lib/shared/layers/python-sdk/python/genai_core/sessions.py @@ -1,9 +1,7 @@ import os from aws_lambda_powertools import Logger import boto3 -import json from botocore.exceptions import ClientError -from boto3.dynamodb.conditions import Key, Attr from typing import List, Dict, Any AWS_REGION = os.environ["AWS_REGION"] @@ -19,7 +17,9 @@ def _get_messages_by_session_id(session_id, user_id): items = [] try: response = table.query( - KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)", + KeyConditionExpression=( + "#pk = :user_id AND begins_with(#sk, :session_prefix)" + ), FilterExpression="#item_type = :session_type", ExpressionAttributeNames={ "#pk": "PK", @@ -39,7 +39,9 @@ def _get_messages_by_session_id(session_id, user_id): # If there are more items, continue querying while "LastEvaluatedKey" in response: response = table.query( - KeyConditionExpression="#pk = :user_id AND begins_with(#sk, :session_prefix)", + KeyConditionExpression=( + "#pk = :user_id AND begins_with(#sk, :session_prefix)" + ), ExpressionAttributeNames={"#pk": "PK", "#sk": "SK"}, ExpressionAttributeValues={ ":user_id": f"USER#{user_id}", @@ -100,7 +102,9 @@ def list_sessions_by_user_id(user_id: str) -> List[Dict[str, Any]]: last_evaluated_key = None while True: query_params = { - "KeyConditionExpression": "#pk = :user_id AND begins_with(#sk, :session_prefix)", + "KeyConditionExpression": ( + "#pk = :user_id AND begins_with(#sk, :session_prefix)" + ), "FilterExpression": "#item_type = :session_type", "ExpressionAttributeNames": { "#pk": "PK", @@ -176,7 +180,8 @@ def delete_user_sessions(user_id): ret_value = [] for session in sessions: - # Extract the session ID from the SK (assuming SK is in the format 'SESSION#') + # Extract the session ID from the SK + # (assuming SK is in the format 'SESSION#') session_id = session["SK"].split("#")[ 1 ] # Extracting session ID from 'SESSION#'