Skip to content

Commit

Permalink
type: make code respect mypy strict mode
Browse files Browse the repository at this point in the history
Fixes: #258
  • Loading branch information
ssbarnea committed Aug 28, 2024
1 parent 05bd6ed commit 16379f1
Show file tree
Hide file tree
Showing 25 changed files with 72 additions and 48 deletions.
13 changes: 7 additions & 6 deletions extensions/eda/plugins/event_filter/json_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@
from __future__ import annotations

import fnmatch
from typing import Any, Optional


def _matches_include_keys(include_keys: list, string: str) -> bool:
def _matches_include_keys(include_keys: list[str], string: str) -> bool:
return any(fnmatch.fnmatch(string, pattern) for pattern in include_keys)


def _matches_exclude_keys(exclude_keys: list, string: str) -> bool:
def _matches_exclude_keys(exclude_keys: list[str], string: str) -> bool:
return any(fnmatch.fnmatch(string, pattern) for pattern in exclude_keys)


def main(
event: dict,
exclude_keys: list | None = None,
include_keys: list | None = None,
) -> dict:
event: dict[str, Any],
exclude_keys: Optional[list[str]] = None,
include_keys: Optional[list[str]] = None,
) -> dict[str, Any]:
"""Filter keys out of events."""
if exclude_keys is None:
exclude_keys = []
Expand Down
4 changes: 3 additions & 1 deletion extensions/eda/plugins/event_filter/noop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""noop.py: An event filter that does nothing to the input."""

from typing import Any

def main(event: dict) -> dict:

def main(event: dict[str, Any]) -> dict[str, Any]:
"""Return the input."""
return event
2 changes: 1 addition & 1 deletion extensions/eda/plugins/event_source/alertmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def clean_host(host: str) -> str:
return host


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Receive events via alertmanager webhook."""
app = web.Application()
app["queue"] = queue
Expand Down
17 changes: 13 additions & 4 deletions extensions/eda/plugins/event_source/aws_cloudtrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,19 @@ def _get_events(events: list[dict], last_event_ids: list[str]) -> list:
return [result, event_time, event_ids]


async def _get_cloudtrail_events(client: BaseClient, params: dict) -> list[dict]:
async def _get_cloudtrail_events(
client: BaseClient, params: dict[str, Any]
) -> list[dict]:
paginator = client.get_paginator("lookup_events")
results = await paginator.paginate(**params).build_full_result()
return results.get("Events", [])
events = results.get("Events", [])
# type guards:
if not isinstance(events, list):
raise ValueError("Events is not a list")
for event in events:
if not isinstance(event, dict):
raise ValueError("Event is not a dictionary")
return events


ARGS_MAPPING = {
Expand All @@ -75,7 +84,7 @@ async def _get_cloudtrail_events(client: BaseClient, params: dict) -> list[dict]
}


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Receive events via AWS CloudTrail."""
delay = int(args.get("delay_seconds", 10))

Expand Down Expand Up @@ -131,7 +140,7 @@ def connection_args(args: dict[str, Any]) -> dict[str, Any]:
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
async def put(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
2 changes: 1 addition & 1 deletion extensions/eda/plugins/event_source/aws_sqs_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


# pylint: disable=too-many-locals
async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Receive events via an AWS SQS queue."""
logger = logging.getLogger()

