Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature graph compiler #292

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
):
self.registry = registry
self.child_registries = child_registries
self._context = {DependencyKey(type(self), DEFAULT_COMPONENT): self}
self._context = {CONTAINER_KEY: self}
if context:
for key, value in context.items():
if not isinstance(key, DependencyKey):
Expand Down Expand Up @@ -252,3 +252,6 @@ def make_container(
close_parent=True,
)
return container


CONTAINER_KEY = DependencyKey(Container, DEFAULT_COMPONENT)
256 changes: 256 additions & 0 deletions src/dishka/graph_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import linecache
import re
from textwrap import indent
from typing import Any, Sequence, Mapping

from .container_objects import Exit
from .entities.factory_type import FactoryType, FactoryData
from .entities.key import DependencyKey
from .entities.scope import BaseScope
from .exceptions import NoContextValueError, UnsupportedFactoryError
from .text_rendering import get_name


MAX_DEPTH = 5 # max code depth, otherwise we get too big file


class Node(FactoryData):
__slots__ = (
"dependencies",
"kw_dependencies",
"cache",
)

def __init__(
self,
*,
dependencies: Sequence["Node"],
kw_dependencies: Mapping[str, "Node"],
source: Any,
provides: DependencyKey,
scope: BaseScope,
type_: FactoryType | None,
cache: bool,
) -> None:
super().__init__(
source=source,
provides=provides,
type_=type_,
scope=scope,
)
self.dependencies = dependencies
self.kw_dependencies = kw_dependencies
self.cache = cache


def make_args(args: list[str], kwargs: dict[str, str]) -> str:
res = ", ".join(args)
if not kwargs:
return res
if res:
res += ", "
res += ", ".join(
f"{arg}={var}"
for arg, var in kwargs
)
return res


GENERATOR = """
generator = {source}({args})
{var} = next(generator)
exits.append(Exit(factory_type, generator))
"""
ASYNC_GENERATOR = """
generator = {source}({args})
{var} = await anext(generator)
exits.append(Exit(factory_type, generator))
"""
FACTORY = """
{var} = {source}({args})
"""
ASYNC_FACTORY = """
{var} = await {source}({args})
"""
VALUE = """
{var} = {source}
"""
ALIAS = """
{var} = {args}
"""
CONTEXT = """
raise NoContextValueError({key})
"""
INVALID = """
raise UnsupportedFactoryError(
f"Unsupported factory type {{factory_type}}.",
)
"""
GO_PARENT = """
{var} = getter({key})
"""
GO_PARENT_ASYNC = """
{var} = await getter({key})
"""

ASYNC_BODIES = {
FactoryType.ASYNC_FACTORY: ASYNC_FACTORY,
FactoryType.FACTORY: FACTORY,
FactoryType.ASYNC_GENERATOR: ASYNC_GENERATOR,
FactoryType.GENERATOR: GENERATOR,
FactoryType.VALUE: VALUE,
FactoryType.CONTEXT: CONTEXT,
FactoryType.ALIAS: ALIAS,
None: GO_PARENT_ASYNC,
}
SYNC_BODIES = {
FactoryType.FACTORY: FACTORY,
FactoryType.GENERATOR: GENERATOR,
FactoryType.VALUE: VALUE,
FactoryType.CONTEXT: CONTEXT,
FactoryType.ALIAS: ALIAS,
}
FUNC_TEMPLATE = """
{async_}def {func_name}(getter, exits, context):
cache_getter = context.get
{body}
return {var}
"""

IF_TEMPLATE = """
if ({var} := cache_getter({key}, ...)) is ...:
{deps}
{body}
{cache}
"""
CACHE = "context[{key}] = {var}"

builtins = {getattr(__builtins__, name): name for name in dir(__builtins__)}
def make_name(obj: Any, ns: dict[Any, str]) -> str:
if obj in builtins:
return builtins[obj]
if isinstance(obj, DependencyKey):
key = get_name(obj.type_hint, include_module=False) +"_"+ obj.component
else:
key = get_name(obj, include_module=False)
key = re.sub(r"\W", "_", key)
if key in ns:
key += f"_{len(ns)}"
return key


def make_globals(node: Node, ns: dict[Any, str]):
if node.provides not in ns:
ns[node.provides] = make_name(node.provides, ns)
if node.source not in ns:
ns[node.source] = make_name(node.source, ns)
for dep in node.dependencies:
make_globals(dep, ns)
for dep in node.kw_dependencies.values():
make_globals(dep, ns)


def make_var(node: Node, ns: dict[Any, str]):
return "value_" + ns[node.provides].lower()


def make_if(
node: Node, node_var: str, ns: dict[Any, str],
is_async: bool,
depth: int,
) -> str:
node_key = ns[node.provides]
node_source = ns[node.source]
if depth > MAX_DEPTH or node.type is None:
if is_async:
return GO_PARENT.format(
var=node_var,
key=node_key,
)
else:
return GO_PARENT.format(
var=node_var,
key=node_key,
)

