diff --git a/src/elisa/plot/misc.py b/src/elisa/plot/misc.py index ee32259..e00f945 100644 --- a/src/elisa/plot/misc.py +++ b/src/elisa/plot/misc.py @@ -22,7 +22,8 @@ def plot_corner( bins: int | Sequence[int] = 40, hist_bin_factor: float | Sequence[float] = 1.5, params: str | Sequence[str] | None = None, - axes_scale: str | Sequence[str] = 'linear', + plot_range: Sequence[float] | None = None, + axes_scale: str | Sequence[str] = "linear", levels: float | Sequence[float] | None = None, titles: str | Sequence[str] | None = None, labels: str | Sequence[str] | None = None, @@ -46,6 +47,9 @@ def plot_corner( plots to provide more resolution. The default is 1.5. params : str or list of str, optional One or more parameters to be plotted. + plot_range : float, or list of float, optional + A list where each parameter is either a length 2 tuple containing + lower and upper bounds. axes_scale : str, or list of str, optional Scale to use for each parameter dimension. If only one scale is given, use that for all dimensions. Scale must be ``'linear'`` or ``'log'``. @@ -66,7 +70,7 @@ def plot_corner( The figure containing the corner plot. """ - posterior = idata['posterior'] + posterior = idata["posterior"] if params is None: params = list(posterior.data_vars.keys()) @@ -80,7 +84,7 @@ def plot_corner( all_params = posterior.data_vars.keys() not_found = set(params) - set(all_params) if not_found: - raise ValueError(f'parameter {not_found} not found in posterior') + raise ValueError(f"parameter {not_found} not found in posterior") if titles is None: titles = params @@ -103,7 +107,7 @@ def plot_corner( median = {p: float(median[p].values) for p in params} quantile = {p: quantile[p].values.tolist() for p in params} titles = [ - f'{t} = {report_interval(median[p], *quantile[p])}' + f"{t} = {report_interval(median[p], *quantile[p])}" for t, p in zip(titles, params) ] @@ -126,21 +130,20 @@ def plot_corner( # colors1 = [scale_color(to_hex(c), 0.95) for c in colors2] if color is None: - color = '#205295' + color = "#205295" else: color = str(color) - plt.rcParams['axes.formatter.min_exponent'] = 3 + plt.rcParams["axes.formatter.min_exponent"] = 3 c1, c2 = get_contour_colors(color, len(levels), 0.8, 2.0) - vmin = {p: posterior[p].values.min() for p in params} - vmax = {p: posterior[p].values.max() for p in params} - if any(vmin[p] == vmax[p] for p in params): - plot_range = [ - (vmin[p], vmax[p]) if vmin[p] != vmax[p] else 0.99 for p in params - ] - else: - plot_range = None + if plot_range is None: + vmin = {p: posterior[p].values.min() for p in params} + vmax = {p: posterior[p].values.max() for p in params} + if any(vmin[p] == vmax[p] for p in params): + plot_range = [ + (vmin[p], vmax[p]) if vmin[p] != vmax[p] else 0.99 for p in params + ] fig = corner.corner( idata, @@ -157,7 +160,7 @@ def plot_corner( use_math_text=True, labelpad=-0.08, divergences=divergences, - divergences_kwargs={'color': 'red', 'alpha': 0.3, 'ms': 1}, + divergences_kwargs={"color": "red", "alpha": 0.3, "ms": 1}, var_names=params, # kwargs for corner.hist2d levels=levels, @@ -166,14 +169,13 @@ def plot_corner( plot_contours=True, fill_contours=True, no_fill_contours=True, - contour_kwargs={'colors': c1}, - contourf_kwargs={'colors': ['white'] + c2, 'alpha': 0.75}, - data_kwargs={'color': c2[0], 'alpha': 0.75, 'ms': 1.5}, + contour_kwargs={"colors": c1}, + contourf_kwargs={"colors": ["white"] + c2, "alpha": 0.75}, + data_kwargs={"color": c2[0], "alpha": 0.75, "ms": 1.5}, ) return fig - def plot_trace( idata: az.InferenceData, params: str | Sequence[str] | None = None,