Skip to content

Commit

Permalink
add plot_range setting to elisa.plot.misc.plot_corner
Browse files Browse the repository at this point in the history
  • Loading branch information
xiesl97 authored Sep 20, 2024
1 parent 3fd11dd commit 5a00141
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions src/elisa/plot/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def plot_corner(
bins: int | Sequence[int] = 40,
hist_bin_factor: float | Sequence[float] = 1.5,
params: str | Sequence[str] | None = None,
axes_scale: str | Sequence[str] = 'linear',
plot_range: Sequence[float] | None = None,
axes_scale: str | Sequence[str] = "linear",
levels: float | Sequence[float] | None = None,
titles: str | Sequence[str] | None = None,
labels: str | Sequence[str] | None = None,
Expand All @@ -46,6 +47,9 @@ def plot_corner(
plots to provide more resolution. The default is 1.5.
params : str or list of str, optional
One or more parameters to be plotted.
plot_range : float, or list of float, optional
A list where each parameter is either a length 2 tuple containing
lower and upper bounds.
axes_scale : str, or list of str, optional
Scale to use for each parameter dimension. If only one scale is given,
use that for all dimensions. Scale must be ``'linear'`` or ``'log'``.
Expand All @@ -66,7 +70,7 @@ def plot_corner(
The figure containing the corner plot.
"""
posterior = idata['posterior']
posterior = idata["posterior"]

if params is None:
params = list(posterior.data_vars.keys())
Expand All @@ -80,7 +84,7 @@ def plot_corner(
all_params = posterior.data_vars.keys()
not_found = set(params) - set(all_params)
if not_found:
raise ValueError(f'parameter {not_found} not found in posterior')
raise ValueError(f"parameter {not_found} not found in posterior")

if titles is None:
titles = params
Expand All @@ -103,7 +107,7 @@ def plot_corner(
median = {p: float(median[p].values) for p in params}
quantile = {p: quantile[p].values.tolist() for p in params}
titles = [
f'{t} = {report_interval(median[p], *quantile[p])}'
f"{t} = {report_interval(median[p], *quantile[p])}"
for t, p in zip(titles, params)
]

Expand All @@ -126,21 +130,20 @@ def plot_corner(
# colors1 = [scale_color(to_hex(c), 0.95) for c in colors2]

if color is None:
color = '#205295'
color = "#205295"
else:
color = str(color)

plt.rcParams['axes.formatter.min_exponent'] = 3
plt.rcParams["axes.formatter.min_exponent"] = 3
c1, c2 = get_contour_colors(color, len(levels), 0.8, 2.0)

vmin = {p: posterior[p].values.min() for p in params}
vmax = {p: posterior[p].values.max() for p in params}
if any(vmin[p] == vmax[p] for p in params):
plot_range = [
(vmin[p], vmax[p]) if vmin[p] != vmax[p] else 0.99 for p in params
]
else:
plot_range = None
if plot_range is None:
vmin = {p: posterior[p].values.min() for p in params}
vmax = {p: posterior[p].values.max() for p in params}
if any(vmin[p] == vmax[p] for p in params):
plot_range = [
(vmin[p], vmax[p]) if vmin[p] != vmax[p] else 0.99 for p in params
]

fig = corner.corner(
idata,
Expand All @@ -157,7 +160,7 @@ def plot_corner(
use_math_text=True,
labelpad=-0.08,
divergences=divergences,
divergences_kwargs={'color': 'red', 'alpha': 0.3, 'ms': 1},
divergences_kwargs={"color": "red", "alpha": 0.3, "ms": 1},
var_names=params,
# kwargs for corner.hist2d
levels=levels,
Expand All @@ -166,14 +169,13 @@ def plot_corner(
plot_contours=True,
fill_contours=True,
no_fill_contours=True,
contour_kwargs={'colors': c1},
contourf_kwargs={'colors': ['white'] + c2, 'alpha': 0.75},
data_kwargs={'color': c2[0], 'alpha': 0.75, 'ms': 1.5},
contour_kwargs={"colors": c1},
contourf_kwargs={"colors": ["white"] + c2, "alpha": 0.75},
data_kwargs={"color": c2[0], "alpha": 0.75, "ms": 1.5},
)

return fig


def plot_trace(
idata: az.InferenceData,
params: str | Sequence[str] | None = None,
Expand Down

0 comments on commit 5a00141

Please sign in to comment.