Skip to content

Commit

Permalink
PyContext Enter/Exit Callbacks
Browse files Browse the repository at this point in the history
Summary: Add new context event watcher callback system

Differential Revision: D57583773
  • Loading branch information
fried authored and facebook-github-bot committed May 21, 2024
1 parent 7f4c505 commit 603ec6e
Show file tree
Hide file tree
Showing 8 changed files with 404 additions and 0 deletions.
44 changes: 44 additions & 0 deletions Doc/c-api/contextvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,50 @@ Context object management functions:
current context for the current thread. Returns ``0`` on success,
and ``-1`` on error.
.. c:function:: int PyContext_AddWatcher(PyCode_WatchCallback callback)
Register *callback* as a context object watcher for the current interpreter.
Return an ID which may be passed to :c:func:`PyContext_ClearWatcher`.
In case of error (e.g. no more watcher IDs available),
return ``-1`` and set an exception.
.. versionadded:: 3.14
.. c:function:: int PyContext_ClearWatcher(int watcher_id)
Clear watcher identified by *watcher_id* previously returned from
:c:func:`PyContext_AddWatcher` for the current interpreter.
Return ``0`` on success, or ``-1`` and set an exception on error
(e.g. if the given *watcher_id* was never registered.)
.. versionadded:: 3.14
.. c:type:: PyContextEvent
Enumeration of possible context object watcher events:
- ``PY_CONTEXT_EVENT_ENTER``
- ``PY_CONTEXT_EVENT_EXIT``
.. versionadded:: 3.14
.. c:type:: int (*PyContecxt_WatchCallback)(PyContextEvent event, PyContext* ctx)
Type of a context object watcher callback function.
If *event* is ``PY_CONTEXT_EVENT_ENTER``, then the callback is invoked
after `ctx` has been set as the current context for the current thread.
Otherwise, the callback is invoked before the deactivation of *ctx* as the current context
and the restoration of the previous contex object for the current thread.
If the callback returns with an exception set, it must return ``-1``; this
exception will be printed as an unraisable exception using
:c:func:`PyErr_FormatUnraisable`. Otherwise it should return ``0``.
There may already be a pending exception set on entry to the callback. In
this case, the callback should return ``0`` with the same exception still
set. This means the callback may not call any other API that can set an
exception unless it saves and clears the exception state first, and restores
it before returning.
.. versionadded:: 3.14
Context variable functions:
Expand Down
37 changes: 37 additions & 0 deletions Include/cpython/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,43 @@ PyAPI_FUNC(PyObject *) PyContext_CopyCurrent(void);
PyAPI_FUNC(int) PyContext_Enter(PyObject *);
PyAPI_FUNC(int) PyContext_Exit(PyObject *);

#define PY_FOREACH_CONTEXT_EVENT(V) \
V(ENTER) \
V(EXIT)

typedef enum {
#define PY_DEF_EVENT(op) PY_CONTEXT_EVENT_##op,
PY_FOREACH_CONTEXT_EVENT(PY_DEF_EVENT)
#undef PY_DEF_EVENT
} PyContextEvent;

/*
* A Callback to clue in non-python contexts impls about a
* change in the active python context.
*
* The callback is invoked with the event and a reference to =
* the context after its entered and before its exited.
*
* if the callback returns with an exception set, it must return -1. Otherwise
* it should return 0
*/
typedef int (*PyContext_WatchCallback)(PyContextEvent, PyContext *);

/*
* Register a per-interpreter callback that will be invoked for context object
* enter/exit events.
*
* Returns a handle that may be passed to PyContext_ClearWatcher on success,
* or -1 and sets and error if no more handles are available.
*/
PyAPI_FUNC(int) PyContext_AddWatcher(PyContext_WatchCallback callback);

/*
* Clear the watcher associated with the watcher_id handle.
*
* Returns 0 on success or -1 if no watcher exists for the provided id.
*/
PyAPI_FUNC(int) PyContext_ClearWatcher(int watcher_id);

/* Create a new context variable.
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "pycore_hamt.h" /* PyHamtObject */

#define CONTEXT_MAX_WATCHERS 8

extern PyTypeObject _PyContextTokenMissing_Type;

