From 1db74ecdcaf488f06b622acea43cabd8263cccb0 Mon Sep 17 00:00:00 2001 From: Sorin Sbarnea Date: Thu, 15 Aug 2024 18:25:52 +0100 Subject: [PATCH] Address more typing errors Related: #258 --- .../eda/plugins/event_source/aws_cloudtrail.py | 4 ++-- .../eda/plugins/event_source/aws_sqs_queue.py | 13 +++++++------ extensions/eda/plugins/event_source/generic.py | 2 +- extensions/eda/plugins/event_source/kafka.py | 6 ++++-- extensions/eda/plugins/event_source/pg_listener.py | 4 ++-- pyproject.toml | 2 -- .../event_source_url_check/test_url_check_source.py | 3 ++- 7 files changed, 18 insertions(+), 16 deletions(-) diff --git a/extensions/eda/plugins/event_source/aws_cloudtrail.py b/extensions/eda/plugins/event_source/aws_cloudtrail.py index b2719c05..5f9bcb00 100644 --- a/extensions/eda/plugins/event_source/aws_cloudtrail.py +++ b/extensions/eda/plugins/event_source/aws_cloudtrail.py @@ -46,7 +46,7 @@ def _cloudtrail_event_to_dict(event: dict) -> dict: return event -def _get_events(events: list[dict], last_event_ids: list) -> list: +def _get_events(events: list[dict], last_event_ids: list[str]) -> list: event_time = None event_ids = [] result = [] @@ -89,7 +89,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: async with session.create_client("cloudtrail", **connection_args(args)) as client: event_time = None - event_ids = [] + event_ids: list[str] = [] while True: if event_time is not None: params["StartTime"] = event_time diff --git a/extensions/eda/plugins/event_source/aws_sqs_queue.py b/extensions/eda/plugins/event_source/aws_sqs_queue.py index ca68f715..66c85d7c 100644 --- a/extensions/eda/plugins/event_source/aws_sqs_queue.py +++ b/extensions/eda/plugins/event_source/aws_sqs_queue.py @@ -30,6 +30,7 @@ from aiobotocore.session import get_session +# pylint: disable=too-many-locals async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: """Receive events via an AWS SQS queue.""" logger = logging.getLogger() @@ -64,19 +65,19 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: ) if "Messages" in response_msg: - for msg in response_msg["Messages"]: + for entry in response_msg["Messages"]: if ( - not isinstance(msg, dict) or "MessageId" not in msg + not isinstance(entry, dict) or "MessageId" not in entry ): # pragma: no cover err_msg = ( f"Unexpected response {response_msg}, missing MessageId." ) raise ValueError(err_msg) - meta = {"MessageId": msg["MessageId"]} + meta = {"MessageId": entry["MessageId"]} try: - msg_body = json.loads(msg["Body"]) + msg_body = json.loads(entry["Body"]) except json.JSONDecodeError: - msg_body = msg["Body"] + msg_body = entry["Body"] await queue.put({"body": msg_body, "meta": meta}) await asyncio.sleep(0) @@ -84,7 +85,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: # Need to remove msg from queue or else it'll reappear await client.delete_message( QueueUrl=queue_url, - ReceiptHandle=msg["ReceiptHandle"], + ReceiptHandle=entry["ReceiptHandle"], ) else: logger.debug("No messages in queue") diff --git a/extensions/eda/plugins/event_source/generic.py b/extensions/eda/plugins/event_source/generic.py index d524b373..43228e3d 100644 --- a/extensions/eda/plugins/event_source/generic.py +++ b/extensions/eda/plugins/event_source/generic.py @@ -190,7 +190,7 @@ def _create_data( self: Generic, index: int, ) -> dict: - data = {} + data: dict[str, str | int] = {} if self.my_args.create_index: data[self.my_args.create_index] = index if self.blob: diff --git a/extensions/eda/plugins/event_source/kafka.py b/extensions/eda/plugins/event_source/kafka.py index 09220cb4..21282457 100644 --- a/extensions/eda/plugins/event_source/kafka.py +++ b/extensions/eda/plugins/event_source/kafka.py @@ -124,11 +124,13 @@ async def receive_msg( logger = logging.getLogger() async for msg in kafka_consumer: - event = {} + event: dict[str, Any] = {} # Process headers try: - headers = {header[0]: header[1].decode(encoding) for header in msg.headers} + headers: dict[str, str] = { + header[0]: header[1].decode(encoding) for header in msg.headers + } event["meta"] = {} event["meta"]["headers"] = headers except UnicodeError: diff --git a/extensions/eda/plugins/event_source/pg_listener.py b/extensions/eda/plugins/event_source/pg_listener.py index 01f67803..2cd66aa9 100644 --- a/extensions/eda/plugins/event_source/pg_listener.py +++ b/extensions/eda/plugins/event_source/pg_listener.py @@ -99,7 +99,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: conninfo=args["dsn"], autocommit=True, ) as conn: - chunked_cache = {} + chunked_cache: dict[str, Any] = {} cursor = conn.cursor() for channel in args["channels"]: await cursor.execute(f"LISTEN {channel};") @@ -118,7 +118,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: async def _handle_chunked_message( - data: dict, + data: dict[str, Any], chunked_cache: dict, queue: asyncio.Queue, ) -> None: diff --git a/pyproject.toml b/pyproject.toml index 00ab4ad2..0e59c5e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,10 +41,8 @@ error_summary = true # TODO: Remove temporary skips and close https://github.com/ansible/event-driven-ansible/issues/258 disable_error_code = [ - "assignment", "attr-defined", "override", - "var-annotated", ] # strict = true # disallow_untyped_calls = true diff --git a/tests/integration/event_source_url_check/test_url_check_source.py b/tests/integration/event_source_url_check/test_url_check_source.py index d304299c..9331b72a 100644 --- a/tests/integration/event_source_url_check/test_url_check_source.py +++ b/tests/integration/event_source_url_check/test_url_check_source.py @@ -1,6 +1,7 @@ import http.server import os import threading +from typing import Any, Generator import pytest @@ -20,7 +21,7 @@ def log_message(self, format, *args): @pytest.fixture(scope="function") -def init_webserver(): +def init_webserver() -> Generator[Any, Any, Any]: handler = HttpHandler port: int = 8000 httpd = http.server.HTTPServer(("", port), handler)