Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport 2.11] prevent long recursive stack traces #24

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions python/ray/air/_internal/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import socket
from contextlib import closing
import copy
import logging
import os
import queue
import socket
import threading
from contextlib import closing
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -35,7 +36,13 @@ class StartTraceback(Exception):


def skip_exceptions(exc: Optional[Exception]) -> Exception:
"""Skip all contained `StartTracebacks` to reduce traceback output"""
"""Skip all contained `StartTracebacks` to reduce traceback output.
Returns a shallow copy of the exception with all `StartTracebacks` removed.
If the RAY_AIR_FULL_TRACEBACKS environment variable is set,
the original exception (not a copy) is returned.
"""
should_not_shorten = bool(int(os.environ.get("RAY_AIR_FULL_TRACEBACKS", "0")))

if should_not_shorten:
Expand All @@ -45,12 +52,15 @@ def skip_exceptions(exc: Optional[Exception]) -> Exception:
# If this is a StartTraceback, skip
return skip_exceptions(exc.__cause__)

# Else, make sure nested exceptions are properly skipped
# Perform a shallow copy to prevent recursive __cause__/__context__.
new_exc = copy.copy(exc).with_traceback(exc.__traceback__)

# Make sure nested exceptions are properly skipped.
cause = getattr(exc, "__cause__", None)
if cause:
exc.__cause__ = skip_exceptions(cause)
new_exc.__cause__ = skip_exceptions(cause)

return exc
return new_exc


def exception_cause(exc: Optional[Exception]) -> Optional[Exception]:
Expand Down
43 changes: 42 additions & 1 deletion python/ray/air/tests/test_tracebacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest

import ray
from ray import cloudpickle
from tblib import pickling_support
from ray.train import ScalingConfig
from ray.air._internal.util import StartTraceback, skip_exceptions
from ray.air._internal.util import StartTraceback, skip_exceptions, exception_cause
from ray.train.data_parallel_trainer import DataParallelTrainer

from ray.tune import Tuner
Expand Down Expand Up @@ -47,6 +49,45 @@ def test_short_traceback(levels):
assert i == levels - start_traceback + 1


def test_recursion():
"""Test that the skipped exception does not point to the original exception."""
root_exception = None

with pytest.raises(StartTraceback) as exc_info:
try:
raise Exception("Root Exception")
except Exception as e:
root_exception = e
raise StartTraceback from root_exception

assert root_exception, "Root exception was not captured."

start_traceback = exc_info.value
skipped_exception = skip_exceptions(start_traceback)

assert (
root_exception != skipped_exception
), "Skipped exception points to the original exception."


def test_tblib():
"""Test that tblib does not cause a maximum recursion error."""

with pytest.raises(Exception) as exc_info:
try:
try:
raise Exception("Root Exception")
except Exception as root_exception:
raise StartTraceback from root_exception
except Exception as start_traceback:
raise skip_exceptions(start_traceback) from exception_cause(start_traceback)

pickling_support.install()
reraised_exception = exc_info.value
# This should not raise a RecursionError/PicklingError.
cloudpickle.dumps(reraised_exception)


def test_traceback_tuner(ray_start_2_cpus):
"""Ensure that the Tuner's stack trace is not too long."""

Expand Down
Loading