Skip to content

Commit

Permalink
Fix forwardref by making ProvideMultiple generic
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishka17 committed Oct 23, 2024
1 parent 8fd67f4 commit 66696c6
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
16 changes: 8 additions & 8 deletions src/dishka/dependency_source/make_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def _type_repr(hint: Any) -> str:


def _async_generator_result(hint: Any) -> Any:
if isinstance(hint, ProvideMultiple):
return ProvideMultiple([
_async_generator_result(x) for x in hint.items
])
if get_origin(hint) is ProvideMultiple:
return ProvideMultiple[tuple(
_async_generator_result(x) for x in get_args(hint)
)]
origin = get_origin(hint)
if origin is AsyncIterable:
return get_args(hint)[0]
Expand Down Expand Up @@ -136,10 +136,10 @@ def _async_generator_result(hint: Any) -> Any:


def _generator_result(hint: Any) -> Any:
if isinstance(hint, ProvideMultiple):
return ProvideMultiple([
_generator_result(x) for x in hint.items
])
if get_origin(hint) is ProvideMultiple:
return ProvideMultiple[tuple(
_generator_result(x) for x in get_args(hint)
)]
origin = get_origin(hint)
if origin is Iterable:
return get_args(hint)[0]
Expand Down
13 changes: 7 additions & 6 deletions src/dishka/dependency_source/unpack_provides.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import get_args, get_origin

from dishka.entities.key import DependencyKey
from dishka.entities.provides_marker import ProvideMultiple
Expand All @@ -9,10 +10,10 @@


def unpack_factory(factory: Factory) -> Sequence[DependencySource]:
if not isinstance(factory.provides.type_hint, ProvideMultiple):
if get_origin(factory.provides.type_hint) is not ProvideMultiple:
return [factory]

provides_first, *provides_others = factory.provides.type_hint.items
provides_first, *provides_others = get_args(factory.provides.type_hint)

res: list[DependencySource] = [
Alias(
Expand Down Expand Up @@ -43,20 +44,20 @@ def unpack_factory(factory: Factory) -> Sequence[DependencySource]:


def unpack_decorator(decorator: Decorator) -> Sequence[DependencySource]:
if not isinstance(decorator.provides.type_hint, ProvideMultiple):
if get_origin(decorator.provides.type_hint) is not ProvideMultiple:
return [decorator]

return [
Decorator(
factory=decorator.factory,
provides=DependencyKey(provides, decorator.provides.component),
)
for provides in decorator.provides.type_hint.items
for provides in get_args(decorator.provides.type_hint)
]


def unpack_alias(alias: Alias) -> Sequence[DependencySource]:
if not isinstance(alias.provides.type_hint, ProvideMultiple):
if get_origin(alias.provides.type_hint) is not ProvideMultiple:
return [alias]

return [
Expand All @@ -66,5 +67,5 @@ def unpack_alias(alias: Alias) -> Sequence[DependencySource]:
cache=alias.cache,
override=alias.override,
)
for provides in alias.provides.type_hint.items
for provides in get_args(alias.provides.type_hint)
]
34 changes: 22 additions & 12 deletions src/dishka/entities/provides_marker.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
import sys
import threading
from typing import TYPE_CHECKING, Generic, TypeVar

__all__ = ["AnyOf", "ProvideMultiple"]

if TYPE_CHECKING:
from typing import Union as AnyOf

if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack

Variants = TypeVarTuple("Variants")
class ProvideMultiple(Generic[Unpack[Variants]]):
pass
else:
class AnyOf:
def __class_getitem__(cls, item: Any) -> Any:
if isinstance(item, tuple):
return ProvideMultiple(item)
return item
Variants = TypeVar("Variants")
provides_lock = threading.Lock()

class ProvideMultiple(Generic[Variants]):
def __class_getitem__(cls, item):
with provides_lock:
cls.__parameters__ = [Variants]*len(item)
return super().__class_getitem__(item)

class ProvideMultiple:
def __init__(self, items: Sequence[Any]) -> None:
self.items = items

if TYPE_CHECKING:
from typing import Union as AnyOf
else:
AnyOf = ProvideMultiple
2 changes: 1 addition & 1 deletion src/dishka/entities/with_parents.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,5 @@ class WithParents:
def __class_getitem__(cls, item: TypeHint) -> TypeHint:
parents = ParentsResolver().get_parents(item)
if len(parents) > 1:
return ProvideMultiple(parents)
return ProvideMultiple[tuple(parents)]
return parents[0]
2 changes: 1 addition & 1 deletion tests/unit/container/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class MyProvider(Provider):
value = 0

@provide(scope=Scope.APP)
async def foo(self) -> AsyncIterable[AnyOf[float, int]]:
async def foo(self) -> AsyncIterable[AnyOf[float, "int"]]:
self.value += 1
yield self.value

Expand Down

0 comments on commit 66696c6

Please sign in to comment.