Skip to content

Commit

Permalink
Add ability to wait on or get results of groups of futures (#14234)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Jun 22, 2024
1 parent 90b3fd9 commit 883fcb0
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 34 deletions.
7 changes: 3 additions & 4 deletions src/integrations/prefect-dask/prefect_dask/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def count_to(highest_number):
Coroutine,
Dict,
Iterable,
List,
Optional,
Set,
TypeVar,
Expand All @@ -91,7 +90,7 @@ def count_to(highest_number):
from typing_extensions import ParamSpec

from prefect.client.schemas.objects import State, TaskRunInput
from prefect.futures import PrefectFuture, PrefectWrappedFuture
from prefect.futures import PrefectFuture, PrefectFutureList, PrefectWrappedFuture
from prefect.logging.loggers import get_logger, get_run_logger
from prefect.task_runners import TaskRunner
from prefect.tasks import Task
Expand Down Expand Up @@ -366,7 +365,7 @@ def map(
task: "Task[P, Coroutine[Any, Any, R]]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[PrefectDaskFuture[R]]:
) -> PrefectFutureList[PrefectDaskFuture[R]]:
...

@overload
Expand All @@ -375,7 +374,7 @@ def map(
task: "Task[Any, R]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[PrefectDaskFuture[R]]:
) -> PrefectFutureList[PrefectDaskFuture[R]]:
...

def map(
Expand Down
7 changes: 3 additions & 4 deletions src/integrations/prefect-ray/prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def count_to(highest_number):
Coroutine,
Dict,
Iterable,
List,
Optional,
Set,
TypeVar,
Expand All @@ -92,7 +91,7 @@ def count_to(highest_number):

from prefect.client.schemas.objects import TaskRunInput
from prefect.context import serialize_context
from prefect.futures import PrefectFuture, PrefectWrappedFuture
from prefect.futures import PrefectFuture, PrefectFutureList, PrefectWrappedFuture
from prefect.logging.loggers import get_logger, get_run_logger
from prefect.states import State, exception_to_crashed_state
from prefect.task_engine import run_task_async, run_task_sync
Expand Down Expand Up @@ -291,7 +290,7 @@ def map(
task: "Task[P, Coroutine[Any, Any, R]]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[PrefectRayFuture[R]]:
) -> PrefectFutureList[PrefectRayFuture[R]]:
...

@overload
Expand All @@ -300,7 +299,7 @@ def map(
task: "Task[Any, R]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[PrefectRayFuture[R]]:
) -> PrefectFutureList[PrefectRayFuture[R]]:
...

def map(
Expand Down
67 changes: 63 additions & 4 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import concurrent.futures
import inspect
import uuid
from collections.abc import Iterator
from functools import partial
from typing import Any, Generic, Optional, Set, Union, cast
from typing import Any, Generic, List, Optional, Set, Union, cast

from typing_extensions import TypeVar

Expand All @@ -16,6 +17,7 @@
from prefect.utilities.annotations import quote
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.collections import StopVisiting, visit_collection
from prefect.utilities.timeout import timeout as timeout_context

F = TypeVar("F")
R = TypeVar("R")
Expand Down Expand Up @@ -62,7 +64,7 @@ def wait(self, timeout: Optional[float] = None) -> None:
If the task run has already completed, this method will return immediately.
Args:
- timeout: The maximum number of seconds to wait for the task run to complete.
timeout: The maximum number of seconds to wait for the task run to complete.
If the task run has not completed after the timeout has elapsed, this method will return.
"""

Expand All @@ -79,9 +81,9 @@ def result(
If the task run has not completed, this method will wait for the task run to complete.
Args:
- timeout: The maximum number of seconds to wait for the task run to complete.
timeout: The maximum number of seconds to wait for the task run to complete.
If the task run has not completed after the timeout has elapsed, this method will return.
- raise_on_failure: If `True`, an exception will be raised if the task run fails.
raise_on_failure: If `True`, an exception will be raised if the task run fails.
Returns:
The result of the task run.
Expand Down Expand Up @@ -233,6 +235,63 @@ def __eq__(self, other):
return self.task_run_id == other.task_run_id


class PrefectFutureList(list, Iterator, Generic[F]):
"""
A list of Prefect futures.
This class provides methods to wait for all futures
in the list to complete and to retrieve the results of all task runs.
"""

def wait(self, timeout: Optional[float] = None) -> None:
"""
Wait for all futures in the list to complete.
Args:
timeout: The maximum number of seconds to wait for all futures to
complete. This method will not raise if the timeout is reached.
"""
try:
with timeout_context(timeout):
for future in self:
future.wait()
except TimeoutError:
logger.debug("Timed out waiting for all futures to complete.")
return

def result(
self,
timeout: Optional[float] = None,
raise_on_failure: bool = True,
) -> List:
"""
Get the results of all task runs associated with the futures in the list.
Args:
timeout: The maximum number of seconds to wait for all futures to
complete.
raise_on_failure: If `True`, an exception will be raised if any task run fails.
Returns:
A list of results of the task runs.
Raises:
TimeoutError: If the timeout is reached before all futures complete.
"""
try:
with timeout_context(timeout):
return [
future.result(raise_on_failure=raise_on_failure) for future in self
]
except TimeoutError as exc:
# timeout came from inside the task
if "Scope timed out after {timeout} second(s)." not in str(exc):
raise
raise TimeoutError(
f"Timed out waiting for all futures to complete within {timeout} seconds"
) from exc


def resolve_futures_to_states(
expr: Union[PrefectFuture, Any],
) -> Union[State, Any]:
Expand Down
15 changes: 8 additions & 7 deletions src/prefect/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
PrefectConcurrentFuture,
PrefectDistributedFuture,
PrefectFuture,
PrefectFutureList,
)
from prefect.logging.loggers import get_logger, get_run_logger
from prefect.utilities.annotations import allow_failure, quote, unmapped
Expand Down Expand Up @@ -97,7 +98,7 @@ def map(
task: "Task",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[F]:
) -> PrefectFutureList[F]:
"""
Submit multiple tasks to the task run engine.
Expand Down Expand Up @@ -169,7 +170,7 @@ def map(

map_length = list(lengths)[0]

futures = []
futures: List[PrefectFuture] = []
for i in range(map_length):
call_parameters = {
key: value[i] for key, value in iterable_parameters.items()
Expand Down Expand Up @@ -199,7 +200,7 @@ def map(
)
)

return futures
return PrefectFutureList(futures)

def __enter__(self):
if self._started:
Expand Down Expand Up @@ -316,7 +317,7 @@ def map(
task: "Task[P, Coroutine[Any, Any, R]]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[PrefectConcurrentFuture[R]]:
) -> PrefectFutureList[PrefectConcurrentFuture[R]]:
...

@overload
Expand All @@ -325,7 +326,7 @@ def map(
task: "Task[Any, R]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[PrefectConcurrentFuture[R]]:
) -> PrefectFutureList[PrefectConcurrentFuture[R]]:
...

def map(
Expand Down Expand Up @@ -427,7 +428,7 @@ def map(
task: "Task[P, Coroutine[Any, Any, R]]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[PrefectDistributedFuture[R]]:
) -> PrefectFutureList[PrefectDistributedFuture[R]]:
...

@overload
Expand All @@ -436,7 +437,7 @@ def map(
task: "Task[Any, R]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
) -> List[PrefectDistributedFuture[R]]:
) -> PrefectFutureList[PrefectDistributedFuture[R]]:
...

def map(
Expand Down
29 changes: 15 additions & 14 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
TaskRunContext,
serialize_context,
)
from prefect.futures import PrefectDistributedFuture, PrefectFuture
from prefect.futures import PrefectDistributedFuture, PrefectFuture, PrefectFutureList
from prefect.logging.loggers import get_logger
from prefect.results import ResultFactory, ResultSerializer, ResultStorage
from prefect.settings import (
Expand Down Expand Up @@ -996,23 +996,23 @@ def map(
self: "Task[P, NoReturn]",
*args: P.args,
**kwargs: P.kwargs,
) -> List[PrefectFuture[NoReturn]]:
) -> PrefectFutureList[PrefectFuture[NoReturn]]:
...

@overload
def map(
self: "Task[P, Coroutine[Any, Any, T]]",
*args: P.args,
**kwargs: P.kwargs,
) -> List[PrefectFuture[T]]:
) -> PrefectFutureList[PrefectFuture[T]]:
...

@overload
def map(
self: "Task[P, T]",
*args: P.args,
**kwargs: P.kwargs,
) -> List[PrefectFuture[T]]:
) -> PrefectFutureList[PrefectFuture[T]]:
...

@overload
Expand All @@ -1021,7 +1021,7 @@ def map(
*args: P.args,
return_state: Literal[True],
**kwargs: P.kwargs,
) -> List[State[T]]:
) -> PrefectFutureList[State[T]]:
...

@overload
Expand All @@ -1030,7 +1030,7 @@ def map(
*args: P.args,
return_state: Literal[True],
**kwargs: P.kwargs,
) -> List[State[T]]:
) -> PrefectFutureList[State[T]]:
...

