From 66696c6b0b12b1aff0d72c21e871871dd427f2ee Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Wed, 23 Oct 2024 21:26:03 +0200 Subject: [PATCH] Fix forwardref by making ProvideMultiple generic --- src/dishka/dependency_source/make_factory.py | 16 ++++----- .../dependency_source/unpack_provides.py | 13 +++---- src/dishka/entities/provides_marker.py | 34 ++++++++++++------- src/dishka/entities/with_parents.py | 2 +- tests/unit/container/test_alias.py | 2 +- 5 files changed, 39 insertions(+), 28 deletions(-) diff --git a/src/dishka/dependency_source/make_factory.py b/src/dishka/dependency_source/make_factory.py index e3fa0aa4..46988a4b 100644 --- a/src/dishka/dependency_source/make_factory.py +++ b/src/dishka/dependency_source/make_factory.py @@ -103,10 +103,10 @@ def _type_repr(hint: Any) -> str: def _async_generator_result(hint: Any) -> Any: - if isinstance(hint, ProvideMultiple): - return ProvideMultiple([ - _async_generator_result(x) for x in hint.items - ]) + if get_origin(hint) is ProvideMultiple: + return ProvideMultiple[tuple( + _async_generator_result(x) for x in get_args(hint) + )] origin = get_origin(hint) if origin is AsyncIterable: return get_args(hint)[0] @@ -136,10 +136,10 @@ def _async_generator_result(hint: Any) -> Any: def _generator_result(hint: Any) -> Any: - if isinstance(hint, ProvideMultiple): - return ProvideMultiple([ - _generator_result(x) for x in hint.items - ]) + if get_origin(hint) is ProvideMultiple: + return ProvideMultiple[tuple( + _generator_result(x) for x in get_args(hint) + )] origin = get_origin(hint) if origin is Iterable: return get_args(hint)[0] diff --git a/src/dishka/dependency_source/unpack_provides.py b/src/dishka/dependency_source/unpack_provides.py index 173f51e3..ed1b62b8 100644 --- a/src/dishka/dependency_source/unpack_provides.py +++ b/src/dishka/dependency_source/unpack_provides.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import get_args, get_origin from dishka.entities.key import DependencyKey from dishka.entities.provides_marker import ProvideMultiple @@ -9,10 +10,10 @@ def unpack_factory(factory: Factory) -> Sequence[DependencySource]: - if not isinstance(factory.provides.type_hint, ProvideMultiple): + if get_origin(factory.provides.type_hint) is not ProvideMultiple: return [factory] - provides_first, *provides_others = factory.provides.type_hint.items + provides_first, *provides_others = get_args(factory.provides.type_hint) res: list[DependencySource] = [ Alias( @@ -43,7 +44,7 @@ def unpack_factory(factory: Factory) -> Sequence[DependencySource]: def unpack_decorator(decorator: Decorator) -> Sequence[DependencySource]: - if not isinstance(decorator.provides.type_hint, ProvideMultiple): + if get_origin(decorator.provides.type_hint) is not ProvideMultiple: return [decorator] return [ @@ -51,12 +52,12 @@ def unpack_decorator(decorator: Decorator) -> Sequence[DependencySource]: factory=decorator.factory, provides=DependencyKey(provides, decorator.provides.component), ) - for provides in decorator.provides.type_hint.items + for provides in get_args(decorator.provides.type_hint) ] def unpack_alias(alias: Alias) -> Sequence[DependencySource]: - if not isinstance(alias.provides.type_hint, ProvideMultiple): + if get_origin(alias.provides.type_hint) is not ProvideMultiple: return [alias] return [ @@ -66,5 +67,5 @@ def unpack_alias(alias: Alias) -> Sequence[DependencySource]: cache=alias.cache, override=alias.override, ) - for provides in alias.provides.type_hint.items + for provides in get_args(alias.provides.type_hint) ] diff --git a/src/dishka/entities/provides_marker.py b/src/dishka/entities/provides_marker.py index d7864264..21cbab2f 100644 --- a/src/dishka/entities/provides_marker.py +++ b/src/dishka/entities/provides_marker.py @@ -1,20 +1,30 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +import sys +import threading +from typing import TYPE_CHECKING, Generic, TypeVar __all__ = ["AnyOf", "ProvideMultiple"] -if TYPE_CHECKING: - from typing import Union as AnyOf + +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack + + Variants = TypeVarTuple("Variants") + class ProvideMultiple(Generic[Unpack[Variants]]): + pass else: - class AnyOf: - def __class_getitem__(cls, item: Any) -> Any: - if isinstance(item, tuple): - return ProvideMultiple(item) - return item + Variants = TypeVar("Variants") + provides_lock = threading.Lock() + class ProvideMultiple(Generic[Variants]): + def __class_getitem__(cls, item): + with provides_lock: + cls.__parameters__ = [Variants]*len(item) + return super().__class_getitem__(item) -class ProvideMultiple: - def __init__(self, items: Sequence[Any]) -> None: - self.items = items + +if TYPE_CHECKING: + from typing import Union as AnyOf +else: + AnyOf = ProvideMultiple diff --git a/src/dishka/entities/with_parents.py b/src/dishka/entities/with_parents.py index 1d1b7f14..3452037a 100644 --- a/src/dishka/entities/with_parents.py +++ b/src/dishka/entities/with_parents.py @@ -174,5 +174,5 @@ class WithParents: def __class_getitem__(cls, item: TypeHint) -> TypeHint: parents = ParentsResolver().get_parents(item) if len(parents) > 1: - return ProvideMultiple(parents) + return ProvideMultiple[tuple(parents)] return parents[0] diff --git a/tests/unit/container/test_alias.py b/tests/unit/container/test_alias.py index c079c44f..1c69f614 100644 --- a/tests/unit/container/test_alias.py +++ b/tests/unit/container/test_alias.py @@ -100,7 +100,7 @@ class MyProvider(Provider): value = 0 @provide(scope=Scope.APP) - async def foo(self) -> AsyncIterable[AnyOf[float, int]]: + async def foo(self) -> AsyncIterable[AnyOf[float, "int"]]: self.value += 1 yield self.value