Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Add diagnostic visualization tools #1631

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/beanmachine/ppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from beanmachine.ppl.diagnostics.tools import viz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to add this to __all__ so that the linter won't complain that this is imported but not use :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I need to setup pre-commit to point to setup.cfg as that's where the flake8 config exists. Do you know of a tool that will add copyright notices for Python?

from torch.distributions import Distribution

from . import experimental
Expand Down
Empty file.
75 changes: 75 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Accessor definition for extending Bean Machine `MonteCarloSamples` objects."""
from __future__ import annotations

import contextlib
import warnings
from typing import Callable, TypeVar

from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples


T = TypeVar("T", bound="CachedAccessor")


class CachedAccessor:
"""A descriptor for caching accessors.

Parameters
----------
name : str
Namespace that will be accessed under, e.g. ``samples.accessor_name``.
accessor : cls
Class with the extension methods.
"""

def __init__(self: T, name: str, accessor: object) -> None:
"""Initialize."""
self._name = name
self._accessor = accessor

def __get__(self: T, obj: object, cls: object) -> object:
"""Access the accessor object."""
if obj is None:
return self._accessor

try:
cache = obj._cache # type: ignore
except AttributeError:
cache = obj._cache = {}

try:
return cache[self._name]
except KeyError:
contextlib.suppress(KeyError)

try:
accessor_obj = self._accessor(obj) # type: ignore
except Exception as error:
msg = f"error initializing {self._name!r} accessor."
raise RuntimeError(msg) from error

cache[self._name] = accessor_obj
return accessor_obj # noqa: R504


def _register_accessor(name: str, cls: object) -> Callable:
"""Register the accessor to the object."""

def decorator(accessor: object) -> object:
if hasattr(cls, name):
warnings.warn(
f"registration of accessor {repr(accessor)} under name "
f"{repr(name)} for type {repr(cls)} is overriding a preexisting "
f"attribute with the same name.",
UserWarning,
stacklevel=2,
)
setattr(cls, name, CachedAccessor(name, accessor))
return accessor

return decorator


def register_mcs_accessor(name: str) -> Callable:
"""Register the accessor to object."""
return _register_accessor(name, MonteCarloSamples)
96 changes: 96 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/autocorrelation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Autocorrelation diagnostic tool for a Bean Machine model."""
from __future__ import annotations

from typing import Any, TypeVar

import arviz as az

import beanmachine.ppl.diagnostics.tools.helpers.autocorrelation as tool
from bokeh.models.callbacks import CustomJS
from bokeh.plotting import show

T = TypeVar("T", bound="Autocorrelation")


class Autocorrelation:
"""Autocorrelation diagnostic tool."""

def __init__(self: T, idata: az.InferenceData) -> None:
"""Initialize."""
self.idata = idata
self.rv_identifiers = list(self.idata["posterior"].data_vars)
self.rv_names = sorted(
[str(rv_identifier) for rv_identifier in self.rv_identifiers],
)
self.num_chains = self.idata["posterior"].dims["chain"]
self.num_draws = self.idata["posterior"].dims["draw"]

def modify_doc(self: T, doc: Any) -> None:
"""Modify the Jupyter document in order to display the tool."""
# Initialize the widgets.
rv_name = self.rv_names[0]
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]

# Compute the initial data displayed in the tool.
rv_data = self.idata["posterior"][rv_identifier].values
computed_data = tool.compute_data(rv_data)

# Create the Bokeh source(s).
sources = tool.create_sources(computed_data)

# Create the figure(s).
figures = tool.create_figures(self.num_chains)

# Create the glyph(s) and attach them to the figure(s).
glyphs = tool.create_glyphs(self.num_chains)
tool.add_glyphs(figures, glyphs, sources)

# Create the annotation(s) and attache them to the figure(s).
annotations = tool.create_annotations(computed_data)
tool.add_annotations(figures, annotations)

# Create the tool tip(s) and attach them to the figure(s).
tooltips = tool.create_tooltips(figures)
tool.add_tooltips(figures, tooltips)