deps = "".join(
make_if(dep, make_var(dep, ns), ns, is_async, depth+1)
for dep in node.dependencies
)
deps += "".join(
make_if(dep, make_var(dep, ns), ns, is_async, depth+1)
for dep in node.kw_dependencies.values()
)
deps = indent(deps, " ")

args = [make_var(dep, ns) for dep in node.dependencies]
kwargs = {
key: make_var(dep, ns)
for key, dep in node.kw_dependencies.items()
}

if is_async:
body_template = ASYNC_BODIES.get(node.type, INVALID)
else:
body_template = SYNC_BODIES.get(node.type, INVALID)

args_str = make_args(args, kwargs)
body_str = body_template.format(
source=node_source,
key=node_key,
var=node_var,
args=args_str,
)

if node.cache:
cache = CACHE.format(var=node_var, key=node_key)
body_str = indent(body_str, " ")
return IF_TEMPLATE.format(
var=node_var,
key=node_key,
deps=deps,
body=body_str,
cache=cache,
)
else:
return "\n".join([deps, body_str])


def make_func(
node: Node, ns: dict[Any, str], func_name: str, is_async: bool,
) -> str:
node_var = make_var(node, ns)
body = make_if(node, node_var, ns, is_async, 0)
body = indent(body, " ")
return FUNC_TEMPLATE.format(
async_="async " if is_async else "",
var=node_var,
body=body,
func_name=func_name,
)


def compile_graph(node: Node, is_async: bool):
ns: dict[Any, str] = {
node.type: "factory_type",
Exit: "Exit",
NoContextValueError: "NoContextValueError",
UnsupportedFactoryError: "UnsupportedFactoryError",
}
make_globals(node, ns)
func_name = f"get_{ns[node.provides].lower()}"
src = make_func(node, ns, func_name, is_async=is_async)
src = "\n".join(line for line in src.splitlines() if line.strip())

print(src)
print()
source_file_name = f"__dishka_factory_{id(node.provides)}"
if is_async:
source_file_name += "_async"
lines = src.splitlines(keepends=True)
linecache.cache[source_file_name] = (
len(src), None, lines, source_file_name,
)
global_ns = {value: key for key, value in ns.items()}
exec(src, global_ns)
return global_ns[func_name]
50 changes: 46 additions & 4 deletions src/dishka/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import time
from collections.abc import Callable
from linecache import cache
from typing import Any, TypeVar, get_args, get_origin


from ._adaptix.type_tools.fundamentals import get_type_vars
from .container_objects import CompiledFactory
from .dependency_source import (
Expand All @@ -10,7 +13,7 @@
from .entities.factory_type import FactoryType
from .entities.key import DependencyKey
from .entities.scope import BaseScope
from .factory_compiler import compile_factory
from .graph_compiler import Node, compile_graph


class Registry:
Expand All @@ -25,7 +28,7 @@ def __init__(self, scope: BaseScope):
def add_factory(
self,
factory: Factory,
provides: DependencyKey| None = None,
provides: DependencyKey | None = None,
) -> None:
if provides is None:
provides = factory.provides
Expand All @@ -40,7 +43,9 @@ def get_compiled(
factory = self.get_factory(dependency)
if not factory:
return None
compiled = compile_factory(factory=factory, is_async=False)
node = make_node(self, dependency)
compiled = compile_graph(node=node, is_async=False)
# compiled = compile_factory(factory=factory, is_async=False)
self.compiled[dependency] = compiled
return compiled

Expand All @@ -53,7 +58,9 @@ def get_compiled_async(
factory = self.get_factory(dependency)
if not factory:
return None
compiled = compile_factory(factory=factory, is_async=True)
node = make_node(self, dependency)
compiled = compile_graph(node=node, is_async=True)
# compiled = compile_factory(factory=factory, is_async=True)
self.compiled[dependency] = compiled
return compiled

Expand Down Expand Up @@ -144,3 +151,38 @@ def _specialize_generic(
cache=factory.cache,
override=factory.override,
)

MAX_DEPTH = 4

def make_node(registry: Registry, key: DependencyKey, cache: dict| None = None, depth: int=0) -> Node:
if cache is None:
cache = {}
factory = registry.get_factory(key)
if not factory or depth>MAX_DEPTH:
node = Node(
provides=key,
scope=registry.scope,
type_=None,
dependencies=[],
kw_dependencies={},
cache=False,
source=None,
)
else:
node = Node(
provides=factory.provides,
scope=factory.scope,
source=factory.source,
type_=factory.type,
cache=factory.cache,
dependencies=[
make_node(registry, dep, cache, depth+1)
for dep in factory.dependencies
],
kw_dependencies={
key: make_node(registry, dep, cache, depth+1)
for key, dep in factory.kw_dependencies.items()
},
)
cache[key] = node
return node
Loading