Skip to content

Commit

Permalink
feat: Replace the use of identity pool by s3 signed urls. (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-marion authored Sep 19, 2024
1 parent 713a879 commit 08f899b
Show file tree
Hide file tree
Showing 31 changed files with 400 additions and 613 deletions.
46 changes: 26 additions & 20 deletions integtests/chatbot-api/multi_modal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,30 @@
import os
import time
import uuid
import boto3
from pathlib import Path

import pytest
import requests
from gql.transport.exceptions import TransportQueryError


def test_multi_modal(
client, config, cognito_credentials, default_multimodal_model, default_provider
):
bucket = config.get("Storage").get("AWSS3").get("bucket")
s3 = boto3.resource(
"s3",
# Use identity pool credentials to verify it owrks
aws_access_key_id=cognito_credentials.aws_access_key,
aws_secret_access_key=cognito_credentials.aws_secret_key,
aws_session_token=cognito_credentials.aws_token,
)
def test_multi_modal(client, default_multimodal_model, default_provider):

key = "INTEG_TEST" + str(uuid.uuid4()) + ".jpeg"
object = s3.Object(bucket, "public/" + key)
wrong_object = s3.Object(bucket, "private/notallowed/1.jpg")
result = client.add_file(
input={
"fileName": key,
}
)

fields = result.get("fields")
cleaned_fields = fields.replace("{", "").replace("}", "")
pairs = [pair.strip() for pair in cleaned_fields.split(",")]
fields_dict = dict(pair.split("=", 1) for pair in pairs)
current_dir = os.path.dirname(os.path.realpath(__file__))
object.put(Body=Path(current_dir + "/resources/powered-by-aws.png").read_bytes())
with pytest.raises(Exception, match="AccessDenied"):
wrong_object.put(
Body=Path(current_dir + "/resources/powered-by-aws.png").read_bytes()
)
files = {"file": Path(current_dir + "/resources/powered-by-aws.png").read_bytes()}
response = requests.post(result.get("url"), data=fields_dict, files=files)
assert response.status_code == 204

session_id = str(uuid.uuid4())

Expand Down Expand Up @@ -58,4 +56,12 @@ def test_multi_modal(

assert "powered by" in content
client.delete_session(session_id)
object.delete()

# Verify it can get the file
url = client.get_file_url(key)
assert url.startswith("https://")


def test_unknown_file(client):
with pytest.raises(TransportQueryError, match="File does not exist"):
client.get_file_url("file")
4 changes: 4 additions & 0 deletions integtests/clients/appsync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def add_file(self, input):
)
return self.client.execute(query).get("getUploadFileURL")

def get_file_url(self, fileName):
query = dsl_gql(DSLQuery(self.schema.Query.getFileURL.args(fileName=fileName)))
return self.client.execute(query).get("getFileURL")

def get_document(self, input):
query = dsl_gql(
DSLQuery(
Expand Down
20 changes: 0 additions & 20 deletions integtests/clients/cognito_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ class Credentials(BaseModel):
id_token: str
email: str
password: str
aws_access_key: str
aws_secret_key: str
aws_token: str

def __repr__(self):
return "Credentials(********)"
Expand Down Expand Up @@ -67,28 +64,11 @@ def get_credentials(self, email: str) -> Credentials:
AuthParameters={"USERNAME": email, "PASSWORD": password},
)

login_key = "cognito-idp." + self.region + ".amazonaws.com/" + self.user_pool_id
identity_response = self.cognito_identity_client.get_id(
IdentityPoolId=self.identity_pool_id,
Logins={login_key: response["AuthenticationResult"]["IdToken"]},
)

aws_credentials_respose = (
self.cognito_identity_client.get_credentials_for_identity(
IdentityId=identity_response["IdentityId"],
Logins={login_key: response["AuthenticationResult"]["IdToken"]},
)
)

return Credentials(
**{
"id_token": response["AuthenticationResult"]["IdToken"],
"email": email,
"password": password,
# Credential with limited permissions (upload images for multi modal)
"aws_access_key": aws_credentials_respose["Credentials"]["AccessKeyId"],
"aws_secret_key": aws_credentials_respose["Credentials"]["SecretKey"],
"aws_token": aws_credentials_respose["Credentials"]["SessionToken"],
}
)

Expand Down
2 changes: 2 additions & 0 deletions integtests/security/unauthorized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def test_unauthenticated(unauthenticated_client: AppSyncClient):
unauthenticated_client.start_kendra_data_sync("id")
with pytest.raises(TransportQueryError, match=match):
unauthenticated_client.is_kendra_data_synching("id")
with pytest.raises(TransportQueryError, match=match):
unauthenticated_client.get_file_url("file")
with pytest.raises(TransportQueryError, match=match):
unauthenticated_client.list_kendra_indexes()
with pytest.raises(TransportQueryError, match=match):
Expand Down
22 changes: 0 additions & 22 deletions lib/authentication/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import * as cognitoIdentityPool from "@aws-cdk/aws-cognito-identitypool-alpha";
import * as cdk from "aws-cdk-lib";
import { SystemConfig } from "../shared/types";
import * as cognito from "aws-cdk-lib/aws-cognito";
Expand All @@ -12,7 +11,6 @@ import * as logs from "aws-cdk-lib/aws-logs";
export class Authentication extends Construct {
public readonly userPool: cognito.UserPool;
public readonly userPoolClient: cognito.UserPoolClient;
public readonly identityPool: cognitoIdentityPool.IdentityPool;
public readonly cognitoDomain: cognito.UserPoolDomain;
public readonly updateUserPoolClient: lambda.Function;
public readonly customOidcProvider: cognito.UserPoolIdentityProviderOidc;
Expand Down Expand Up @@ -53,21 +51,6 @@ export class Authentication extends Construct {
this.cognitoDomain = userPooldomain;
}

const identityPool = new cognitoIdentityPool.IdentityPool(
this,
"IdentityPool",
{
authenticationProviders: {
userPools: [
new cognitoIdentityPool.UserPoolAuthenticationProvider({
userPool,
userPoolClient,
}),
],
},
}
);

if (config.cognitoFederation?.enabled) {
// Create an IAM Role for the Lambda function
const lambdaRoleUpdateClient = new iam.Role(
Expand Down Expand Up @@ -243,16 +226,11 @@ export class Authentication extends Construct {

this.userPool = userPool;
this.userPoolClient = userPoolClient;
this.identityPool = identityPool;

new cdk.CfnOutput(this, "UserPoolId", {
value: userPool.userPoolId,
});

new cdk.CfnOutput(this, "IdentityPoolId", {
value: identityPool.identityPoolId,
});

new cdk.CfnOutput(this, "UserPoolWebClientId", {
value: userPoolClient.userPoolClientId,
});
Expand Down
5 changes: 1 addition & 4 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
userPoolId: authentication.userPool.userPoolId,
userPoolClient: authentication.userPoolClient,
userPoolClientId: authentication.userPoolClient.userPoolClientId,
identityPool: authentication.identityPool,
api: chatBotApi,
chatbotFilesBucket: chatBotApi.filesBucket,
crossEncodersEnabled:
Expand Down Expand Up @@ -296,7 +295,6 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
NagSuppressions.addResourceSuppressionsByPath(
this,
[
`/${this.stackName}/Authentication/IdentityPool/AuthenticatedRole/DefaultPolicy/Resource`,
`/${this.stackName}/Authentication/UserPool/smsRole/Resource`,
`/${this.stackName}/Custom::CDKBucketDeployment8693BB64968944B69AAFB0CC9EB8756C/ServiceRole/DefaultPolicy/Resource`,
`/${this.stackName}/LogRetentionaae0aa3c5b4d4f87b02d85b201efdd8a/ServiceRole/Resource`,
Expand Down Expand Up @@ -351,8 +349,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
NagSuppressions.addResourceSuppressionsByPath(
this,
[
`/${this.stackName}/IdeficsInterface/ChatbotFilesPrivateApi/Default/{object}/ANY/Resource`,
`/${this.stackName}/IdeficsInterface/ChatbotFilesPrivateApi/Default/{object}/ANY/Resource`,
`/${this.stackName}/IdeficsInterface/ChatbotFilesPrivateApi/Default/{folder}/{key}/GET/Resource`,
],
[
{ id: "AwsSolutions-APIG4", reason: "Private API within a VPC." },
Expand Down
35 changes: 27 additions & 8 deletions lib/chatbot-api/functions/api-handler/routes/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
SAFE_SHORT_STR_VALIDATION,
)
import genai_core.types
import genai_core.upload
import genai_core.presign
import genai_core.documents
import genai_core.auth
from pydantic import BaseModel, Field
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler.appsync import Router
Expand All @@ -20,7 +21,7 @@


class FileUploadRequest(BaseModel):
workspaceId: str = ID_FIELD_VALIDATION
workspaceId: Optional[str] = ID_FIELD_VALIDATION
fileName: str = Field(min_length=1, max_length=500, pattern=SAFE_STR_REGEX)


Expand Down Expand Up @@ -104,7 +105,7 @@ class DocumentSubscriptionStatusRequest(BaseModel):
)


allowed_extensions = set(
allowed_workspace_extensions = set(
[
".csv",
".doc",
Expand All @@ -128,18 +129,36 @@ class DocumentSubscriptionStatusRequest(BaseModel):
]
)

allowed_session_extensions = set(
[
".jpg",
".jpeg",
".png",
]
)


@router.resolver(field_name="getUploadFileURL")
@tracer.capture_method
def file_upload(input: dict):
request = FileUploadRequest(**input)
_, extension = os.path.splitext(request.fileName)
if extension not in allowed_extensions:
raise genai_core.types.CommonError("Invalid file extension")

result = genai_core.upload.generate_presigned_post(
request.workspaceId, request.fileName
)
if "workspaceId" in input:
if extension not in allowed_workspace_extensions:
raise genai_core.types.CommonError("Invalid file extension")

result = genai_core.presign.generate_workspace_presigned_post(
request.workspaceId, request.fileName
)
else:
if extension not in allowed_session_extensions:
raise genai_core.types.CommonError("Invalid file extension")

user_id = genai_core.auth.get_user_id(router)
result = genai_core.presign.generate_user_presigned_post(
user_id, request.fileName
)

logger.info("Generated pre-signed for " + request.fileName)
return result
Expand Down
20 changes: 20 additions & 0 deletions lib/chatbot-api/functions/api-handler/routes/sessions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pydantic import BaseModel, Field
from common.constant import SAFE_STR_REGEX
from common.validation import WorkspaceIdValidation
import genai_core.presign
import genai_core.sessions
import genai_core.types
import genai_core.auth
Expand All @@ -12,6 +15,23 @@
logger = Logger()


class FileURequestValidation(BaseModel):
fileName: str = Field(min_length=1, max_length=500, pattern=SAFE_STR_REGEX)


@router.resolver(field_name="getFileURL")
@tracer.capture_method
def get_file(fileName: str):
FileURequestValidation(**{"fileName": fileName})
user_id = genai_core.auth.get_user_id(router)
result = genai_core.presign.generate_user_presigned_get(
user_id, fileName, expiration=600
)

logger.info("Generated pre-signed for " + fileName)
return result


@router.resolver(field_name="listSessions")
@tracer.capture_method
def get_sessions():
Expand Down
1 change: 1 addition & 0 deletions lib/chatbot-api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ export class ChatBotApi extends Construct {
byUserIdIndex: chatTables.byUserIdIndex,
api,
userFeedbackBucket: chatBuckets.userFeedbackBucket,
filesBucket: chatBuckets.filesBucket,
});

this.resolvers.push(apiResolvers.appSyncLambdaResolver);
Expand Down
3 changes: 3 additions & 0 deletions lib/chatbot-api/rest-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export interface ApiResolversProps {
readonly userPool: cognito.UserPool;
readonly sessionsTable: dynamodb.Table;
readonly byUserIdIndex: string;
readonly filesBucket: s3.Bucket;
readonly userFeedbackBucket: s3.Bucket;
readonly modelsParameter: ssm.StringParameter;
readonly models: SageMakerModelEndpoint[];
Expand Down Expand Up @@ -69,6 +70,7 @@ export class ApiResolvers extends Construct {
SESSIONS_BY_USER_ID_INDEX_NAME: props.byUserIdIndex,
USER_FEEDBACK_BUCKET_NAME: props.userFeedbackBucket?.bucketName ?? "",
UPLOAD_BUCKET_NAME: props.ragEngines?.uploadBucket?.bucketName ?? "",
CHATBOT_FILES_BUCKET_NAME: props.filesBucket.bucketName,
PROCESSING_BUCKET_NAME:
props.ragEngines?.processingBucket?.bucketName ?? "",
AURORA_DB_SECRET_ID: props.ragEngines?.auroraPgVector?.database
Expand Down Expand Up @@ -296,6 +298,7 @@ export class ApiResolvers extends Construct {
props.modelsParameter.grantRead(apiHandler);
props.sessionsTable.grantReadWriteData(apiHandler);
props.userFeedbackBucket.grantReadWrite(apiHandler);
props.filesBucket.grantReadWrite(apiHandler);
props.ragEngines?.uploadBucket.grantReadWrite(apiHandler);
props.ragEngines?.processingBucket.grantReadWrite(apiHandler);

Expand Down
4 changes: 3 additions & 1 deletion lib/chatbot-api/schema/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ type EmbeddingModel @aws_cognito_user_pools {
}

input FileUploadInput {
workspaceId: String!
workspaceId: String
fileName: String!
}

Expand Down Expand Up @@ -363,6 +363,8 @@ type Query {
checkHealth: Boolean @aws_cognito_user_pools
getUploadFileURL(input: FileUploadInput!): FileUploadResult
@aws_cognito_user_pools
getFileURL(fileName: String!): String
@aws_cognito_user_pools
listModels: [Model!]! @aws_cognito_user_pools
listWorkspaces: [Workspace!]! @aws_cognito_user_pools
getWorkspace(workspaceId: String!): Workspace @aws_cognito_user_pools
Expand Down
Loading

0 comments on commit 08f899b

Please sign in to comment.