Expand Down
2 changes: 2 additions & 0 deletions Include/internal/pycore_interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,10 @@ struct _is {
PyObject *audit_hooks;
PyType_WatchCallback type_watchers[TYPE_MAX_WATCHERS];
PyCode_WatchCallback code_watchers[CODE_MAX_WATCHERS];
PyContext_WatchCallback context_watchers[CONTEXT_MAX_WATCHERS];
// One bit is set for each non-NULL entry in code_watchers
uint8_t active_code_watchers;
uint8_t active_context_watchers;

struct _py_object_state object_state;
struct _Py_unicode_state unicode;
Expand Down
85 changes: 85 additions & 0 deletions Lib/test/test_capi/test_watchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from contextlib import contextmanager, ExitStack
from test.support import catch_unraisable_exception, import_helper
import contextvars


# Skip this test if the _testcapi module isn't available.
Expand Down Expand Up @@ -555,5 +556,89 @@ def test_allocate_too_many_watchers(self):
_testcapi.allocate_too_many_func_watchers()


class TestContextObjectWatchers(unittest.TestCase):
@contextmanager
def context_watcher(self, which_watcher):
wid = _testcapi.add_context_watcher(which_watcher)
try:
yield wid
finally:
_testcapi.clear_context_watcher(wid)

def assert_event_counts(self, exp_enter_0, exp_exit_0,
exp_enter_1, exp_exit_1):
self.assertEqual(
exp_enter_0, _testcapi.get_context_watcher_num_enter_events(0))
self.assertEqual(
exp_exit_0, _testcapi.get_context_watcher_num_exit_events(0))
self.assertEqual(
exp_enter_1, _testcapi.get_context_watcher_num_enter_events(1))
self.assertEqual(
exp_exit_1, _testcapi.get_context_watcher_num_exit_events(1))

def test_context_object_events_dispatched(self):
# verify that all counts are zero before any watchers are registered
self.assert_event_counts(0, 0, 0, 0)

# verify that all counts remain zero when a context object is
# entered and exited with no watchers registered
ctx = contextvars.copy_context()
ctx.run(self.assert_event_counts, 0, 0, 0, 0)
self.assert_event_counts(0, 0, 0, 0)

# verify counts are as expected when first watcher is registered
with self.context_watcher(0):
self.assert_event_counts(0, 0, 0, 0)
ctx.run(self.assert_event_counts, 1, 0, 0, 0)
self.assert_event_counts(1, 1, 0, 0)

# again with second watcher registered
with self.context_watcher(1):
self.assert_event_counts(1, 1, 0, 0)
ctx.run(self.assert_event_counts, 2, 1, 1, 0)
self.assert_event_counts(2, 2, 1, 1)

# verify counts are reset and don't change after both watchers are cleared
ctx.run(self.assert_event_counts, 0, 0, 0, 0)
self.assert_event_counts(0, 0, 0, 0)

def test_enter_error(self):
with self.context_watcher(2):
with catch_unraisable_exception() as cm:
ctx = contextvars.copy_context()
ctx.run(int, 0)
self.assertEqual(
cm.unraisable.object,
ctx
# For main branch
# f"PY_CONTEXT_EVENT_ENTER watcher callback for {ctx!r}"
#
)
self.assertEqual(str(cm.unraisable.exc_value), "boom!")

def test_exit_error(self):
ctx = contextvars.copy_context()
def _in_context(stack):
stack.enter_context(self.context_watcher(2))

with catch_unraisable_exception() as cm:
with ExitStack() as stack:
ctx.run(_in_context, stack)
self.assertEqual(str(cm.unraisable.exc_value), "boom!")

def test_clear_out_of_range_watcher_id(self):
with self.assertRaisesRegex(ValueError, r"Invalid context watcher ID -1"):
_testcapi.clear_context_watcher(-1)
with self.assertRaisesRegex(ValueError, r"Invalid context watcher ID 8"):
_testcapi.clear_context_watcher(8) # CONTEXT_MAX_WATCHERS = 8

def test_clear_unassigned_watcher_id(self):
with self.assertRaisesRegex(ValueError, r"No context watcher set for ID 1"):
_testcapi.clear_context_watcher(1)

def test_allocate_too_many_watchers(self):
with self.assertRaisesRegex(RuntimeError, r"no more context watcher IDs available"):
_testcapi.allocate_too_many_context_watchers()

if __name__ == "__main__":
unittest.main()
152 changes: 152 additions & 0 deletions Modules/_testcapi/watchers.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#define Py_BUILD_CORE
#include "pycore_function.h" // FUNC_MAX_WATCHERS
#include "pycore_code.h" // CODE_MAX_WATCHERS
#include "pycore_context.h" // CONTEXT_MAX_WATCHERS

/*[clinic input]
module _testcapi
Expand Down Expand Up @@ -625,6 +626,147 @@ allocate_too_many_func_watchers(PyObject *self, PyObject *args)
Py_RETURN_NONE;
}

// Test contexct object watchers
#define NUM_CONTEXT_WATCHERS 2
static context_watcher_ids[NUM_CONTEXT_WATCHERS] = {-1, -1};
static int num_context_object_enter_events[NUM_CONTEXT_WATCHERS] = {0, 0};
static int num_context_object_exit_events[NUM_CONTEXT_WATCHERS] = {0, 0};

static int
handle_context_watcher_event(int which_watcher, PyContextEvent event, PyContext *ctx) {
if (event == PY_CONTEXT_EVENT_ENTER) {
num_context_object_enter_events[which_watcher]++;
}
else if (event == PY_CONTEXT_EVENT_EXIT) {
num_context_object_exit_events[which_watcher]++;
}
else {
return -1;
}
return 0;
}

static int
first_context_watcher_callback(PyContextEvent event, PyContext *ctx) {
return handle_context_watcher_event(0, event, ctx);
}

static int
second_context_watcher_callback(PyContextEvent event, PyContext *ctx) {
return handle_context_watcher_event(1, event, ctx);
}

static int
noop_context_event_handler(PyContextEvent event, PyContext *ctx) {
return 0;
}

static int
error_context_event_handler(PyContextEvent event, PyContext *ctx) {
PyErr_SetString(PyExc_RuntimeError, "boom!");
return -1;
}

static PyObject *
add_context_watcher(PyObject *self, PyObject *which_watcher)
{
int watcher_id;
assert(PyLong_Check(which_watcher));
long which_l = PyLong_AsLong(which_watcher);
if (which_l == 0) {
watcher_id = PyContext_AddWatcher(first_context_watcher_callback);
context_watcher_ids[0] = watcher_id;
num_context_object_enter_events[0] = 0;
num_context_object_exit_events[0] = 0;
}
else if (which_l == 1) {
watcher_id = PyContext_AddWatcher(second_context_watcher_callback);
context_watcher_ids[1] = watcher_id;
num_context_object_enter_events[1] = 0;
num_context_object_exit_events[1] = 0;
}
else if (which_l == 2) {
watcher_id = PyContext_AddWatcher(error_context_event_handler);
}
else {
PyErr_Format(PyExc_ValueError, "invalid watcher %d", which_l);
return NULL;
}
if (watcher_id < 0) {
return NULL;
}
return PyLong_FromLong(watcher_id);
}

static PyObject *
clear_context_watcher(PyObject *self, PyObject *watcher_id)
{
assert(PyLong_Check(watcher_id));
long watcher_id_l = PyLong_AsLong(watcher_id);
if (PyContext_ClearWatcher(watcher_id_l) < 0) {
return NULL;
}
// reset static events counters
if (watcher_id_l >= 0) {
for (int i = 0; i < NUM_CONTEXT_WATCHERS; i++) {
if (watcher_id_l == context_watcher_ids[i]) {
context_watcher_ids[i] = -1;
num_context_object_enter_events[i] = 0;
num_context_object_exit_events[i] = 0;
}
}
}
Py_RETURN_NONE;
}

static PyObject *
get_context_watcher_num_enter_events(PyObject *self, PyObject *watcher_id)
{
assert(PyLong_Check(watcher_id));
long watcher_id_l = PyLong_AsLong(watcher_id);
assert(watcher_id_l >= 0 && watcher_id_l < NUM_CONTEXT_WATCHERS);
return PyLong_FromLong(num_context_object_enter_events[watcher_id_l]);
}

static PyObject *
get_context_watcher_num_exit_events(PyObject *self, PyObject *watcher_id)
{
assert(PyLong_Check(watcher_id));
long watcher_id_l = PyLong_AsLong(watcher_id);
assert(watcher_id_l >= 0 && watcher_id_l < NUM_CONTEXT_WATCHERS);
return PyLong_FromLong(num_context_object_exit_events[watcher_id_l]);
}

static PyObject *
allocate_too_many_context_watchers(PyObject *self, PyObject *args)
{
int watcher_ids[CONTEXT_MAX_WATCHERS + 1];
int num_watchers = 0;
for (unsigned long i = 0; i < sizeof(watcher_ids) / sizeof(int); i++) {
int watcher_id = PyContext_AddWatcher(noop_context_event_handler);
if (watcher_id == -1) {
break;
}
watcher_ids[i] = watcher_id;
num_watchers++;
}
PyObject *exc = PyErr_GetRaisedException();
for (int i = 0; i < num_watchers; i++) {
if (PyContext_ClearWatcher(watcher_ids[i]) < 0) {
PyErr_WriteUnraisable(Py_None);
break;
}
}
if (exc) {
PyErr_SetRaisedException(exc);
return NULL;
}
else if (PyErr_Occurred()) {
return NULL;
}
Py_RETURN_NONE;
}

/*[clinic input]
_testcapi.set_func_defaults_via_capi
func: object
Expand Down Expand Up @@ -692,6 +834,16 @@ static PyMethodDef test_methods[] = {
_TESTCAPI_SET_FUNC_KWDEFAULTS_VIA_CAPI_METHODDEF
{"allocate_too_many_func_watchers", allocate_too_many_func_watchers,
METH_NOARGS, NULL},

// Code object watchers.
{"add_context_watcher", add_context_watcher, METH_O, NULL},
{"clear_context_watcher", clear_context_watcher, METH_O, NULL},
{"get_context_watcher_num_enter_events",
get_context_watcher_num_enter_events, METH_O, NULL},
{"get_context_watcher_num_exit_events",
get_context_watcher_num_exit_events, METH_O, NULL},
{"allocate_too_many_context_watchers",
(PyCFunction) allocate_too_many_context_watchers, METH_NOARGS, NULL},
{NULL},
};

Expand Down
Loading

0 comments on commit 603ec6e

Please sign in to comment.