def map(
Expand All @@ -1044,8 +1044,9 @@ def map(
"""
Submit a mapped run of the task to a worker.
Must be called within a flow function. If writing an async task, this
call must be awaited.
Must be called within a flow run context. Will return a list of futures
that should be waited on before exiting the flow context to ensure all
mapped tasks have completed.
Must be called with at least one iterable and all iterables must be
the same length. Any arguments that are not iterable will be treated as
Expand Down Expand Up @@ -1083,15 +1084,14 @@ def map(
>>> from prefect import flow
>>> @flow
>>> def my_flow():
>>> my_task.map([1, 2, 3])
>>> return my_task.map([1, 2, 3])
Wait for all mapped tasks to finish
>>> @flow
>>> def my_flow():
>>> futures = my_task.map([1, 2, 3])
>>> for future in futures:
>>> future.wait()
>>> futures.wait():
>>> # Now all of the mapped tasks have finished
>>> my_task(10)
Expand All @@ -1100,8 +1100,8 @@ def map(
>>> @flow
>>> def my_flow():
>>> futures = my_task.map([1, 2, 3])
>>> for future in futures:
>>> print(future.result())
>>> for x in futures.result():
>>> print(x)
>>> my_flow()
2
3
Expand All @@ -1122,6 +1122,7 @@ def map(
>>>
>>> # task 2 will wait for task_1 to complete
>>> y = task_2.map([1, 2, 3], wait_for=[x])
>>> return y
Use a non-iterable input as a constant across mapped tasks
>>> @task
Expand All @@ -1130,7 +1131,7 @@ def map(
>>>
>>> @flow
>>> def my_flow():
>>> display.map("Check it out: ", [1, 2, 3])
>>> return display.map("Check it out: ", [1, 2, 3])
>>>
>>> my_flow()
Check it out: 1
Expand Down
Loading

0 comments on commit 883fcb0

Please sign in to comment.