Expand Down
6 changes: 3 additions & 3 deletions extensions/eda/plugins/event_source/azure_service_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def receive_events(
loop: asyncio.events.AbstractEventLoop,
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict[str, Any], # pylint: disable=W0621
) -> None:
"""Receive events from service bus."""
Expand All @@ -53,7 +53,7 @@ def receive_events(


async def main(
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict[str, Any], # pylint: disable=W0621
) -> None:
"""Receive events from service bus in a loop."""
Expand All @@ -69,7 +69,7 @@ async def main(
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

def put_nowait(self: "MockQueue", event: dict) -> None:
def put_nowait(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
6 changes: 3 additions & 3 deletions extensions/eda/plugins/event_source/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from watchdog.observers import Observer


def send_facts(queue: Queue, filename: Union[str, bytes]) -> None:
def send_facts(queue: Queue[Any], filename: Union[str, bytes]) -> None:
"""Send facts to the queue."""
if isinstance(filename, bytes):
filename = str(filename, "utf-8")
Expand All @@ -50,7 +50,7 @@ def send_facts(queue: Queue, filename: Union[str, bytes]) -> None:
coroutine = queue.put(item) # noqa: F841


def main(queue: Queue, args: dict) -> None:
def main(queue: Queue[Any], args: dict) -> None:
"""Load facts from YAML files initially and when the file changes."""
files = [pathlib.Path(f).resolve().as_posix() for f in args.get("files", [])]

Expand All @@ -62,7 +62,7 @@ def main(queue: Queue, args: dict) -> None:
_observe_files(queue, files)


def _observe_files(queue: Queue, files: list[str]) -> None:
def _observe_files(queue: Queue[Any], files: list[str]) -> None:
class Handler(RegexMatchingEventHandler):
"""A handler for file events."""

Expand Down
4 changes: 2 additions & 2 deletions extensions/eda/plugins/event_source/file_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

def watch(
loop: asyncio.events.AbstractEventLoop,
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict,
) -> None:
"""Watch for changes and put events on the queue."""
Expand Down Expand Up @@ -96,7 +96,7 @@ def on_moved(self: "Handler", event: FileSystemEvent) -> None:
observer.join()


async def main(queue: asyncio.Queue, args: dict) -> None:
async def main(queue: asyncio.Queue[Any], args: dict) -> None:
"""Watch for changes to a file and put events on the queue."""
loop = asyncio.get_event_loop()

Expand Down
6 changes: 4 additions & 2 deletions extensions/eda/plugins/event_source/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ class DelayArgs:
class Generic:
"""Generic source plugin to generate different events."""

def __init__(self: Generic, queue: asyncio.Queue, args: dict[str, Any]) -> None:
def __init__(
self: Generic, queue: asyncio.Queue[Any], args: dict[str, Any]
) -> None:
"""Insert event data into the queue."""
self.queue = queue
field_names = [f.name for f in fields(Args)]
Expand Down Expand Up @@ -206,7 +208,7 @@ def _create_data(


async def main( # pylint: disable=R0914
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict[str, Any],
) -> None:
"""Call the Generic Source Plugin."""
Expand Down
2 changes: 1 addition & 1 deletion extensions/eda/plugins/event_source/journald.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from systemd import journal # type: ignore


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: # noqa: D417
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: # noqa: D417
"""Read journal entries and add them to the provided queue.
Args:
Expand Down
5 changes: 3 additions & 2 deletions extensions/eda/plugins/event_source/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import asyncio
import json
import logging
from re import A
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from typing import Any

Expand All @@ -45,7 +46,7 @@


async def main( # pylint: disable=R0914
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict[str, Any],
) -> None:
"""Receive events via a kafka topic."""
Expand Down Expand Up @@ -116,7 +117,7 @@ async def main( # pylint: disable=R0914


async def receive_msg(
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
kafka_consumer: AIOKafkaConsumer,
encoding: str,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions extensions/eda/plugins/event_source/pg_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _validate_chunked_payload(payload: dict) -> None:
raise MissingChunkKeyError(key)


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Listen for events from a channel."""
for key in REQUIRED_KEYS:
if key not in args:
Expand Down Expand Up @@ -120,7 +120,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def _handle_chunked_message(
data: dict[str, Any],
chunked_cache: dict,
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
) -> None:
message_uuid = data[MESSAGE_CHUNKED_UUID]
number_of_chunks = data[MESSAGE_CHUNK_COUNT]
Expand Down
2 changes: 1 addition & 1 deletion extensions/eda/plugins/event_source/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Generate events with an increasing index i with a limit."""
delay = args.get("delay", 0)

Expand Down
2 changes: 1 addition & 1 deletion extensions/eda/plugins/event_source/tick.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Generate events with an increasing index i and a time between ticks."""
delay = args.get("delay", 1)

Expand Down
2 changes: 1 addition & 1 deletion extensions/eda/plugins/event_source/url_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
OK = 200


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Poll a set of URLs and send events with status."""
urls = args.get("urls", [])
delay = int(args.get("delay", 1))
Expand Down
2 changes: 1 addition & 1 deletion extensions/eda/plugins/event_source/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _get_ssl_context(args: dict[str, Any]) -> ssl.SSLContext | None:
return context


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Receive events via webhook."""
if "port" not in args:
msg = "Missing required argument: port"
Expand Down
11 changes: 9 additions & 2 deletions plugins/module_utils/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_one_or_many(
endpoint: str,
name: Optional[str] = None,
**kwargs: Any,
) -> List[Any]:
) -> List[dict[str, Any]]:
new_kwargs = kwargs.copy()

if name:
Expand All @@ -140,7 +140,14 @@ def get_one_or_many(
if response.json["count"] == 0:
return []

return response.json["results"]
# type safeguard
results = response.json["results"]
if not isinstance(results, list):
raise EDAError("The endpoint did not provide a list of dictionaries")
for result in results:
if not isinstance(result, dict):
raise EDAError("The endpoint did not provide a list of dictionaries")
return results

def create_if_needed(
self,
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ color_output = true
error_summary = true

# TODO: Remove temporary skips and close https://github.com/ansible/event-driven-ansible/issues/258
# strict = true
# disallow_untyped_calls = true
strict = true
disallow_untyped_calls = true
disallow_untyped_defs = true
# disallow_any_generics = true
# disallow_any_unimported = True
# disallow_any_unimported = true
# warn_redundant_casts = True
# warn_return_any = True
# warn_unused_configs = True
warn_unused_configs = true

# site-packages is here to help vscode mypy integration getting confused
exclude = "(build|dist|test/local-content|site-packages|~/.pyenv|examples/playbooks/collections|plugins/modules)"
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@pytest.fixture(scope="function")
def subprocess_teardown() -> Iterator[Callable]:
def subprocess_teardown() -> Iterator[Callable[[Popen[bytes]], None]]:
processes: list[Popen[bytes]] = []

def _teardown(process: Popen[bytes]) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/event_source_kafka/test_kafka_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest
from kafka import KafkaProducer

from ..utils import TESTS_PATH, CLIRunner
from .. import TESTS_PATH
from ..utils import CLIRunner


@pytest.fixture(scope="session")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import pytest

from ..utils import DEFAULT_TEST_TIMEOUT, TESTS_PATH, CLIRunner
from ..utils import DEFAULT_TEST_TIMEOUT, CLIRunner
from .. import TESTS_PATH

EVENT_SOURCE_DIR = os.path.dirname(__file__)

Expand Down
8 changes: 4 additions & 4 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import subprocess
from dataclasses import dataclass
from typing import List, Optional
from typing import Any, List, Optional

from . import TESTS_PATH

Expand All @@ -25,7 +25,7 @@ class CLIRunner:
verbose: bool = False
debug: bool = False
timeout: float = 10.0
env: Optional[dict] = None
env: Optional[dict[str, str]] = None

def __post_init__(self) -> None:
self.env = os.environ.copy() if self.env is None else self.env
Expand Down Expand Up @@ -54,7 +54,7 @@ def _process_args(self) -> List[str]:

return args

def run(self) -> subprocess.CompletedProcess:
def run(self) -> subprocess.CompletedProcess[Any]:
args = self._process_args()
print("Running command: ", " ".join(args))
return subprocess.run(
Expand All @@ -66,7 +66,7 @@ def run(self) -> subprocess.CompletedProcess:
env=self.env,
)

def run_in_background(self) -> subprocess.Popen:
def run_in_background(self) -> subprocess.Popen[bytes]:
args = self._process_args()
print("Running command: ", " ".join(args))
return subprocess.Popen(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/event_source/test_alertmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from extensions.eda.plugins.event_source.alertmanager import main as alert_main


async def start_server(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def start_server(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
await alert_main(queue, args)


Expand Down
Loading

0 comments on commit 16379f1

Please sign in to comment.