# Create the widget(s) for the tool.
widgets = tool.create_widgets(rv_name, self.rv_names, self.num_draws)

# Create the callback(s) for the widget(s).
def update_rv_select(attr: Any, old: str, new: str) -> None:
rv_name = new
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]
rv_data = self.idata["posterior"][rv_identifier].values
tool.update(rv_data, sources)
end = 10 if self.num_draws <= 2 * 100 else 100
widgets["range_slider"].value = (0, end)

def update_range_slider(
attr: Any,
old: tuple[int, int],
new: tuple[int, int],
) -> None:
fig = figures[list(figures.keys())[0]]
fig.x_range.start, fig.x_range.end = new

widgets["rv_select"].on_change("value", update_rv_select)
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the
# figures.
widgets["rv_select"].js_on_change(
"value",
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"),
)
widgets["range_slider"].on_change("value", update_range_slider)

tool_view = tool.create_view(widgets, figures)
doc.add_root(tool_view)

def show_tool(self: T) -> None:
"""Show the diagnostic tool.

Returns
-------
None
Directly displays the tool in Jupyter.
"""
show(self.modify_doc)
87 changes: 87 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/effective_sample_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Effective Sample Size (ESS) diagnostic tool for a Bean Machine model."""
from __future__ import annotations

from typing import Any, TypeVar

import arviz as az

import beanmachine.ppl.diagnostics.tools.helpers.effective_sample_size as tool
from bokeh.core.enums import LegendClickPolicy
from bokeh.models.callbacks import CustomJS
from bokeh.plotting import show


T = TypeVar("T", bound="EffectiveSampleSize")


class EffectiveSampleSize:
"""Effective Sample Size (ESS) diagnostic tool."""

def __init__(self: T, idata: az.InferenceData) -> None:
"""Initialize."""
self.idata = idata
self.rv_identifiers = list(self.idata["posterior"].data_vars)
self.rv_names = sorted(
[str(rv_identifier) for rv_identifier in self.rv_identifiers],
)
self.num_chains = self.idata["posterior"].dims["chain"]

def modify_doc(self: T, doc: Any) -> None:
"""Modify the Jupyter document in order to display the tool."""
# Initialize the widgets.
rv_name = self.rv_names[0]
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]

# Compute the initial data displayed in the tool.
rv_data = self.idata["posterior"][rv_identifier].values
computed_data = tool.compute_data(rv_data)

# Create the Bokeh source(s).
sources = tool.create_sources(computed_data)

# Create the figure(s).
figures = tool.create_figures()

# Create the glyph(s) and attach them to the figure(s).
glyphs = tool.create_glyphs()
tool.add_glyphs(figures, glyphs, sources)

# Create the annotation(s) and attache them to the figure(s).
annotations = tool.create_annotations(figures)
annotations["ess"]["legend"].click_policy = LegendClickPolicy.hide
tool.add_annotations(figures, annotations)

# Create the tool tip(s) and attach them to the figure(s).
tooltips = tool.create_tooltips(figures)
tool.add_tooltips(figures, tooltips)

# Create the widget(s) for the tool.
widgets = tool.create_widgets(rv_name, self.rv_names)

# Create the callback(s) for the widget(s).
def update_rv_select(attr: Any, old: str, new: str) -> None:
rv_name = new
rv_identifier = self.rv_identifiers[self.rv_names.index(rv_name)]
rv_data = self.idata["posterior"][rv_identifier].values
tool.update(rv_data, sources)

widgets["rv_select"].on_change("value", update_rv_select)
# NOTE: We are using Bokeh's CustomJS model in order to reset the ranges of the
# figures.
widgets["rv_select"].js_on_change(
"value",
CustomJS(args={"p": list(figures.values())[0]}, code="p.reset.emit()"),
)

tool_view = tool.create_view(widgets, figures)
doc.add_root(tool_view)

def show_tool(self: T) -> None:
"""Show the diagnostic tool.

Returns
-------
None
Directly displays the tool in Jupyter.
"""
show(self.modify_doc)
Empty file.
Loading