diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 4e5a417cbf8b..a4c948fb5bc7 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -185,6 +185,22 @@ def _resolve_parameters(self): self.parameters = resolved_parameters + def _set_custom_task_run_name(self): + from prefect.utilities.engine import _resolve_custom_task_run_name + + # update the task run name if necessary + if not self._task_name_set and self.task.task_run_name: + task_run_name = _resolve_custom_task_run_name( + task=self.task, parameters=self.parameters or {} + ) + + self.logger.extra["task_run_name"] = task_run_name + self.logger.debug( + f"Renamed task run {self.task_run.name!r} to {task_run_name!r}" + ) + self.task_run.name = task_run_name + self._task_name_set = True + def _wait_for_dependencies(self): if not self.wait_for: return @@ -349,6 +365,7 @@ def call_hooks(self, state: Optional[State] = None): def begin_run(self): try: self._resolve_parameters() + self._set_custom_task_run_name() self._wait_for_dependencies() except UpstreamTaskError as upstream_exc: state = self.set_state( @@ -578,7 +595,6 @@ def handle_crash(self, exc: BaseException) -> None: @contextmanager def setup_run_context(self, client: Optional[SyncPrefectClient] = None): from prefect.utilities.engine import ( - _resolve_custom_task_run_name, should_log_prints, ) @@ -610,18 +626,6 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore - # update the task run name if necessary - if not self._task_name_set and self.task.task_run_name: - task_run_name = _resolve_custom_task_run_name( - task=self.task, parameters=self.parameters - ) - - self.logger.extra["task_run_name"] = task_run_name - self.logger.debug( - f"Renamed task run {self.task_run.name!r} to {task_run_name!r}" - ) - self.task_run.name = task_run_name - self._task_name_set = True yield @contextmanager @@ -870,6 +874,7 @@ async def call_hooks(self, state: Optional[State] = None): async def begin_run(self): try: self._resolve_parameters() + self._set_custom_task_run_name() self._wait_for_dependencies() except UpstreamTaskError as upstream_exc: state = await self.set_state( @@ -1092,7 +1097,6 @@ async def handle_crash(self, exc: BaseException) -> None: @asynccontextmanager async def setup_run_context(self, client: Optional[PrefectClient] = None): from prefect.utilities.engine import ( - _resolve_custom_task_run_name, should_log_prints, ) @@ -1123,16 +1127,6 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None): self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore - if not self._task_name_set and self.task.task_run_name: - task_run_name = _resolve_custom_task_run_name( - task=self.task, parameters=self.parameters - ) - self.logger.extra["task_run_name"] = task_run_name - self.logger.debug( - f"Renamed task run {self.task_run.name!r} to {task_run_name!r}" - ) - self.task_run.name = task_run_name - self._task_name_set = True yield @asynccontextmanager diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index a6e27f05637e..cc271ec226c2 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -310,7 +310,9 @@ def __init__( Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] ] = None, cache_expiration: Optional[datetime.timedelta] = None, - task_run_name: Optional[Union[Callable[[], str], str]] = None, + task_run_name: Optional[ + Union[Callable[[], str], Callable[[Dict[str, Any]], str], str] + ] = None, retries: Optional[int] = None, retry_delay_seconds: Optional[ Union[ @@ -531,7 +533,9 @@ def with_options( cache_key_fn: Optional[ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] ] = None, - task_run_name: Optional[Union[Callable[[], str], str, Type[NotSet]]] = NotSet, + task_run_name: Optional[ + Union[Callable[[], str], Callable[[Dict[str, Any]], str], str, Type[NotSet]] + ] = NotSet, cache_expiration: Optional[datetime.timedelta] = None, retries: Union[int, Type[NotSet]] = NotSet, retry_delay_seconds: Union[ @@ -1583,7 +1587,9 @@ def task( Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] ] = None, cache_expiration: Optional[datetime.timedelta] = None, - task_run_name: Optional[Union[Callable[[], str], str]] = None, + task_run_name: Optional[ + Union[Callable[[], str], Callable[[Dict[str, Any]], str], str] + ] = None, retries: int = 0, retry_delay_seconds: Union[ float, @@ -1620,7 +1626,9 @@ def task( Callable[["TaskRunContext", Dict[str, Any]], Optional[str]], None ] = None, cache_expiration: Optional[datetime.timedelta] = None, - task_run_name: Optional[Union[Callable[[], str], str]] = None, + task_run_name: Optional[ + Union[Callable[[], str], Callable[[Dict[str, Any]], str], str] + ] = None, retries: Optional[int] = None, retry_delay_seconds: Union[ float, int, List[float], Callable[[int], List[float]], None diff --git a/src/prefect/utilities/engine.py b/src/prefect/utilities/engine.py index 0f13a44a7a74..fd94104d3773 100644 --- a/src/prefect/utilities/engine.py +++ b/src/prefect/utilities/engine.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import inspect import os import signal import time @@ -684,7 +685,15 @@ def _resolve_custom_flow_run_name(flow: Flow, parameters: Dict[str, Any]) -> str def _resolve_custom_task_run_name(task: Task, parameters: Dict[str, Any]) -> str: if callable(task.task_run_name): - task_run_name = task.task_run_name() + sig = inspect.signature(task.task_run_name) + + # If the callable accepts a 'parameters' kwarg, pass the entire parameters dict + if "parameters" in sig.parameters: + task_run_name = task.task_run_name(parameters=parameters) + else: + # If it doesn't expect parameters, call it without arguments + task_run_name = task.task_run_name() + if not isinstance(task_run_name, str): raise TypeError( f"Callable {task.task_run_name} for 'task_run_name' returned type" diff --git a/tests/public/flows/test_flow_with_mapped_tasks.py b/tests/public/flows/test_flow_with_mapped_tasks.py new file mode 100644 index 000000000000..064f52364cbb --- /dev/null +++ b/tests/public/flows/test_flow_with_mapped_tasks.py @@ -0,0 +1,66 @@ +"""This is a regression test for https://github.com/PrefectHQ/prefect/issues/15747""" + +from typing import Any + +from prefect import flow, task +from prefect.context import TaskRunContext +from prefect.runtime import task_run + +names = [] + + +def generate_task_run_name(parameters: dict) -> str: + names.append(f'{task_run.task_name} - input: {parameters["input"]["number"]}') + return names[-1] + + +def alternate_task_run_name() -> str: + names.append("wildcard!") + return names[-1] + + +@task(task_run_name="other {input[number]}") +def other_task(input: dict) -> dict: + names.append(TaskRunContext.get().task_run.name) + return input + + +@task(log_prints=True, task_run_name=generate_task_run_name) +def increment_number(input: dict) -> dict: + input["number"] += 1 + return input + + +@flow +def double_increment_flow() -> list[dict[str, Any]]: + inputs = [ + {"number": 1, "is_even": False}, + {"number": 2, "is_even": True}, + ] + + first_increment = increment_number.map(input=inputs) + second_increment = increment_number.with_options( + task_run_name=alternate_task_run_name + ).map(input=first_increment) + final_results = second_increment.result() + + other_task.map(inputs).wait() + + print(f"Final results: {final_results}") + return final_results + + +async def test_flow_with_mapped_tasks(): + results = double_increment_flow() + assert results == [ + {"number": 3, "is_even": False}, + {"number": 4, "is_even": True}, + ] + assert set(names) == { + "increment_number - input: 1", + "increment_number - input: 2", + "wildcard!", + "wildcard!", + "other 3", + "other 4", + }