from itertools import zip_longest
import numpy as np
from astropy.visualization import simple_norm
from .priors import f_hyperprior_lira
__all__ = [
"plot_example_dataset",
"plot_parameter_traces",
"plot_parameter_distributions",
"plot_pixel_trace",
"plot_pixel_trace_neighbours",
]
def get_grid_figsize(width, ncols, nrows):
height = width * (nrows / ncols)
return width, height
[docs]def plot_example_dataset(data, figsize=(12, 7), **kwargs):
"""Plot example dataset
Parameters
----------
data : dict of `~numpy.ndarray`
Data
figsize : tuple
Figure size
**kwargs : dict
Keyword arguments passed to `~matplotlib.pyplot.imshow`
"""
import matplotlib.pyplot as plt
data = data.copy()
wcs = data.pop("wcs", None)
fig, axes = plt.subplots(
nrows=2,
ncols=3,
figsize=figsize,
subplot_kw={"projection": wcs},
)
for name, ax in zip_longest(data.keys(), axes.flat):
if name is None:
ax.set_visible(False)
continue
im = ax.imshow(data[name], origin="lower", **kwargs)
ax.set_title(name.title())
fig.colorbar(im, ax=ax)
def plot_trace(ax, idx, trace, n_burn_in, **kwargs):
"""Plot a single parameter trace
Parameters
----------
ax : `~matplotlib.pyplot.Axes`
Plot axes
idx : `~numpy.ndarray`
Iteration
trace : `~numpy.ndarray`
Trace to plot
n_burn_in : int
Number of burn in iterations
**kwargs : dict
Keyword arguments passed to `~matplotlib.pyplot.plot`
"""
burn_in = slice(0, n_burn_in)
valid = slice(n_burn_in, -1)
ax.plot(idx[burn_in], trace[burn_in], alpha=0.3, label="Burn in", **kwargs)
ax.plot(idx[valid], trace[valid], label="Valid", **kwargs)
ax.set_xlabel("Number of Iterations")
mean = np.mean(trace[valid])
ax.hlines(mean, n_burn_in, len(idx), color="tab:orange", zorder=10, label="Mean")
std = np.std(trace[valid])
y1, y2 = mean - std, mean + std
ax.fill_between(
idx[valid],
np.array([y1]),
np.array([y2]),
color="tab:orange",
alpha=0.2,
zorder=9,
label=r"1 $\sigma$ Std. Deviation",
)
[docs]def plot_parameter_traces(
parameter_trace, config=None, figsize=None, ncols=3, **kwargs
):
"""Plot parameters traces
Parameters
----------
parameter_trace : `~astropy.table.Table`
Parameter trace table
config : dict
Config dictionary.
figsize : tupe of float
Figure size
ncols : int
Number of columns to plot.
**kwargs : dict
Keyword arguments passed to `~matplotlib.pyplot.plot`
Returns
-------
axes : `~numpy.ndarray` of `~matplotlib.pyplot.Axes`
Plotting axes
"""
import matplotlib.pyplot as plt
table = parameter_trace.copy()
table.remove_columns(["iteration", "stepSize", "cycleSpinRow", "cycleSpinCol"])
if config is None:
config = table.meta
kwargs.setdefault("color", "tab:blue")
nrows = (len(table.colnames) // ncols) + 1
if figsize is None:
figsize = get_grid_figsize(width=16, ncols=ncols, nrows=nrows)
fig, axes = plt.subplots(
ncols=ncols, nrows=nrows, figsize=figsize, gridspec_kw={"hspace": 0.25}
)
n_burn_in = config.get("n_burn_in", 0)
idx = np.arange(len(table))
for name, ax in zip_longest(table.colnames, axes.flat):
if name is None:
ax.set_visible(False)
continue
trace = parameter_trace[name]
plot_trace(ax=ax, trace=trace, idx=idx, n_burn_in=n_burn_in, **kwargs)
ax.set_title(name.title())
if name == "logPost":
ax.legend()
return axes
[docs]def plot_parameter_distributions(
parameter_trace, config=None, figsize=None, ncols=3, **kwargs
):
"""Plot parameters traces
Parameters
----------
parameter_trace : `~astropy.table.Table`
Parameter trace table
config : dict
Config dictionary
figsize : tupe of float
Figure size
ncols : int
Number of columns to plot.
**kwargs : dict
Keyword arguments passed to `~matplotlib.pyplot.hist`
Returns
-------
axes : `~numpy.ndarray` of `~matplotlib.pyplot.Axes`
Plotting axes
"""
import matplotlib.pyplot as plt
table = parameter_trace.copy()
table.remove_columns(
["iteration", "stepSize", "cycleSpinRow", "cycleSpinCol", "logPost"]
)
if config is None:
config = table.meta
n_burn_in = config.get("n_burn_in", 0)
nrows = (len(table.colnames) // ncols) + 1
if figsize is None:
figsize = get_grid_figsize(width=16, ncols=ncols, nrows=nrows)
fig, axes = plt.subplots(
ncols=ncols, nrows=nrows, figsize=figsize, gridspec_kw={"hspace": 0.25}
)
kwargs.setdefault("color", "tab:blue")
kwargs.setdefault("density", True)
kwargs.setdefault("bins", int(np.sqrt(len(table))))
has_legend = False
for name, ax in zip_longest(table.colnames, axes.flat):
if name is None:
ax.set_visible(False)
continue
column = parameter_trace[name][n_burn_in:]
is_finite = np.isfinite(column)
n_vals, bins, _ = ax.hist(column[is_finite], label="Valid", **kwargs)
column_burn_in = parameter_trace[name][:n_burn_in]
is_finite_burn_in = np.isfinite(column_burn_in)
n_vals_burn_in, _, _ = ax.hist(
column_burn_in[is_finite_burn_in], alpha=0.3, label="Burn in", **kwargs
)
ax.set_title(name.title())
ax.set_xlabel("Number of Iterations")
y_max = np.max([n_vals, n_vals_burn_in])
mean = np.mean(column[is_finite])
ax.vlines(mean, 0, y_max, color="tab:orange", zorder=10, label="Mean")
std = np.std(column[is_finite])
x1, x2 = mean - std, mean + std
ax.fill_betweenx(
np.linspace(0, y_max, 10),
np.array([x1]),
np.array([x2]),
color="tab:orange",
alpha=0.2,
zorder=9,
label=r"1 $\sigma$ Std. Deviation",
)
if not has_legend:
ax.legend()
has_legend = True
return axes
[docs]def plot_pixel_trace(image_trace, center_pix, ax=None, config=None, **kwargs):
"""Plot pixel traces in a circular region, given a position and radius.
Parameters
----------
image_trace : `~numpy.ndarray`
Image traces array
center_pix : tuple of int
Pixel indices center, order is (x, y).
ax : `~matplotlib.pyplot.Axes`
Plotting axes
config : dict
Configuration dictionary
**kwargs : dict
Keyword arguments passed to `~matplotlib.pyplot.plot`
Returns
-------
ax : `~matplotlib.pyplot.Axes`
Plotting axes
"""
import matplotlib.pyplot as plt
if config is None:
config = {}
if ax is None:
ax = plt.gca()
n_iter, n_y, n_x = image_trace.shape
n_burn_in = config.get("n_burn_in", 0)
idx = np.arange(n_iter)
trace = image_trace[(Ellipsis,) + center_pix[::-1]].T
kwargs.setdefault("color", "tab:blue")
plot_trace(ax=ax, trace=trace, idx=idx, n_burn_in=n_burn_in, **kwargs)
ax.set_title(f"Pixel trace for {center_pix}")
ax.set_xlabel("Number of Iterations")
ax.legend()
return ax
[docs]def plot_pixel_trace_neighbours(
image_trace, center_pix, radius_pix=1, cmap="Greys", ax=None, **kwargs
):
"""Plot pixel traces in a given region.
The distance to the center is encoded in the color the trace it plotted with.
Parameters
----------
image_trace : `~numpy.ndarray`
Image traces array
center_pix : tuple of int
Pixel indices, order is (x, y). By default the trace at the center is plotted.
radius_pix : float
Radius in which the traces are plotted.
cmap : str
Colormapt o plot the traces with.
ax : `~matplotlib.pyplot.Axes`
Plotting axes
**kwargs : dict
Keyword arguments forwarded to `~matplotlib.pyplot.plot`
Returns
-------
ax : `~matplotlib.pyplot.Axes`
Plotting axes
"""
import matplotlib
import matplotlib.pyplot as plt
if ax is None:
ax = plt.gca()
_, ny, nx = image_trace.shape
y, x = np.arange(ny).reshape((-1, 1)), np.arange(nx)
offset_pix = np.sqrt((y - center_pix[1]) ** 2 + (x - center_pix[0]) ** 2)
idx = np.where((offset_pix < radius_pix) & (offset_pix > 0))
cmap = matplotlib.cm.get_cmap(cmap)
norm = simple_norm(data=offset_pix[idx])
idx = idx + (offset_pix[idx],)
kwargs.setdefault("zorder", 0)
kwargs.setdefault("alpha", 0.5)
for idx_x, idx_y, offset in zip(*idx):
trace = image_trace[(slice(None), idx_x, idx_y)]
value = norm(offset)
color = tuple(cmap(value)[0])
ax.plot(trace, color=color, **kwargs)
return ax
def plot_hyperpriors_lira(
figsize=(12, 4),
ncols=2,
alphas=None,
ms_al_kap1=0,
ms_al_kap2=1000,
ms_al_kap3=3,
**kwargs,
):
"""Plot hyperprior distributions
Parameters
----------
figsize: tuple of float
Figure size
ncols : int
Number of columns
alphas : `~numpy.ndarray`
Alpha values
ms_al_kap1: float or `~numpy.ndarray`
Multiscale prior parameter.
ms_al_kap2: float or `~numpy.ndarray`
Multiscale prior parameter.
ms_al_kap3: float or `~numpy.ndarray`
Multiscale prior parameter.
Returns
-------
axes : `~matplotlib.pyplot.Axes`
Plotting axes
"""
import matplotlib.pyplot as plt
ms_al_kap1 = np.atleast_1d(ms_al_kap1)
ms_al_kap2 = np.atleast_1d(ms_al_kap2)
ms_al_kap3 = np.atleast_1d(ms_al_kap3)
if alphas is None:
alphas = np.linspace(0, 3, 100)
nrows = 1 + (len(ms_al_kap1) - 1) // ncols
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
if not isinstance(axes, np.ndarray):
axes = np.asanyarray([axes])
for idx, ax in enumerate(axes.flat):
values = f_hyperprior_lira(
alpha=alphas,
ms_al_kap1=ms_al_kap1[idx],
ms_al_kap2=ms_al_kap2[idx],
ms_al_kap3=ms_al_kap3[idx],
)
ax.plot(alphas, values, **kwargs)
ax.set_xlabel("Alpha")
ax.set_ylabel("PDF / A.U.")
return axes