Skip to content

Commit

Permalink
Add fix for bad result references from the API (#14239)
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw authored Jun 21, 2024
1 parent 9ba12f1 commit 90b3fd9
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 10 deletions.
14 changes: 12 additions & 2 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def result(self: "State[R]", raise_on_failure: bool = False) -> Union[R, Excepti
...

def result(
self, raise_on_failure: bool = True, fetch: Optional[bool] = None
self,
raise_on_failure: bool = True,
fetch: Optional[bool] = None,
retry_result_failure: bool = True,
) -> Union[R, Exception]:
"""
Retrieve the result attached to this state.
Expand All @@ -191,6 +194,8 @@ def result(
results into data. For synchronous users, this defaults to `True`.
For asynchronous users, this defaults to `False` for backwards
compatibility.
retry_result_failure: a boolean specifying whether to retry on failures to
load the result from result storage
Raises:
TypeError: If the state is failed but the result is not an exception.
Expand Down Expand Up @@ -253,7 +258,12 @@ def result(
"""
from prefect.states import get_state_result

return get_state_result(self, raise_on_failure=raise_on_failure, fetch=fetch)
return get_state_result(
self,
raise_on_failure=raise_on_failure,
fetch=fetch,
retry_result_failure=retry_result_failure,
)

def to_state_create(self):
"""
Expand Down
34 changes: 26 additions & 8 deletions src/prefect/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@


def get_state_result(
state: State[R], raise_on_failure: bool = True, fetch: Optional[bool] = None
state: State[R],
raise_on_failure: bool = True,
fetch: Optional[bool] = None,
retry_result_failure: bool = True,
) -> R:
"""
Get the result from a state.
Expand Down Expand Up @@ -62,37 +65,50 @@ def get_state_result(

return state.data
else:
return _get_state_result(state, raise_on_failure=raise_on_failure)
return _get_state_result(
state,
raise_on_failure=raise_on_failure,
retry_result_failure=retry_result_failure,
)


RESULT_READ_MAXIMUM_ATTEMPTS = 10
RESULT_READ_RETRY_DELAY = 0.25


async def _get_state_result_data_with_retries(state: State[R]) -> R:
async def _get_state_result_data_with_retries(
state: State[R], retry_result_failure: bool = True
) -> R:
# Results may be written asynchronously, possibly after their corresponding
# state has been written and events have been emitted, so we should give some
# grace here about missing results. The exception below could come in the form
# of a missing file, a short read, or other types of errors depending on the
# result storage backend.
for i in range(1, RESULT_READ_MAXIMUM_ATTEMPTS + 1):
if retry_result_failure is False:
max_attempts = 1
else:
max_attempts = RESULT_READ_MAXIMUM_ATTEMPTS

for i in range(1, max_attempts + 1):
try:
return await state.data.get()
except Exception as e:
if i == RESULT_READ_MAXIMUM_ATTEMPTS:
if i == max_attempts:
raise
logger.debug(
"Exception %r while reading result, retry %s/%s in %ss...",
e,
i,
RESULT_READ_MAXIMUM_ATTEMPTS,
max_attempts,
RESULT_READ_RETRY_DELAY,
)
await asyncio.sleep(RESULT_READ_RETRY_DELAY)


@sync_compatible
async def _get_state_result(state: State[R], raise_on_failure: bool) -> R:
async def _get_state_result(
state: State[R], raise_on_failure: bool, retry_result_failure: bool = True
) -> R:
"""
Internal implementation for `get_state_result` without async backwards compatibility
"""
Expand All @@ -111,7 +127,9 @@ async def _get_state_result(state: State[R], raise_on_failure: bool) -> R:
raise await get_state_exception(state)

if isinstance(state.data, BaseResult):
result = await _get_state_result_data_with_retries(state)
result = await _get_state_result_data_with_retries(
state, retry_result_failure=retry_result_failure
)

elif state.data is None:
if state.is_failed() or state.is_crashed() or state.is_cancelled():
Expand Down
10 changes: 10 additions & 0 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ def begin_run(self):
new_state = Running()
state = self.set_state(new_state)

# TODO: this is temporary until the API stops rejecting state transitions
# and the client / transaction store becomes the source of truth
# this is a bandaid caused by the API storing a Completed state with a bad
# result reference that no longer exists
if state.is_completed():
try:
state.result(retry_result_failure=False, _sync=True)
except Exception:
state = self.set_state(new_state, force=True)

BACKOFF_MAX = 10
backoff_count = 0

Expand Down
28 changes: 28 additions & 0 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import os
import time
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -1243,6 +1244,33 @@ def my_param_flow(x: int, other_val: str):
assert third_result not in [first_result, second_result]
assert fourth_result not in [first_result, second_result]

async def test_bad_api_result_references_cause_reruns(self, tmp_path: Path):
fs = LocalFileSystem(basepath=tmp_path)

PAYLOAD = {"return": 42}

@task(result_storage=fs, result_storage_key="tmp-first")
async def first():
return PAYLOAD["return"], get_run_context().task_run

result, task_run = await run_task_async(first)

assert result == 42
assert await fs.read_path("tmp-first")

# delete record
path = fs._resolve_path("tmp-first")
os.unlink(path)
with pytest.raises(ValueError, match="does not exist"):
assert await fs.read_path("tmp-first")

# rerun with same task run ID
PAYLOAD["return"] = "bar"
result, task_run = await run_task_async(first, task_run=task_run)

assert result == "bar"
assert await fs.read_path("tmp-first")


class TestGenerators:
async def test_generator_task(self):
Expand Down

0 comments on commit 90b3fd9

Please sign in to comment.